rpc: store pointer to pipe inode in gss upcall message
[safe/jmp/linux-2.6] / net / sunrpc / auth_gss / auth_gss.c
index 019d4b4..fe06acd 100644 (file)
@@ -61,22 +61,11 @@ static const struct rpc_credops gss_nullops;
 # define RPCDBG_FACILITY       RPCDBG_AUTH
 #endif
 
-#define NFS_NGROUPS    16
-
-#define GSS_CRED_SLACK         1024            /* XXX: unused */
+#define GSS_CRED_SLACK         1024
 /* length of a krb5 verifier (48), plus data added before arguments when
  * using integrity (two 4-byte integers): */
 #define GSS_VERF_SLACK         100
 
-/* XXX this define must match the gssd define
-* as it is passed to gssd to signal the use of
-* machine creds should be part of the shared rpc interface */
-
-#define CA_RUN_AS_MACHINE  0x00000200
-
-/* dump the buffer in `emacs-hexl' style */
-#define isprint(c)      ((c > 0x1f) && (c < 0x7f))
-
 struct gss_auth {
        struct kref kref;
        struct rpc_auth rpc_auth;
@@ -86,6 +75,13 @@ struct gss_auth {
        struct dentry *dentry;
 };
 
+/* pipe_version >= 0 if and only if someone has a pipe open. */
+static int pipe_version = -1;
+static atomic_t pipe_users = ATOMIC_INIT(0);
+static DEFINE_SPINLOCK(pipe_version_lock);
+static struct rpc_wait_queue pipe_version_rpc_waitqueue;
+static DECLARE_WAIT_QUEUE_HEAD(pipe_version_waitqueue);
+
 static void gss_free_ctx(struct gss_cl_ctx *);
 static struct rpc_pipe_ops gss_upcall_ops;
 
@@ -144,7 +140,7 @@ simple_get_netobj(const void *p, const void *end, struct xdr_netobj *dest)
        q = (const void *)((const char *)p + len);
        if (unlikely(q > end || q < p))
                return ERR_PTR(-EFAULT);
-       dest->data = kmemdup(p, len, GFP_KERNEL);
+       dest->data = kmemdup(p, len, GFP_NOFS);
        if (unlikely(dest->data == NULL))
                return ERR_PTR(-ENOMEM);
        dest->len = len;
@@ -169,7 +165,7 @@ gss_alloc_context(void)
 {
        struct gss_cl_ctx *ctx;
 
-       ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
+       ctx = kzalloc(sizeof(*ctx), GFP_NOFS);
        if (ctx != NULL) {
                ctx->gc_proc = RPC_GSS_PROC_DATA;
                ctx->gc_seq = 1;        /* NetApp 6.4R1 doesn't accept seq. no. 0 */
@@ -238,16 +234,40 @@ struct gss_upcall_msg {
        struct rpc_pipe_msg msg;
        struct list_head list;
        struct gss_auth *auth;
+       struct rpc_inode *inode;
        struct rpc_wait_queue rpc_waitqueue;
        wait_queue_head_t waitqueue;
        struct gss_cl_ctx *ctx;
 };
 
+static int get_pipe_version(void)
+{
+       int ret;
+
+       spin_lock(&pipe_version_lock);
+       if (pipe_version >= 0) {
+               atomic_inc(&pipe_users);
+               ret = 0;
+       } else
+               ret = -EAGAIN;
+       spin_unlock(&pipe_version_lock);
+       return ret;
+}
+
+static void put_pipe_version(void)
+{
+       if (atomic_dec_and_lock(&pipe_users, &pipe_version_lock)) {
+               pipe_version = -1;
+               spin_unlock(&pipe_version_lock);
+       }
+}
+
 static void
 gss_release_msg(struct gss_upcall_msg *gss_msg)
 {
        if (!atomic_dec_and_test(&gss_msg->count))
                return;
+       put_pipe_version();
        BUG_ON(!list_empty(&gss_msg->list));
        if (gss_msg->ctx != NULL)
                gss_put_ctx(gss_msg->ctx);
@@ -270,15 +290,15 @@ __gss_find_upcall(struct rpc_inode *rpci, uid_t uid)
        return NULL;
 }
 
-/* Try to add a upcall to the pipefs queue.
+/* Try to add an upcall to the pipefs queue.
  * If an upcall owned by our uid already exists, then we return a reference
  * to that upcall instead of adding the new upcall.
  */
 static inline struct gss_upcall_msg *
 gss_add_msg(struct gss_auth *gss_auth, struct gss_upcall_msg *gss_msg)
 {
-       struct inode *inode = gss_auth->dentry->d_inode;
-       struct rpc_inode *rpci = RPC_I(inode);
+       struct rpc_inode *rpci = gss_msg->inode;
+       struct inode *inode = &rpci->vfs_inode;
        struct gss_upcall_msg *old;
 
        spin_lock(&inode->i_lock);
@@ -304,8 +324,7 @@ __gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 static void
 gss_unhash_msg(struct gss_upcall_msg *gss_msg)
 {
-       struct gss_auth *gss_auth = gss_msg->auth;
-       struct inode *inode = gss_auth->dentry->d_inode;
+       struct inode *inode = &gss_msg->inode->vfs_inode;
 
        if (list_empty(&gss_msg->list))
                return;
@@ -321,7 +340,7 @@ gss_upcall_callback(struct rpc_task *task)
        struct gss_cred *gss_cred = container_of(task->tk_msg.rpc_cred,
                        struct gss_cred, gc_base);
        struct gss_upcall_msg *gss_msg = gss_cred->gc_upcall;
-       struct inode *inode = gss_msg->auth->dentry->d_inode;
+       struct inode *inode = &gss_msg->inode->vfs_inode;
 
        spin_lock(&inode->i_lock);
        if (gss_msg->ctx)
@@ -338,18 +357,25 @@ static inline struct gss_upcall_msg *
 gss_alloc_msg(struct gss_auth *gss_auth, uid_t uid)
 {
        struct gss_upcall_msg *gss_msg;
+       int vers;
 
-       gss_msg = kzalloc(sizeof(*gss_msg), GFP_KERNEL);
-       if (gss_msg != NULL) {
-               INIT_LIST_HEAD(&gss_msg->list);
-               rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
-               init_waitqueue_head(&gss_msg->waitqueue);
-               atomic_set(&gss_msg->count, 1);
-               gss_msg->msg.data = &gss_msg->uid;
-               gss_msg->msg.len = sizeof(gss_msg->uid);
-               gss_msg->uid = uid;
-               gss_msg->auth = gss_auth;
+       gss_msg = kzalloc(sizeof(*gss_msg), GFP_NOFS);
+       if (gss_msg == NULL)
+               return ERR_PTR(-ENOMEM);
+       vers = get_pipe_version();
+       if (vers < 0) {
+               kfree(gss_msg);
+               return ERR_PTR(vers);
        }
+       gss_msg->inode = RPC_I(gss_auth->dentry->d_inode);
+       INIT_LIST_HEAD(&gss_msg->list);
+       rpc_init_wait_queue(&gss_msg->rpc_waitqueue, "RPCSEC_GSS upcall waitq");
+       init_waitqueue_head(&gss_msg->waitqueue);
+       atomic_set(&gss_msg->count, 1);
+       gss_msg->msg.data = &gss_msg->uid;
+       gss_msg->msg.len = sizeof(gss_msg->uid);
+       gss_msg->uid = uid;
+       gss_msg->auth = gss_auth;
        return gss_msg;
 }
 
@@ -366,11 +392,12 @@ gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cr
                uid = 0;
 
        gss_new = gss_alloc_msg(gss_auth, uid);
-       if (gss_new == NULL)
-               return ERR_PTR(-ENOMEM);
+       if (IS_ERR(gss_new))
+               return gss_new;
        gss_msg = gss_add_msg(gss_auth, gss_new);
        if (gss_msg == gss_new) {
-               int res = rpc_queue_upcall(gss_auth->dentry->d_inode, &gss_new->msg);
+               struct inode *inode = &gss_new->inode->vfs_inode;
+               int res = rpc_queue_upcall(inode, &gss_new->msg);
                if (res) {
                        gss_unhash_msg(gss_new);
                        gss_msg = ERR_PTR(res);
@@ -380,6 +407,18 @@ gss_setup_upcall(struct rpc_clnt *clnt, struct gss_auth *gss_auth, struct rpc_cr
        return gss_msg;
 }
 
+static void warn_gssd(void)
+{
+       static unsigned long ratelimit;
+       unsigned long now = jiffies;
+
+       if (time_after(now, ratelimit)) {
+               printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
+                               "Please check user daemon is running.\n");
+               ratelimit = now + 15*HZ;
+       }
+}
+
 static inline int
 gss_refresh_upcall(struct rpc_task *task)
 {
@@ -389,16 +428,25 @@ gss_refresh_upcall(struct rpc_task *task)
        struct gss_cred *gss_cred = container_of(cred,
                        struct gss_cred, gc_base);
        struct gss_upcall_msg *gss_msg;
-       struct inode *inode = gss_auth->dentry->d_inode;
+       struct inode *inode;
        int err = 0;
 
        dprintk("RPC: %5u gss_refresh_upcall for uid %u\n", task->tk_pid,
                                                                cred->cr_uid);
        gss_msg = gss_setup_upcall(task->tk_client, gss_auth, cred);
+       if (IS_ERR(gss_msg) == -EAGAIN) {
+               /* XXX: warning on the first, under the assumption we
+                * shouldn't normally hit this case on a refresh. */
+               warn_gssd();
+               task->tk_timeout = 15*HZ;
+               rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
+               return 0;
+       }
        if (IS_ERR(gss_msg)) {
                err = PTR_ERR(gss_msg);
                goto out;
        }
+       inode = &gss_msg->inode->vfs_inode;
        spin_lock(&inode->i_lock);
        if (gss_cred->gc_upcall != NULL)
                rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
@@ -425,18 +473,29 @@ out:
 static inline int
 gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
 {
-       struct inode *inode = gss_auth->dentry->d_inode;
+       struct inode *inode;
        struct rpc_cred *cred = &gss_cred->gc_base;
        struct gss_upcall_msg *gss_msg;
        DEFINE_WAIT(wait);
        int err = 0;
 
        dprintk("RPC:       gss_upcall for uid %u\n", cred->cr_uid);
+retry:
        gss_msg = gss_setup_upcall(gss_auth->client, gss_auth, cred);
+       if (PTR_ERR(gss_msg) == -EAGAIN) {
+               err = wait_event_interruptible_timeout(pipe_version_waitqueue,
+                               pipe_version >= 0, 15*HZ);
+               if (err)
+                       goto out;
+               if (pipe_version < 0)
+                       warn_gssd();
+               goto retry;
+       }
        if (IS_ERR(gss_msg)) {
                err = PTR_ERR(gss_msg);
                goto out;
        }
+       inode = &gss_msg->inode->vfs_inode;
        for (;;) {
                prepare_to_wait(&gss_msg->waitqueue, &wait, TASK_INTERRUPTIBLE);
                spin_lock(&inode->i_lock);
@@ -491,7 +550,6 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
 {
        const void *p, *end;
        void *buf;
-       struct rpc_clnt *clnt;
        struct gss_upcall_msg *gss_msg;
        struct inode *inode = filp->f_path.dentry->d_inode;
        struct gss_cl_ctx *ctx;
@@ -501,11 +559,10 @@ gss_pipe_downcall(struct file *filp, const char __user *src, size_t mlen)
        if (mlen > MSG_BUF_MAXSIZE)
                goto out;
        err = -ENOMEM;
-       buf = kmalloc(mlen, GFP_KERNEL);
+       buf = kmalloc(mlen, GFP_NOFS);
        if (!buf)
                goto out;
 
-       clnt = RPC_I(inode)->private;
        err = -EFAULT;
        if (copy_from_user(buf, src, mlen))
                goto err;
@@ -556,6 +613,20 @@ out:
        return err;
 }
 
+static int
+gss_pipe_open(struct inode *inode)
+{
+       spin_lock(&pipe_version_lock);
+       if (pipe_version < 0) {
+               pipe_version = 0;
+               rpc_wake_up(&pipe_version_rpc_waitqueue);
+               wake_up(&pipe_version_waitqueue);
+       }
+       atomic_inc(&pipe_users);
+       spin_unlock(&pipe_version_lock);
+       return 0;
+}
+
 static void
 gss_pipe_release(struct inode *inode)
 {
@@ -575,27 +646,22 @@ gss_pipe_release(struct inode *inode)
                spin_lock(&inode->i_lock);
        }
        spin_unlock(&inode->i_lock);
+
+       put_pipe_version();
 }
 
 static void
 gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
 {
        struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
-       static unsigned long ratelimit;
 
        if (msg->errno < 0) {
                dprintk("RPC:       gss_pipe_destroy_msg releasing msg %p\n",
                                gss_msg);
                atomic_inc(&gss_msg->count);
                gss_unhash_msg(gss_msg);
-               if (msg->errno == -ETIMEDOUT) {
-                       unsigned long now = jiffies;
-                       if (time_after(now, ratelimit)) {
-                               printk(KERN_WARNING "RPC: AUTH_GSS upcall timed out.\n"
-                                                   "Please check user daemon is running!\n");
-                               ratelimit = now + 15*HZ;
-                       }
-               }
+               if (msg->errno == -ETIMEDOUT)
+                       warn_gssd();
                gss_release_msg(gss_msg);
        }
 }
@@ -663,7 +729,6 @@ static void
 gss_free(struct gss_auth *gss_auth)
 {
        rpc_unlink(gss_auth->dentry);
-       gss_auth->dentry = NULL;
        gss_mech_put(gss_auth->mech);
 
        kfree(gss_auth);
@@ -706,7 +771,7 @@ gss_destroying_context(struct rpc_cred *cred)
        struct rpc_task *task;
 
        if (gss_cred->gc_ctx == NULL ||
-           test_and_clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
+           test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) == 0)
                return 0;
 
        gss_cred->gc_ctx->gc_proc = RPC_GSS_PROC_DESTROY;
@@ -770,14 +835,12 @@ gss_free_cred_callback(struct rcu_head *head)
 }
 
 static void
-gss_destroy_cred(struct rpc_cred *cred)
+gss_destroy_nullcred(struct rpc_cred *cred)
 {
        struct gss_cred *gss_cred = container_of(cred, struct gss_cred, gc_base);
        struct gss_auth *gss_auth = container_of(cred->cr_auth, struct gss_auth, rpc_auth);
        struct gss_cl_ctx *ctx = gss_cred->gc_ctx;
 
-       if (gss_destroying_context(cred))
-               return;
        rcu_assign_pointer(gss_cred->gc_ctx, NULL);
        call_rcu(&cred->cr_rcu, gss_free_cred_callback);
        if (ctx)
@@ -785,6 +848,15 @@ gss_destroy_cred(struct rpc_cred *cred)
        kref_put(&gss_auth->kref, gss_free_callback);
 }
 
+static void
+gss_destroy_cred(struct rpc_cred *cred)
+{
+
+       if (gss_destroying_context(cred))
+               return;
+       gss_destroy_nullcred(cred);
+}
+
 /*
  * Lookup RPCSEC_GSS cred for the current process
  */
@@ -804,7 +876,7 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags)
        dprintk("RPC:       gss_create_cred for uid %d, flavor %d\n",
                acred->uid, auth->au_flavor);
 
-       if (!(cred = kzalloc(sizeof(*cred), GFP_KERNEL)))
+       if (!(cred = kzalloc(sizeof(*cred), GFP_NOFS)))
                goto out_err;
 
        rpcauth_init_cred(&cred->gc_base, acred, auth, &gss_credops);
@@ -1030,7 +1102,7 @@ gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
        *p++ = htonl(rqstp->rq_seqno);
 
-       status = rpc_call_xdrproc(encode, rqstp, p, obj);
+       status = encode(rqstp, p, obj);
        if (status)
                return status;
 
@@ -1124,7 +1196,7 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
        offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
        *p++ = htonl(rqstp->rq_seqno);
 
-       status = rpc_call_xdrproc(encode, rqstp, p, obj);
+       status = encode(rqstp, p, obj);
        if (status)
                return status;
 
@@ -1183,12 +1255,12 @@ gss_wrap_req(struct rpc_task *task,
                /* The spec seems a little ambiguous here, but I think that not
                 * wrapping context destruction requests makes the most sense.
                 */
-               status = rpc_call_xdrproc(encode, rqstp, p, obj);
+               status = encode(rqstp, p, obj);
                goto out;
        }
        switch (gss_cred->gc_service) {
                case RPC_GSS_SVC_NONE:
-                       status = rpc_call_xdrproc(encode, rqstp, p, obj);
+                       status = encode(rqstp, p, obj);
                        break;
                case RPC_GSS_SVC_INTEGRITY:
                        status = gss_wrap_req_integ(cred, ctx, encode,
@@ -1304,7 +1376,7 @@ gss_unwrap_resp(struct rpc_task *task,
        cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
                                                + (savedlen - head->iov_len);
 out_decode:
-       status = rpc_call_xdrproc(decode, rqstp, p, obj);
+       status = decode(rqstp, p, obj);
 out:
        gss_put_ctx(ctx);
        dprintk("RPC: %5u gss_unwrap_resp returning %d\n", task->tk_pid,
@@ -1337,7 +1409,7 @@ static const struct rpc_credops gss_credops = {
 
 static const struct rpc_credops gss_nullops = {
        .cr_name        = "AUTH_GSS",
-       .crdestroy      = gss_destroy_cred,
+       .crdestroy      = gss_destroy_nullcred,
        .crbind         = rpcauth_generic_bind_cred,
        .crmatch        = gss_match,
        .crmarshal      = gss_marshal,
@@ -1351,6 +1423,7 @@ static struct rpc_pipe_ops gss_upcall_ops = {
        .upcall         = gss_pipe_upcall,
        .downcall       = gss_pipe_downcall,
        .destroy_msg    = gss_pipe_destroy_msg,
+       .open_pipe      = gss_pipe_open,
        .release_pipe   = gss_pipe_release,
 };
 
@@ -1367,6 +1440,7 @@ static int __init init_rpcsec_gss(void)
        err = gss_svc_init();
        if (err)
                goto out_unregister;
+       rpc_init_wait_queue(&pipe_version_rpc_waitqueue, "gss pipe version");
        return 0;
 out_unregister:
        rpcauth_unregister(&authgss_ops);