rpc: track number of users of the gss upcall pipe
[safe/jmp/linux-2.6] / net / sunrpc / auth_gss / auth_gss.c
index 834a831..51aa27d 100644 (file)
@@ -33,8 +33,6 @@
  *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
  *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
  *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
- *
- * $Id$
  */
 
 
@@ -77,6 +75,8 @@ struct gss_auth {
        struct dentry *dentry;
 };
 
+static atomic_t pipe_users = ATOMIC_INIT(0);
+
 static void gss_free_ctx(struct gss_cl_ctx *);
 static struct rpc_pipe_ops gss_upcall_ops;
 
@@ -239,6 +239,7 @@ gss_release_msg(struct gss_upcall_msg *gss_msg)
 {
        if (!atomic_dec_and_test(&gss_msg->count))
                return;
+       atomic_dec(&pipe_users);
        BUG_ON(!list_empty(&gss_msg->list));
        if (gss_msg->ctx != NULL)
                gss_put_ctx(gss_msg->ctx);
@@ -331,16 +332,17 @@ gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid)
        struct gss_upcall_msg *gss_msg;
 
        gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
-       if (gss_msg != NULL) {
-               INIT_LIST_HEAD(&gss_msg->list);
-               rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
-               init_waitqueue_head(&gss_msg->waitqueue);
-               atomic_set(&gss_msg->count, 1);
-               gss_msg->msg.data = &gss_msg->uid;
-               gss_msg->msg.len = sizeof(gss_msg->uid);
-               gss_msg->uid = uid;
-               gss_msg->auth = gss_auth;
-       }
+       if (gss_msg == NULL)
+               return ERR_PTR(-ENOMEM);
+       atomic_inc(&pipe_users);
+       INIT_LIST_HEAD(&gss_msg->list);
+       rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
+       init_waitqueue_head(&gss_msg->waitqueue);
+       atomic_set(&gss_msg->count, 1);
+       gss_msg->msg.data = &gss_msg->uid;
+       gss_msg->msg.len = sizeof(gss_msg->uid);
+       gss_msg->uid = uid;
+       gss_msg->auth = gss_auth;
        return gss_msg;
 }
 
@@ -357,8 +359,8 @@ gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cr
                uid = 0;
 
        gss_new = gss_alloc_msg(gss_auth, uid);
-       if (gss_new == NULL)
-               return ERR_PTR(-ENOMEM);
+       if (IS_ERR(gss_new))
+               return gss_new;
        gss_msg = gss_add_msg(gss_auth, gss_new);
        if (gss_msg == gss_new) {
                int res = rpc_queue_upcall(gss_auth->dentry->d_inode, &gss_new->msg);
@@ -371,6 +373,18 @@ gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cr
        return gss_msg;
 }
 
+static void warn_gssd(void)
+{
+       static unsigned long ratelimit;
+       unsigned long now = jiffies;
+
+       if (time_after(now, ratelimit)) {
+               printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
+                               "Please check user daemon is running.\n");
+               ratelimit = now + 15*HZ;
+       }
+}
+
 static inline int
 gss_refresh_upcall(struct rpc_task *task)
 {
@@ -545,6 +559,13 @@ out:
        return err;
 }
 
+static int
+gss_pipe_open(struct inode *inode)
+{
+       atomic_inc(&pipe_users);
+       return 0;
+}
+
 static void
 gss_pipe_release(struct inode *inode)
 {
@@ -564,27 +585,22 @@ gss_pipe_release(struct inode *inode)
                spin_lock(&inode->i_lock);
        }
        spin_unlock(&inode->i_lock);
+
+       atomic_dec(&pipe_users);
 }
 
 static void
 gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
 {
        struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
-       static unsigned long ratelimit;
 
        if (msg->errno < 0) {
                dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
                                gss_msg);
                atomic_inc(&gss_msg->count);
                gss_unhash_msg(gss_msg);
-               if (msg->errno == -ETIMEDOUT) {
-                       unsigned long now = jiffies;
-                       if (time_after(now, ratelimit)) {
-                               printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
-                                                   "Please check user daemon is running!\n");
-                               ratelimit = now + 15*HZ;
-                       }
-               }
+               if (msg->errno == -ETIMEDOUT)
+                       warn_gssd();
                gss_release_msg(gss_msg);
        }
 }
@@ -652,7 +668,6 @@ static void
 gss_free(struct gss_auth *gss_auth)
 {
        rpc_unlink(gss_auth->dentry);
-       gss_auth->dentry = NULL;
        gss_mech_put(gss_auth->mech);
 
        kfree(gss_auth);
@@ -695,7 +710,7 @@ gss_destroying_context(struct rpc_cred *cred)
        struct rpc_task *task;
 
        if (gss_cred->gc_ctx == NULL ||
-           test_and_clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
+           test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
                return 0;
 
        gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
@@ -759,14 +774,12 @@ gss_free_cred_callback(struct rcu_head *head)
 }
 
 static void
-gss_destroy_cred(struct rpc_cred *cred)
+gss_destroy_nullcred(struct rpc_cred *cred)
 {
        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
        struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
        struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
 
-       if (gss_destroying_context(cred))
-               return;
        rcu_assign_pointer(gss_cred->gc_ctx, NULL);
        call_rcu(&cred->cr_rcu, gss_free_cred_callback);
        if (ctx)
@@ -774,6 +787,15 @@ gss_destroy_cred(struct rpc_cred *cred)
        kref_put(&gss_auth->kref, gss_free_callback);
 }
 
+static void
+gss_destroy_cred(struct rpc_cred *cred)
+{
+
+       if (gss_destroying_context(cred))
+               return;
+       gss_destroy_nullcred(cred);
+}
+
 /*
  * Lookup RPCSEC_GSS cred for the current process
  */
@@ -1019,7 +1041,7 @@ gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
        *p++ = htonl(rqstp->rq_seqno);
 
-       status = rpc_call_xdrproc(encode, rqstp, p, obj);
+       status = encode(rqstp, p, obj);
        if (status)
                return status;
 
@@ -1113,7 +1135,7 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
        *p++ = htonl(rqstp->rq_seqno);
 
-       status = rpc_call_xdrproc(encode, rqstp, p, obj);
+       status = encode(rqstp, p, obj);
        if (status)
                return status;
 
@@ -1172,12 +1194,12 @@ gss_wrap_req(struct rpc_task *task,
                /* The spec seems a little ambiguous here, but I think that not
                 * wrapping context destruction requests makes the most sense.
                 */
-               status = rpc_call_xdrproc(encode, rqstp, p, obj);
+               status = encode(rqstp, p, obj);
                goto out;
        }
        switch (gss_cred->gc_service) {
                case RPC_GSS_SVC_NONE:
-                       status = rpc_call_xdrproc(encode, rqstp, p, obj);
+                       status = encode(rqstp, p, obj);
                        break;
                case RPC_GSS_SVC_INTEGRITY:
                        status = gss_wrap_req_integ(cred, ctx, encode,
@@ -1293,7 +1315,7 @@ gss_unwrap_resp(struct rpc_task *task,
        cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
                                                + (savedlen - head->iov_len);
 out_decode:
-       status = rpc_call_xdrproc(decode, rqstp, p, obj);
+       status = decode(rqstp, p, obj);
 out:
        gss_put_ctx(ctx);
        dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
@@ -1326,7 +1348,7 @@ static const struct rpc_credops gss_credops = {
 
 static const struct rpc_credops gss_nullops = {
        .cr_name        = "AUTH_GSS",
-       .crdestroy      = gss_destroy_cred,
+       .crdestroy      = gss_destroy_nullcred,
        .crbind         = rpcauth_generic_bind_cred,
        .crmatch        = gss_match,
        .crmarshal      = gss_marshal,
@@ -1340,6 +1362,7 @@ static struct rpc_pipe_ops gss_upcall_ops = {
        .upcall         = gss_pipe_upcall,
        .downcall       = gss_pipe_downcall,
        .destroy_msg    = gss_pipe_destroy_msg,
+       .open_pipe      = gss_pipe_open,
        .release_pipe   = gss_pipe_release,
 };