Merge branch 'linus' into core/softirq
[safe/jmp/linux-2.6] / net / xfrm / xfrm_user.c
index 70dca1e..b976d9e 100644 (file)
@@ -30,9 +30,8 @@
 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
 #include <linux/in6.h>
 #endif
-#include <linux/audit.h>
 
-static inline int alg_len(struct xfrm_algo *alg)
+static inline int aead_len(struct xfrm_algo_aead *alg)
 {
        return sizeof(*alg) + ((alg->alg_key_len + 7) / 8);
 }
@@ -46,24 +45,13 @@ static int verify_one_alg(struct nlattr **attrs, enum xfrm_attr_type_t type)
                return 0;
 
        algp = nla_data(rt);
-       if (nla_len(rt) < alg_len(algp))
+       if (nla_len(rt) < xfrm_alg_len(algp))
                return -EINVAL;
 
        switch (type) {
        case XFRMA_ALG_AUTH:
-               if (!algp->alg_key_len &&
-                   strcmp(algp->alg_name, "digest_null") != 0)
-                       return -EINVAL;
-               break;
-
        case XFRMA_ALG_CRYPT:
-               if (!algp->alg_key_len &&
-                   strcmp(algp->alg_name, "cipher_null") != 0)
-                       return -EINVAL;
-               break;
-
        case XFRMA_ALG_COMP:
-               /* Zero length keys are legal.  */
                break;
 
        default:
@@ -74,6 +62,22 @@ static int verify_one_alg(struct nlattr **attrs, enum xfrm_attr_type_t type)
        return 0;
 }
 
+static int verify_aead(struct nlattr **attrs)
+{
+       struct nlattr *rt = attrs[XFRMA_ALG_AEAD];
+       struct xfrm_algo_aead *algp;
+
+       if (!rt)
+               return 0;
+
+       algp = nla_data(rt);
+       if (nla_len(rt) < aead_len(algp))
+               return -EINVAL;
+
+       algp->alg_name[CRYPTO_MAX_ALG_NAME - 1] = '\0';
+       return 0;
+}
+
 static void verify_one_addr(struct nlattr **attrs, enum xfrm_attr_type_t type,
                           xfrm_address_t **addrp)
 {
@@ -125,20 +129,28 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
        switch (p->id.proto) {
        case IPPROTO_AH:
                if (!attrs[XFRMA_ALG_AUTH]      ||
+                   attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_CRYPT]      ||
                    attrs[XFRMA_ALG_COMP])
                        goto out;
                break;
 
        case IPPROTO_ESP:
-               if ((!attrs[XFRMA_ALG_AUTH] &&
-                    !attrs[XFRMA_ALG_CRYPT])   ||
-                   attrs[XFRMA_ALG_COMP])
+               if (attrs[XFRMA_ALG_COMP])
+                       goto out;
+               if (!attrs[XFRMA_ALG_AUTH] &&
+                   !attrs[XFRMA_ALG_CRYPT] &&
+                   !attrs[XFRMA_ALG_AEAD])
+                       goto out;
+               if ((attrs[XFRMA_ALG_AUTH] ||
+                    attrs[XFRMA_ALG_CRYPT]) &&
+                   attrs[XFRMA_ALG_AEAD])
                        goto out;
                break;
 
        case IPPROTO_COMP:
                if (!attrs[XFRMA_ALG_COMP]      ||
+                   attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_AUTH]       ||
                    attrs[XFRMA_ALG_CRYPT])
                        goto out;
@@ -149,6 +161,7 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
        case IPPROTO_ROUTING:
                if (attrs[XFRMA_ALG_COMP]       ||
                    attrs[XFRMA_ALG_AUTH]       ||
+                   attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_CRYPT]      ||
                    attrs[XFRMA_ENCAP]          ||
                    attrs[XFRMA_SEC_CTX]        ||
@@ -161,6 +174,8 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
                goto out;
        }
 
+       if ((err = verify_aead(attrs)))
+               goto out;
        if ((err = verify_one_alg(attrs, XFRMA_ALG_AUTH)))
                goto out;
        if ((err = verify_one_alg(attrs, XFRMA_ALG_CRYPT)))
@@ -205,7 +220,7 @@ static int attach_one_algo(struct xfrm_algo **algpp, u8 *props,
                return -ENOSYS;
        *props = algo->desc.sadb_alg_id;
 
-       p = kmemdup(ualg, alg_len(ualg), GFP_KERNEL);
+       p = kmemdup(ualg, xfrm_alg_len(ualg), GFP_KERNEL);
        if (!p)
                return -ENOMEM;
 
@@ -214,23 +229,31 @@ static int attach_one_algo(struct xfrm_algo **algpp, u8 *props,
        return 0;
 }
 
-static int attach_encap_tmpl(struct xfrm_encap_tmpl **encapp, struct nlattr *rta)
+static int attach_aead(struct xfrm_algo_aead **algpp, u8 *props,
+                      struct nlattr *rta)
 {
-       struct xfrm_encap_tmpl *p, *uencap;
+       struct xfrm_algo_aead *p, *ualg;
+       struct xfrm_algo_desc *algo;
 
        if (!rta)
                return 0;
 
-       uencap = nla_data(rta);
-       p = kmemdup(uencap, sizeof(*p), GFP_KERNEL);
+       ualg = nla_data(rta);
+
+       algo = xfrm_aead_get_byname(ualg->alg_name, ualg->alg_icv_len, 1);
+       if (!algo)
+               return -ENOSYS;
+       *props = algo->desc.sadb_alg_id;
+
+       p = kmemdup(ualg, aead_len(ualg), GFP_KERNEL);
        if (!p)
                return -ENOMEM;
 
-       *encapp = p;
+       strcpy(p->alg_name, algo->name);
+       *algpp = p;
        return 0;
 }
 
-
 static inline int xfrm_user_sec_ctx_size(struct xfrm_sec_ctx *xfrm_ctx)
 {
        int len = 0;
@@ -242,33 +265,6 @@ static inline int xfrm_user_sec_ctx_size(struct xfrm_sec_ctx *xfrm_ctx)
        return len;
 }
 
-static int attach_sec_ctx(struct xfrm_state *x, struct nlattr *u_arg)
-{
-       struct xfrm_user_sec_ctx *uctx;
-
-       if (!u_arg)
-               return 0;
-
-       uctx = nla_data(u_arg);
-       return security_xfrm_state_alloc(x, uctx);
-}
-
-static int attach_one_addr(xfrm_address_t **addrpp, struct nlattr *rta)
-{
-       xfrm_address_t *p, *uaddrp;
-
-       if (!rta)
-               return 0;
-
-       uaddrp = nla_data(rta);
-       p = kmemdup(uaddrp, sizeof(*p), GFP_KERNEL);
-       if (!p)
-               return -ENOMEM;
-
-       *addrpp = p;
-       return 0;
-}
-
 static void copy_from_user_state(struct xfrm_state *x, struct xfrm_usersa_info *p)
 {
        memcpy(&x->id, &p->id, sizeof(x->id));
@@ -281,12 +277,9 @@ static void copy_from_user_state(struct xfrm_state *x, struct xfrm_usersa_info *
        memcpy(&x->props.saddr, &p->saddr, sizeof(x->props.saddr));
        x->props.flags = p->flags;
 
-       /*
-        * Set inner address family if the KM left it as zero.
-        * See comment in validate_tmpl.
-        */
        if (!x->sel.family)
                x->sel.family = p->family;
+
 }
 
 /*
@@ -336,6 +329,9 @@ static struct xfrm_state *xfrm_state_construct(struct xfrm_usersa_info *p,
 
        copy_from_user_state(x, p);
 
+       if ((err = attach_aead(&x->aead, &x->props.ealgo,
+                              attrs[XFRMA_ALG_AEAD])))
+               goto error;
        if ((err = attach_one_algo(&x->aalg, &x->props.aalgo,
                                   xfrm_aalg_get_byname,
                                   attrs[XFRMA_ALG_AUTH])))
@@ -348,15 +344,27 @@ static struct xfrm_state *xfrm_state_construct(struct xfrm_usersa_info *p,
                                   xfrm_calg_get_byname,
                                   attrs[XFRMA_ALG_COMP])))
                goto error;
-       if ((err = attach_encap_tmpl(&x->encap, attrs[XFRMA_ENCAP])))
-               goto error;
-       if ((err = attach_one_addr(&x->coaddr, attrs[XFRMA_COADDR])))
-               goto error;
+
+       if (attrs[XFRMA_ENCAP]) {
+               x->encap = kmemdup(nla_data(attrs[XFRMA_ENCAP]),
+                                  sizeof(*x->encap), GFP_KERNEL);
+               if (x->encap == NULL)
+                       goto error;
+       }
+
+       if (attrs[XFRMA_COADDR]) {
+               x->coaddr = kmemdup(nla_data(attrs[XFRMA_COADDR]),
+                                   sizeof(*x->coaddr), GFP_KERNEL);
+               if (x->coaddr == NULL)
+                       goto error;
+       }
+
        err = xfrm_init_state(x);
        if (err)
                goto error;
 
-       if ((err = attach_sec_ctx(x, attrs[XFRMA_SEC_CTX])))
+       if (attrs[XFRMA_SEC_CTX] &&
+           security_xfrm_state_alloc(x, nla_data(attrs[XFRMA_SEC_CTX])))
                goto error;
 
        x->km.seq = p->seq;
@@ -388,6 +396,9 @@ static int xfrm_add_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        struct xfrm_state *x;
        int err;
        struct km_event c;
+       uid_t loginuid = NETLINK_CB(skb).loginuid;
+       u32 sessionid = NETLINK_CB(skb).sessionid;
+       u32 sid = NETLINK_CB(skb).sid;
 
        err = verify_newsa_info(p, attrs);
        if (err)
@@ -403,8 +414,7 @@ static int xfrm_add_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        else
                err = xfrm_state_update(x);
 
-       xfrm_audit_log(NETLINK_CB(skb).loginuid, NETLINK_CB(skb).sid,
-                      AUDIT_MAC_IPSEC_ADDSA, err ? 0 : 1, NULL, x);
+       xfrm_audit_state_add(x, err ? 0 : 1, loginuid, sessionid, sid);
 
        if (err < 0) {
                x->km.state = XFRM_STATE_DEAD;
@@ -459,6 +469,9 @@ static int xfrm_del_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        int err = -ESRCH;
        struct km_event c;
        struct xfrm_usersa_id *p = nlmsg_data(nlh);
+       uid_t loginuid = NETLINK_CB(skb).loginuid;
+       u32 sessionid = NETLINK_CB(skb).sessionid;
+       u32 sid = NETLINK_CB(skb).sid;
 
        x = xfrm_user_state_lookup(p, attrs, &err);
        if (x == NULL)
@@ -483,8 +496,7 @@ static int xfrm_del_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        km_state_notify(x, &c);
 
 out:
-       xfrm_audit_log(NETLINK_CB(skb).loginuid, NETLINK_CB(skb).sid,
-                      AUDIT_MAC_IPSEC_DELSA, err ? 0 : 1, NULL, x);
+       xfrm_audit_state_delete(x, err ? 0 : 1, loginuid, sessionid, sid);
        xfrm_state_put(x);
        return err;
 }
@@ -510,15 +522,13 @@ struct xfrm_dump_info {
        struct sk_buff *out_skb;
        u32 nlmsg_seq;
        u16 nlmsg_flags;
-       int start_idx;
-       int this_idx;
 };
 
 static int copy_sec_ctx(struct xfrm_sec_ctx *s, struct sk_buff *skb)
 {
-       int ctx_size = sizeof(struct xfrm_sec_ctx) + s->ctx_len;
        struct xfrm_user_sec_ctx *uctx;
        struct nlattr *attr;
+       int ctx_size = sizeof(*uctx) + s->ctx_len;
 
        attr = nla_reserve(skb, XFRMA_SEC_CTX, ctx_size);
        if (attr == NULL)
@@ -535,29 +545,25 @@ static int copy_sec_ctx(struct xfrm_sec_ctx *s, struct sk_buff *skb)
        return 0;
 }
 
-static int dump_one_state(struct xfrm_state *x, int count, void *ptr)
+/* Don't change this without updating xfrm_sa_len! */
+static int copy_to_user_state_extra(struct xfrm_state *x,
+                                   struct xfrm_usersa_info *p,
+                                   struct sk_buff *skb)
 {
-       struct xfrm_dump_info *sp = ptr;
-       struct sk_buff *in_skb = sp->in_skb;
-       struct sk_buff *skb = sp->out_skb;
-       struct xfrm_usersa_info *p;
-       struct nlmsghdr *nlh;
-
-       if (sp->this_idx < sp->start_idx)
-               goto out;
+       copy_to_user_state(x, p);
 
-       nlh = nlmsg_put(skb, NETLINK_CB(in_skb).pid, sp->nlmsg_seq,
-                       XFRM_MSG_NEWSA, sizeof(*p), sp->nlmsg_flags);
-       if (nlh == NULL)
-               return -EMSGSIZE;
+       if (x->coaddr)
+               NLA_PUT(skb, XFRMA_COADDR, sizeof(*x->coaddr), x->coaddr);
 
-       p = nlmsg_data(nlh);
-       copy_to_user_state(x, p);
+       if (x->lastused)
+               NLA_PUT_U64(skb, XFRMA_LASTUSED, x->lastused);
 
+       if (x->aead)
+               NLA_PUT(skb, XFRMA_ALG_AEAD, aead_len(x->aead), x->aead);
        if (x->aalg)
-               NLA_PUT(skb, XFRMA_ALG_AUTH, alg_len(x->aalg), x->aalg);
+               NLA_PUT(skb, XFRMA_ALG_AUTH, xfrm_alg_len(x->aalg), x->aalg);
        if (x->ealg)
-               NLA_PUT(skb, XFRMA_ALG_CRYPT, alg_len(x->ealg), x->ealg);
+               NLA_PUT(skb, XFRMA_ALG_CRYPT, xfrm_alg_len(x->ealg), x->ealg);
        if (x->calg)
                NLA_PUT(skb, XFRMA_ALG_COMP, sizeof(*(x->calg)), x->calg);
 
@@ -567,34 +573,66 @@ static int dump_one_state(struct xfrm_state *x, int count, void *ptr)
        if (x->security && copy_sec_ctx(x->security, skb) < 0)
                goto nla_put_failure;
 
-       if (x->coaddr)
-               NLA_PUT(skb, XFRMA_COADDR, sizeof(*x->coaddr), x->coaddr);
+       return 0;
 
-       if (x->lastused)
-               NLA_PUT_U64(skb, XFRMA_LASTUSED, x->lastused);
+nla_put_failure:
+       return -EMSGSIZE;
+}
+
+static int dump_one_state(struct xfrm_state *x, int count, void *ptr)
+{
+       struct xfrm_dump_info *sp = ptr;
+       struct sk_buff *in_skb = sp->in_skb;
+       struct sk_buff *skb = sp->out_skb;
+       struct xfrm_usersa_info *p;
+       struct nlmsghdr *nlh;
+       int err;
+
+       nlh = nlmsg_put(skb, NETLINK_CB(in_skb).pid, sp->nlmsg_seq,
+                       XFRM_MSG_NEWSA, sizeof(*p), sp->nlmsg_flags);
+       if (nlh == NULL)
+               return -EMSGSIZE;
+
+       p = nlmsg_data(nlh);
+
+       err = copy_to_user_state_extra(x, p, skb);
+       if (err)
+               goto nla_put_failure;
 
        nlmsg_end(skb, nlh);
-out:
-       sp->this_idx++;
        return 0;
 
 nla_put_failure:
        nlmsg_cancel(skb, nlh);
-       return -EMSGSIZE;
+       return err;
+}
+
+static int xfrm_dump_sa_done(struct netlink_callback *cb)
+{
+       struct xfrm_state_walk *walk = (struct xfrm_state_walk *) &cb->args[1];
+       xfrm_state_walk_done(walk);
+       return 0;
 }
 
 static int xfrm_dump_sa(struct sk_buff *skb, struct netlink_callback *cb)
 {
+       struct xfrm_state_walk *walk = (struct xfrm_state_walk *) &cb->args[1];
        struct xfrm_dump_info info;
 
+       BUILD_BUG_ON(sizeof(struct xfrm_state_walk) >
+                    sizeof(cb->args) - sizeof(cb->args[0]));
+
        info.in_skb = cb->skb;
        info.out_skb = skb;
        info.nlmsg_seq = cb->nlh->nlmsg_seq;
        info.nlmsg_flags = NLM_F_MULTI;
-       info.this_idx = 0;
-       info.start_idx = cb->args[0];
-       (void) xfrm_state_walk(0, dump_one_state, &info);
-       cb->args[0] = info.this_idx;
+
+       if (!cb->args[0]) {
+               cb->args[0] = 1;
+               xfrm_state_walk_init(walk, 0);
+       }
+
+       (void) xfrm_state_walk(walk, dump_one_state, &info);
 
        return skb->len;
 }
@@ -613,7 +651,6 @@ static struct sk_buff *xfrm_state_netlink(struct sk_buff *in_skb,
        info.out_skb = skb;
        info.nlmsg_seq = seq;
        info.nlmsg_flags = 0;
-       info.this_idx = info.start_idx = 0;
 
        if (dump_one_state(x, 0, &info)) {
                kfree_skb(skb);
@@ -817,16 +854,11 @@ static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (x == NULL)
                goto out_noput;
 
-       resp_skb = ERR_PTR(-ENOENT);
-
-       spin_lock_bh(&x->lock);
-       if (x->km.state != XFRM_STATE_DEAD) {
-               xfrm_alloc_spi(x, htonl(p->min), htonl(p->max));
-               if (x->id.spi)
-                       resp_skb = xfrm_state_netlink(skb, x, nlh->nlmsg_seq);
-       }
-       spin_unlock_bh(&x->lock);
+       err = xfrm_alloc_spi(x, p->min, p->max);
+       if (err)
+               goto out;
 
+       resp_skb = xfrm_state_netlink(skb, x, nlh->nlmsg_seq);
        if (IS_ERR(resp_skb)) {
                err = PTR_ERR(resp_skb);
                goto out;
@@ -920,7 +952,7 @@ static int copy_from_user_sec_ctx(struct xfrm_policy *pol, struct nlattr **attrs
                return 0;
 
        uctx = nla_data(rt);
-       return security_xfrm_policy_alloc(pol, uctx);
+       return security_xfrm_policy_alloc(&pol->security, uctx);
 }
 
 static void copy_templates(struct xfrm_policy *xp, struct xfrm_user_tmpl *ut,
@@ -942,6 +974,8 @@ static void copy_templates(struct xfrm_policy *xp, struct xfrm_user_tmpl *ut,
                t->aalgos = ut->aalgos;
                t->ealgos = ut->ealgos;
                t->calgos = ut->calgos;
+               /* If all masks are ~0, then we allow all algorithms. */
+               t->allalgs = !~(t->aalgos & t->ealgos & t->calgos);
                t->encap_family = ut->family;
        }
 }
@@ -1069,7 +1103,8 @@ static struct xfrm_policy *xfrm_policy_construct(struct xfrm_userpolicy_info *p,
        return xp;
  error:
        *errp = err;
-       kfree(xp);
+       xp->dead = 1;
+       xfrm_policy_destroy(xp);
        return NULL;
 }
 
@@ -1081,6 +1116,9 @@ static int xfrm_add_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
        struct km_event c;
        int err;
        int excl;
+       uid_t loginuid = NETLINK_CB(skb).loginuid;
+       u32 sessionid = NETLINK_CB(skb).sessionid;
+       u32 sid = NETLINK_CB(skb).sid;
 
        err = verify_newpolicy_info(p);
        if (err)
@@ -1099,11 +1137,10 @@ static int xfrm_add_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
         * a type XFRM_MSG_UPDPOLICY - JHS */
        excl = nlh->nlmsg_type == XFRM_MSG_NEWPOLICY;
        err = xfrm_policy_insert(p->dir, xp, excl);
-       xfrm_audit_log(NETLINK_CB(skb).loginuid, NETLINK_CB(skb).sid,
-                      AUDIT_MAC_IPSEC_DELSPD, err ? 0 : 1, xp, NULL);
+       xfrm_audit_policy_add(xp, err ? 0 : 1, loginuid, sessionid, sid);
 
        if (err) {
-               security_xfrm_policy_free(xp);
+               security_xfrm_policy_free(xp->security);
                kfree(xp);
                return err;
        }
@@ -1195,9 +1232,6 @@ static int dump_one_policy(struct xfrm_policy *xp, int dir, int count, void *ptr
        struct sk_buff *skb = sp->out_skb;
        struct nlmsghdr *nlh;
 
-       if (sp->this_idx < sp->start_idx)
-               goto out;
-
        nlh = nlmsg_put(skb, NETLINK_CB(in_skb).pid, sp->nlmsg_seq,
                        XFRM_MSG_NEWPOLICY, sizeof(*p), sp->nlmsg_flags);
        if (nlh == NULL)
@@ -1213,8 +1247,6 @@ static int dump_one_policy(struct xfrm_policy *xp, int dir, int count, void *ptr
                goto nlmsg_failure;
 
        nlmsg_end(skb, nlh);
-out:
-       sp->this_idx++;
        return 0;
 
 nlmsg_failure:
@@ -1222,21 +1254,33 @@ nlmsg_failure:
        return -EMSGSIZE;
 }
 
+static int xfrm_dump_policy_done(struct netlink_callback *cb)
+{
+       struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *) &cb->args[1];
+
+       xfrm_policy_walk_done(walk);
+       return 0;
+}
+
 static int xfrm_dump_policy(struct sk_buff *skb, struct netlink_callback *cb)
 {
+       struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *) &cb->args[1];
        struct xfrm_dump_info info;
 
+       BUILD_BUG_ON(sizeof(struct xfrm_policy_walk) >
+                    sizeof(cb->args) - sizeof(cb->args[0]));
+
        info.in_skb = cb->skb;
        info.out_skb = skb;
        info.nlmsg_seq = cb->nlh->nlmsg_seq;
        info.nlmsg_flags = NLM_F_MULTI;
-       info.this_idx = 0;
-       info.start_idx = cb->args[0];
-       (void) xfrm_policy_walk(XFRM_POLICY_TYPE_MAIN, dump_one_policy, &info);
-#ifdef CONFIG_XFRM_SUB_POLICY
-       (void) xfrm_policy_walk(XFRM_POLICY_TYPE_SUB, dump_one_policy, &info);
-#endif
-       cb->args[0] = info.this_idx;
+
+       if (!cb->args[0]) {
+               cb->args[0] = 1;
+               xfrm_policy_walk_init(walk, XFRM_POLICY_TYPE_ANY);
+       }
+
+       (void) xfrm_policy_walk(walk, dump_one_policy, &info);
 
        return skb->len;
 }
@@ -1256,7 +1300,6 @@ static struct sk_buff *xfrm_policy_netlink(struct sk_buff *in_skb,
        info.out_skb = skb;
        info.nlmsg_seq = seq;
        info.nlmsg_flags = 0;
-       info.this_idx = info.start_idx = 0;
 
        if (dump_one_policy(xp, dir, 0, &info) < 0) {
                kfree_skb(skb);
@@ -1291,22 +1334,23 @@ static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
                xp = xfrm_policy_byid(type, p->dir, p->index, delete, &err);
        else {
                struct nlattr *rt = attrs[XFRMA_SEC_CTX];
-               struct xfrm_policy tmp;
+               struct xfrm_sec_ctx *ctx;
 
                err = verify_sec_ctx_len(attrs);
                if (err)
                        return err;
 
-               memset(&tmp, 0, sizeof(struct xfrm_policy));
+               ctx = NULL;
                if (rt) {
                        struct xfrm_user_sec_ctx *uctx = nla_data(rt);
 
-                       if ((err = security_xfrm_policy_alloc(&tmp, uctx)))
+                       err = security_xfrm_policy_alloc(&ctx, uctx);
+                       if (err)
                                return err;
                }
-               xp = xfrm_policy_bysel_ctx(type, p->dir, &p->sel, tmp.security,
+               xp = xfrm_policy_bysel_ctx(type, p->dir, &p->sel, ctx,
                                           delete, &err);
-               security_xfrm_policy_free(&tmp);
+               security_xfrm_policy_free(ctx);
        }
        if (xp == NULL)
                return -ENOENT;
@@ -1322,8 +1366,12 @@ static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
                                            NETLINK_CB(skb).pid);
                }
        } else {
-               xfrm_audit_log(NETLINK_CB(skb).loginuid, NETLINK_CB(skb).sid,
-                              AUDIT_MAC_IPSEC_DELSPD, err ? 0 : 1, xp, NULL);
+               uid_t loginuid = NETLINK_CB(skb).loginuid;
+               u32 sessionid = NETLINK_CB(skb).sessionid;
+               u32 sid = NETLINK_CB(skb).sid;
+
+               xfrm_audit_policy_delete(xp, err ? 0 : 1, loginuid, sessionid,
+                                        sid);
 
                if (err != 0)
                        goto out;
@@ -1349,6 +1397,7 @@ static int xfrm_flush_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        int err;
 
        audit_info.loginuid = NETLINK_CB(skb).loginuid;
+       audit_info.sessionid = NETLINK_CB(skb).sessionid;
        audit_info.secid = NETLINK_CB(skb).sid;
        err = xfrm_state_flush(p->proto, &audit_info);
        if (err)
@@ -1496,6 +1545,7 @@ static int xfrm_flush_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
                return err;
 
        audit_info.loginuid = NETLINK_CB(skb).loginuid;
+       audit_info.sessionid = NETLINK_CB(skb).sessionid;
        audit_info.secid = NETLINK_CB(skb).sid;
        err = xfrm_policy_flush(type, &audit_info);
        if (err)
@@ -1525,26 +1575,26 @@ static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh,
                xp = xfrm_policy_byid(type, p->dir, p->index, 0, &err);
        else {
                struct nlattr *rt = attrs[XFRMA_SEC_CTX];
-               struct xfrm_policy tmp;
+               struct xfrm_sec_ctx *ctx;
 
                err = verify_sec_ctx_len(attrs);
                if (err)
                        return err;
 
-               memset(&tmp, 0, sizeof(struct xfrm_policy));
+               ctx = NULL;
                if (rt) {
                        struct xfrm_user_sec_ctx *uctx = nla_data(rt);
 
-                       if ((err = security_xfrm_policy_alloc(&tmp, uctx)))
+                       err = security_xfrm_policy_alloc(&ctx, uctx);
+                       if (err)
                                return err;
                }
-               xp = xfrm_policy_bysel_ctx(type, p->dir, &p->sel, tmp.security,
-                                          0, &err);
-               security_xfrm_policy_free(&tmp);
+               xp = xfrm_policy_bysel_ctx(type, p->dir, &p->sel, ctx, 0, &err);
+               security_xfrm_policy_free(ctx);
        }
-
        if (xp == NULL)
                return -ENOENT;
+
        read_lock(&xp->lock);
        if (xp->dead) {
                read_unlock(&xp->lock);
@@ -1554,9 +1604,11 @@ static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh,
        read_unlock(&xp->lock);
        err = 0;
        if (up->hard) {
+               uid_t loginuid = NETLINK_CB(skb).loginuid;
+               uid_t sessionid = NETLINK_CB(skb).sessionid;
+               u32 sid = NETLINK_CB(skb).sid;
                xfrm_policy_delete(xp, p->dir);
-               xfrm_audit_log(NETLINK_CB(skb).loginuid, NETLINK_CB(skb).sid,
-                               AUDIT_MAC_IPSEC_DELSPD, 1, xp, NULL);
+               xfrm_audit_policy_delete(xp, 1, loginuid, sessionid, sid);
 
        } else {
                // reset the timers here?
@@ -1590,9 +1642,11 @@ static int xfrm_add_sa_expire(struct sk_buff *skb, struct nlmsghdr *nlh,
        km_state_expired(x, ue->hard, current->pid);
 
        if (ue->hard) {
+               uid_t loginuid = NETLINK_CB(skb).loginuid;
+               uid_t sessionid = NETLINK_CB(skb).sessionid;
+               u32 sid = NETLINK_CB(skb).sid;
                __xfrm_state_delete(x);
-               xfrm_audit_log(NETLINK_CB(skb).loginuid, NETLINK_CB(skb).sid,
-                              AUDIT_MAC_IPSEC_DELSA, 1, NULL, x);
+               xfrm_audit_state_delete(x, 1, loginuid, sessionid, sid);
        }
        err = 0;
 out:
@@ -1833,6 +1887,7 @@ static const int xfrm_msg_min[XFRM_NR_MSGTYPES] = {
 #undef XMSGSIZE
 
 static const struct nla_policy xfrma_policy[XFRMA_MAX+1] = {
+       [XFRMA_ALG_AEAD]        = { .len = sizeof(struct xfrm_algo_aead) },
        [XFRMA_ALG_AUTH]        = { .len = sizeof(struct xfrm_algo) },
        [XFRMA_ALG_CRYPT]       = { .len = sizeof(struct xfrm_algo) },
        [XFRMA_ALG_COMP]        = { .len = sizeof(struct xfrm_algo) },
@@ -1852,15 +1907,18 @@ static const struct nla_policy xfrma_policy[XFRMA_MAX+1] = {
 static struct xfrm_link {
        int (*doit)(struct sk_buff *, struct nlmsghdr *, struct nlattr **);
        int (*dump)(struct sk_buff *, struct netlink_callback *);
+       int (*done)(struct netlink_callback *);
 } xfrm_dispatch[XFRM_NR_MSGTYPES] = {
        [XFRM_MSG_NEWSA       - XFRM_MSG_BASE] = { .doit = xfrm_add_sa        },
        [XFRM_MSG_DELSA       - XFRM_MSG_BASE] = { .doit = xfrm_del_sa        },
        [XFRM_MSG_GETSA       - XFRM_MSG_BASE] = { .doit = xfrm_get_sa,
-                                                  .dump = xfrm_dump_sa       },
+                                                  .dump = xfrm_dump_sa,
+                                                  .done = xfrm_dump_sa_done  },
        [XFRM_MSG_NEWPOLICY   - XFRM_MSG_BASE] = { .doit = xfrm_add_policy    },
        [XFRM_MSG_DELPOLICY   - XFRM_MSG_BASE] = { .doit = xfrm_get_policy    },
        [XFRM_MSG_GETPOLICY   - XFRM_MSG_BASE] = { .doit = xfrm_get_policy,
-                                                  .dump = xfrm_dump_policy   },
+                                                  .dump = xfrm_dump_policy,
+                                                  .done = xfrm_dump_policy_done },
        [XFRM_MSG_ALLOCSPI    - XFRM_MSG_BASE] = { .doit = xfrm_alloc_userspi },
        [XFRM_MSG_ACQUIRE     - XFRM_MSG_BASE] = { .doit = xfrm_add_acquire   },
        [XFRM_MSG_EXPIRE      - XFRM_MSG_BASE] = { .doit = xfrm_add_sa_expire },
@@ -1899,7 +1957,7 @@ static int xfrm_user_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
                if (link->dump == NULL)
                        return -EINVAL;
 
-               return netlink_dump_start(xfrm_nl, skb, nlh, link->dump, NULL);
+               return netlink_dump_start(xfrm_nl, skb, nlh, link->dump, link->done);
        }
 
        err = nlmsg_parse(nlh, xfrm_msg_min[type], attrs, XFRMA_MAX,
@@ -1913,16 +1971,11 @@ static int xfrm_user_rcv_msg(struct sk_buff *skb, struct nlmsghdr *nlh)
        return link->doit(skb, nlh, attrs);
 }
 
-static void xfrm_netlink_rcv(struct sock *sk, int len)
+static void xfrm_netlink_rcv(struct sk_buff *skb)
 {
-       unsigned int qlen = 0;
-
-       do {
-               mutex_lock(&xfrm_cfg_mutex);
-               netlink_run_queue(sk, &qlen, &xfrm_user_rcv_msg);
-               mutex_unlock(&xfrm_cfg_mutex);
-
-       } while (qlen);
+       mutex_lock(&xfrm_cfg_mutex);
+       netlink_rcv_skb(skb, &xfrm_user_rcv_msg);
+       mutex_unlock(&xfrm_cfg_mutex);
 }
 
 static inline size_t xfrm_expire_msgsize(void)
@@ -2002,14 +2055,24 @@ static int xfrm_notify_sa_flush(struct km_event *c)
 static inline size_t xfrm_sa_len(struct xfrm_state *x)
 {
        size_t l = 0;
+       if (x->aead)
+               l += nla_total_size(aead_len(x->aead));
        if (x->aalg)
-               l += nla_total_size(alg_len(x->aalg));
+               l += nla_total_size(xfrm_alg_len(x->aalg));
        if (x->ealg)
-               l += nla_total_size(alg_len(x->ealg));
+               l += nla_total_size(xfrm_alg_len(x->ealg));
        if (x->calg)
                l += nla_total_size(sizeof(*x->calg));
        if (x->encap)
                l += nla_total_size(sizeof(*x->encap));
+       if (x->security)
+               l += nla_total_size(sizeof(struct xfrm_user_sec_ctx) +
+                                   x->security->ctx_len);
+       if (x->coaddr)
+               l += nla_total_size(sizeof(*x->coaddr));
+
+       /* Must count x->lastused as it may become non-zero behind our back. */
+       l += nla_total_size(sizeof(u64));
 
        return l;
 }
@@ -2055,23 +2118,16 @@ static int xfrm_notify_sa(struct xfrm_state *x, struct km_event *c)
                p = nla_data(attr);
        }
 
-       copy_to_user_state(x, p);
-
-       if (x->aalg)
-               NLA_PUT(skb, XFRMA_ALG_AUTH, alg_len(x->aalg), x->aalg);
-       if (x->ealg)
-               NLA_PUT(skb, XFRMA_ALG_CRYPT, alg_len(x->ealg), x->ealg);
-       if (x->calg)
-               NLA_PUT(skb, XFRMA_ALG_COMP, sizeof(*(x->calg)), x->calg);
-
-       if (x->encap)
-               NLA_PUT(skb, XFRMA_ENCAP, sizeof(*x->encap), x->encap);
+       if (copy_to_user_state_extra(x, p, skb))
+               goto nla_put_failure;
 
        nlmsg_end(skb, nlh);
 
        return nlmsg_multicast(xfrm_nl, skb, 0, XFRMNLGRP_SA, GFP_ATOMIC);
 
 nla_put_failure:
+       /* Somebody screwed up with xfrm_sa_len! */
+       WARN_ON(1);
        kfree_skb(skb);
        return -1;
 }
@@ -2431,7 +2487,7 @@ static int __init xfrm_user_init(void)
 
        printk(KERN_INFO "Initializing XFRM netlink socket\n");
 
-       nlsk = netlink_kernel_create(NETLINK_XFRM, XFRMNLGRP_MAX,
+       nlsk = netlink_kernel_create(&init_net, NETLINK_XFRM, XFRMNLGRP_MAX,
                                     xfrm_netlink_rcv, NULL, THIS_MODULE);
        if (nlsk == NULL)
                return -ENOMEM;
@@ -2449,7 +2505,7 @@ static void __exit xfrm_user_exit(void)
        xfrm_unregister_km(&netlink_mgr);
        rcu_assign_pointer(xfrm_nl, NULL);
        synchronize_rcu();
-       sock_release(nlsk->sk_socket);
+       netlink_kernel_release(nlsk);
 }
 
 module_init(xfrm_user_init);