net: x25: use seq_list_foo() helpers
[safe/jmp/linux-2.6] / net / xfrm / xfrm_user.c
index 8b5b01d..d5a7129 100644 (file)
@@ -62,6 +62,22 @@ static int verify_one_alg(struct nlattr **attrs, enum xfrm_attr_type_t type)
        return 0;
 }
 
+static int verify_auth_trunc(struct nlattr **attrs)
+{
+       struct nlattr *rt = attrs[XFRMA_ALG_AUTH_TRUNC];
+       struct xfrm_algo_auth *algp;
+
+       if (!rt)
+               return 0;
+
+       algp = nla_data(rt);
+       if (nla_len(rt) < xfrm_alg_auth_len(algp))
+               return -EINVAL;
+
+       algp->alg_name[CRYPTO_MAX_ALG_NAME - 1] = '\0';
+       return 0;
+}
+
 static int verify_aead(struct nlattr **attrs)
 {
        struct nlattr *rt = attrs[XFRMA_ALG_AEAD];
@@ -128,7 +144,8 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
        err = -EINVAL;
        switch (p->id.proto) {
        case IPPROTO_AH:
-               if (!attrs[XFRMA_ALG_AUTH]      ||
+               if ((!attrs[XFRMA_ALG_AUTH]     &&
+                    !attrs[XFRMA_ALG_AUTH_TRUNC]) ||
                    attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_CRYPT]      ||
                    attrs[XFRMA_ALG_COMP])
@@ -139,10 +156,12 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
                if (attrs[XFRMA_ALG_COMP])
                        goto out;
                if (!attrs[XFRMA_ALG_AUTH] &&
+                   !attrs[XFRMA_ALG_AUTH_TRUNC] &&
                    !attrs[XFRMA_ALG_CRYPT] &&
                    !attrs[XFRMA_ALG_AEAD])
                        goto out;
                if ((attrs[XFRMA_ALG_AUTH] ||
+                    attrs[XFRMA_ALG_AUTH_TRUNC] ||
                     attrs[XFRMA_ALG_CRYPT]) &&
                    attrs[XFRMA_ALG_AEAD])
                        goto out;
@@ -152,6 +171,7 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
                if (!attrs[XFRMA_ALG_COMP]      ||
                    attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_AUTH]       ||
+                   attrs[XFRMA_ALG_AUTH_TRUNC] ||
                    attrs[XFRMA_ALG_CRYPT])
                        goto out;
                break;
@@ -161,6 +181,7 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
        case IPPROTO_ROUTING:
                if (attrs[XFRMA_ALG_COMP]       ||
                    attrs[XFRMA_ALG_AUTH]       ||
+                   attrs[XFRMA_ALG_AUTH_TRUNC] ||
                    attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_CRYPT]      ||
                    attrs[XFRMA_ENCAP]          ||
@@ -176,6 +197,8 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
 
        if ((err = verify_aead(attrs)))
                goto out;
+       if ((err = verify_auth_trunc(attrs)))
+               goto out;
        if ((err = verify_one_alg(attrs, XFRMA_ALG_AUTH)))
                goto out;
        if ((err = verify_one_alg(attrs, XFRMA_ALG_CRYPT)))
@@ -229,6 +252,66 @@ static int attach_one_algo(struct xfrm_algo **algpp, u8 *props,
        return 0;
 }
 
+static int attach_auth(struct xfrm_algo_auth **algpp, u8 *props,
+                      struct nlattr *rta)
+{
+       struct xfrm_algo *ualg;
+       struct xfrm_algo_auth *p;
+       struct xfrm_algo_desc *algo;
+
+       if (!rta)
+               return 0;
+
+       ualg = nla_data(rta);
+
+       algo = xfrm_aalg_get_byname(ualg->alg_name, 1);
+       if (!algo)
+               return -ENOSYS;
+       *props = algo->desc.sadb_alg_id;
+
+       p = kmalloc(sizeof(*p) + (ualg->alg_key_len + 7) / 8, GFP_KERNEL);
+       if (!p)
+               return -ENOMEM;
+
+       strcpy(p->alg_name, algo->name);
+       p->alg_key_len = ualg->alg_key_len;
+       p->alg_trunc_len = algo->uinfo.auth.icv_truncbits;
+       memcpy(p->alg_key, ualg->alg_key, (ualg->alg_key_len + 7) / 8);
+
+       *algpp = p;
+       return 0;
+}
+
+static int attach_auth_trunc(struct xfrm_algo_auth **algpp, u8 *props,
+                            struct nlattr *rta)
+{
+       struct xfrm_algo_auth *p, *ualg;
+       struct xfrm_algo_desc *algo;
+
+       if (!rta)
+               return 0;
+
+       ualg = nla_data(rta);
+
+       algo = xfrm_aalg_get_byname(ualg->alg_name, 1);
+       if (!algo)
+               return -ENOSYS;
+       if (ualg->alg_trunc_len > algo->uinfo.auth.icv_fullbits)
+               return -EINVAL;
+       *props = algo->desc.sadb_alg_id;
+
+       p = kmemdup(ualg, xfrm_alg_auth_len(ualg), GFP_KERNEL);
+       if (!p)
+               return -ENOMEM;
+
+       strcpy(p->alg_name, algo->name);
+       if (!p->alg_trunc_len)
+               p->alg_trunc_len = algo->uinfo.auth.icv_truncbits;
+
+       *algpp = p;
+       return 0;
+}
+
 static int attach_aead(struct xfrm_algo_aead **algpp, u8 *props,
                       struct nlattr *rta)
 {
@@ -316,11 +399,12 @@ static void xfrm_update_ae_params(struct xfrm_state *x, struct nlattr **attrs)
                x->replay_maxdiff = nla_get_u32(rt);
 }
 
-static struct xfrm_state *xfrm_state_construct(struct xfrm_usersa_info *p,
+static struct xfrm_state *xfrm_state_construct(struct net *net,
+                                              struct xfrm_usersa_info *p,
                                               struct nlattr **attrs,
                                               int *errp)
 {
-       struct xfrm_state *x = xfrm_state_alloc(&init_net);
+       struct xfrm_state *x = xfrm_state_alloc(net);
        int err = -ENOMEM;
 
        if (!x)
@@ -331,10 +415,14 @@ static struct xfrm_state *xfrm_state_construct(struct xfrm_usersa_info *p,
        if ((err = attach_aead(&x->aead, &x->props.ealgo,
                               attrs[XFRMA_ALG_AEAD])))
                goto error;
-       if ((err = attach_one_algo(&x->aalg, &x->props.aalgo,
-                                  xfrm_aalg_get_byname,
-                                  attrs[XFRMA_ALG_AUTH])))
+       if ((err = attach_auth_trunc(&x->aalg, &x->props.aalgo,
+                                    attrs[XFRMA_ALG_AUTH_TRUNC])))
                goto error;
+       if (!x->props.aalgo) {
+               if ((err = attach_auth(&x->aalg, &x->props.aalgo,
+                                      attrs[XFRMA_ALG_AUTH])))
+                       goto error;
+       }
        if ((err = attach_one_algo(&x->ealg, &x->props.ealgo,
                                   xfrm_ealg_get_byname,
                                   attrs[XFRMA_ALG_CRYPT])))
@@ -367,9 +455,9 @@ static struct xfrm_state *xfrm_state_construct(struct xfrm_usersa_info *p,
                goto error;
 
        x->km.seq = p->seq;
-       x->replay_maxdiff = sysctl_xfrm_aevent_rseqth;
+       x->replay_maxdiff = net->xfrm.sysctl_aevent_rseqth;
        /* sysctl_xfrm_aevent_etime is in 100ms units */
-       x->replay_maxage = (sysctl_xfrm_aevent_etime*HZ)/XFRM_AE_ETH_M;
+       x->replay_maxage = (net->xfrm.sysctl_aevent_etime*HZ)/XFRM_AE_ETH_M;
        x->preplay.bitmap = 0;
        x->preplay.seq = x->replay.seq+x->replay_maxdiff;
        x->preplay.oseq = x->replay.oseq +x->replay_maxdiff;
@@ -391,6 +479,7 @@ error_no_put:
 static int xfrm_add_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_usersa_info *p = nlmsg_data(nlh);
        struct xfrm_state *x;
        int err;
@@ -403,7 +492,7 @@ static int xfrm_add_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (err)
                return err;
 
-       x = xfrm_state_construct(p, attrs, &err);
+       x = xfrm_state_construct(net, p, attrs, &err);
        if (!x)
                return err;
 
@@ -431,7 +520,8 @@ out:
        return err;
 }
 
-static struct xfrm_state *xfrm_user_state_lookup(struct xfrm_usersa_id *p,
+static struct xfrm_state *xfrm_user_state_lookup(struct net *net,
+                                                struct xfrm_usersa_id *p,
                                                 struct nlattr **attrs,
                                                 int *errp)
 {
@@ -440,7 +530,7 @@ static struct xfrm_state *xfrm_user_state_lookup(struct xfrm_usersa_id *p,
 
        if (xfrm_id_proto_match(p->proto, IPSEC_PROTO_ANY)) {
                err = -ESRCH;
-               x = xfrm_state_lookup(&init_net, &p->daddr, p->spi, p->proto, p->family);
+               x = xfrm_state_lookup(net, &p->daddr, p->spi, p->proto, p->family);
        } else {
                xfrm_address_t *saddr = NULL;
 
@@ -451,7 +541,7 @@ static struct xfrm_state *xfrm_user_state_lookup(struct xfrm_usersa_id *p,
                }
 
                err = -ESRCH;
-               x = xfrm_state_lookup_byaddr(&init_net, &p->daddr, saddr,
+               x = xfrm_state_lookup_byaddr(net, &p->daddr, saddr,
                                             p->proto, p->family);
        }
 
@@ -464,6 +554,7 @@ static struct xfrm_state *xfrm_user_state_lookup(struct xfrm_usersa_id *p,
 static int xfrm_del_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_state *x;
        int err = -ESRCH;
        struct km_event c;
@@ -472,7 +563,7 @@ static int xfrm_del_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        u32 sessionid = NETLINK_CB(skb).sessionid;
        u32 sid = NETLINK_CB(skb).sid;
 
-       x = xfrm_user_state_lookup(p, attrs, &err);
+       x = xfrm_user_state_lookup(net, p, attrs, &err);
        if (x == NULL)
                return err;
 
@@ -544,6 +635,24 @@ static int copy_sec_ctx(struct xfrm_sec_ctx *s, struct sk_buff *skb)
        return 0;
 }
 
+static int copy_to_user_auth(struct xfrm_algo_auth *auth, struct sk_buff *skb)
+{
+       struct xfrm_algo *algo;
+       struct nlattr *nla;
+
+       nla = nla_reserve(skb, XFRMA_ALG_AUTH,
+                         sizeof(*algo) + (auth->alg_key_len + 7) / 8);
+       if (!nla)
+               return -EMSGSIZE;
+
+       algo = nla_data(nla);
+       strcpy(algo->alg_name, auth->alg_name);
+       memcpy(algo->alg_key, auth->alg_key, (auth->alg_key_len + 7) / 8);
+       algo->alg_key_len = auth->alg_key_len;
+
+       return 0;
+}
+
 /* Don't change this without updating xfrm_sa_len! */
 static int copy_to_user_state_extra(struct xfrm_state *x,
                                    struct xfrm_usersa_info *p,
@@ -559,8 +668,13 @@ static int copy_to_user_state_extra(struct xfrm_state *x,
 
        if (x->aead)
                NLA_PUT(skb, XFRMA_ALG_AEAD, aead_len(x->aead), x->aead);
-       if (x->aalg)
-               NLA_PUT(skb, XFRMA_ALG_AUTH, xfrm_alg_len(x->aalg), x->aalg);
+       if (x->aalg) {
+               if (copy_to_user_auth(x->aalg, skb))
+                       goto nla_put_failure;
+
+               NLA_PUT(skb, XFRMA_ALG_AUTH_TRUNC,
+                       xfrm_alg_auth_len(x->aalg), x->aalg);
+       }
        if (x->ealg)
                NLA_PUT(skb, XFRMA_ALG_CRYPT, xfrm_alg_len(x->ealg), x->ealg);
        if (x->calg)
@@ -615,6 +729,7 @@ static int xfrm_dump_sa_done(struct netlink_callback *cb)
 
 static int xfrm_dump_sa(struct sk_buff *skb, struct netlink_callback *cb)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_state_walk *walk = (struct xfrm_state_walk *) &cb->args[1];
        struct xfrm_dump_info info;
 
@@ -631,7 +746,7 @@ static int xfrm_dump_sa(struct sk_buff *skb, struct netlink_callback *cb)
                xfrm_state_walk_init(walk, 0);
        }
 
-       (void) xfrm_state_walk(&init_net, walk, dump_one_state, &info);
+       (void) xfrm_state_walk(net, walk, dump_one_state, &info);
 
        return skb->len;
 }
@@ -666,7 +781,8 @@ static inline size_t xfrm_spdinfo_msgsize(void)
               + nla_total_size(sizeof(struct xfrmu_spdhinfo));
 }
 
-static int build_spdinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
+static int build_spdinfo(struct sk_buff *skb, struct net *net,
+                        u32 pid, u32 seq, u32 flags)
 {
        struct xfrmk_spdinfo si;
        struct xfrmu_spdinfo spc;
@@ -680,7 +796,7 @@ static int build_spdinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
 
        f = nlmsg_data(nlh);
        *f = flags;
-       xfrm_spd_getinfo(&si);
+       xfrm_spd_getinfo(net, &si);
        spc.incnt = si.incnt;
        spc.outcnt = si.outcnt;
        spc.fwdcnt = si.fwdcnt;
@@ -713,7 +829,7 @@ static int xfrm_get_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (r_skb == NULL)
                return -ENOMEM;
 
-       if (build_spdinfo(r_skb, spid, seq, *flags) < 0)
+       if (build_spdinfo(r_skb, net, spid, seq, *flags) < 0)
                BUG();
 
        return nlmsg_unicast(net->xfrm.nlsk, r_skb, spid);
@@ -726,7 +842,8 @@ static inline size_t xfrm_sadinfo_msgsize(void)
               + nla_total_size(4); /* XFRMA_SAD_CNT */
 }
 
-static int build_sadinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
+static int build_sadinfo(struct sk_buff *skb, struct net *net,
+                        u32 pid, u32 seq, u32 flags)
 {
        struct xfrmk_sadinfo si;
        struct xfrmu_sadhinfo sh;
@@ -739,7 +856,7 @@ static int build_sadinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
 
        f = nlmsg_data(nlh);
        *f = flags;
-       xfrm_sad_getinfo(&si);
+       xfrm_sad_getinfo(net, &si);
 
        sh.sadhmcnt = si.sadhmcnt;
        sh.sadhcnt = si.sadhcnt;
@@ -767,7 +884,7 @@ static int xfrm_get_sadinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (r_skb == NULL)
                return -ENOMEM;
 
-       if (build_sadinfo(r_skb, spid, seq, *flags) < 0)
+       if (build_sadinfo(r_skb, net, spid, seq, *flags) < 0)
                BUG();
 
        return nlmsg_unicast(net->xfrm.nlsk, r_skb, spid);
@@ -776,13 +893,13 @@ static int xfrm_get_sadinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
 static int xfrm_get_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
-       struct net *net = &init_net;
+       struct net *net = sock_net(skb->sk);
        struct xfrm_usersa_id *p = nlmsg_data(nlh);
        struct xfrm_state *x;
        struct sk_buff *resp_skb;
        int err = -ESRCH;
 
-       x = xfrm_user_state_lookup(p, attrs, &err);
+       x = xfrm_user_state_lookup(net, p, attrs, &err);
        if (x == NULL)
                goto out_noput;
 
@@ -823,7 +940,7 @@ static int verify_userspi_info(struct xfrm_userspi_info *p)
 static int xfrm_alloc_userspi(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
-       struct net *net = &init_net;
+       struct net *net = sock_net(skb->sk);
        struct xfrm_state *x;
        struct xfrm_userspi_info *p;
        struct sk_buff *resp_skb;
@@ -1082,9 +1199,9 @@ static void copy_to_user_policy(struct xfrm_policy *xp, struct xfrm_userpolicy_i
        p->share = XFRM_SHARE_ANY; /* XXX xp->share */
 }
 
-static struct xfrm_policy *xfrm_policy_construct(struct xfrm_userpolicy_info *p, struct nlattr **attrs, int *errp)
+static struct xfrm_policy *xfrm_policy_construct(struct net *net, struct xfrm_userpolicy_info *p, struct nlattr **attrs, int *errp)
 {
-       struct xfrm_policy *xp = xfrm_policy_alloc(&init_net, GFP_KERNEL);
+       struct xfrm_policy *xp = xfrm_policy_alloc(net, GFP_KERNEL);
        int err;
 
        if (!xp) {
@@ -1114,6 +1231,7 @@ static struct xfrm_policy *xfrm_policy_construct(struct xfrm_userpolicy_info *p,
 static int xfrm_add_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_userpolicy_info *p = nlmsg_data(nlh);
        struct xfrm_policy *xp;
        struct km_event c;
@@ -1130,7 +1248,7 @@ static int xfrm_add_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (err)
                return err;
 
-       xp = xfrm_policy_construct(p, attrs, &err);
+       xp = xfrm_policy_construct(net, p, attrs, &err);
        if (!xp)
                return err;
 
@@ -1267,6 +1385,7 @@ static int xfrm_dump_policy_done(struct netlink_callback *cb)
 
 static int xfrm_dump_policy(struct sk_buff *skb, struct netlink_callback *cb)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_policy_walk *walk = (struct xfrm_policy_walk *) &cb->args[1];
        struct xfrm_dump_info info;
 
@@ -1283,7 +1402,7 @@ static int xfrm_dump_policy(struct sk_buff *skb, struct netlink_callback *cb)
                xfrm_policy_walk_init(walk, XFRM_POLICY_TYPE_ANY);
        }
 
-       (void) xfrm_policy_walk(&init_net, walk, dump_one_policy, &info);
+       (void) xfrm_policy_walk(net, walk, dump_one_policy, &info);
 
        return skb->len;
 }
@@ -1315,7 +1434,7 @@ static struct sk_buff *xfrm_policy_netlink(struct sk_buff *in_skb,
 static int xfrm_get_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
-       struct net *net = &init_net;
+       struct net *net = sock_net(skb->sk);
        struct xfrm_policy *xp;
        struct xfrm_userpolicy_id *p;
        u8 type = XFRM_POLICY_TYPE_MAIN;
@@ -1395,6 +1514,7 @@ out:
 static int xfrm_flush_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct km_event c;
        struct xfrm_usersa_flush *p = nlmsg_data(nlh);
        struct xfrm_audit audit_info;
@@ -1403,13 +1523,14 @@ static int xfrm_flush_sa(struct sk_buff *skb, struct nlmsghdr *nlh,
        audit_info.loginuid = NETLINK_CB(skb).loginuid;
        audit_info.sessionid = NETLINK_CB(skb).sessionid;
        audit_info.secid = NETLINK_CB(skb).sid;
-       err = xfrm_state_flush(&init_net, p->proto, &audit_info);
+       err = xfrm_state_flush(net, p->proto, &audit_info);
        if (err)
                return err;
        c.data.proto = p->proto;
        c.event = nlh->nlmsg_type;
        c.seq = nlh->nlmsg_seq;
        c.pid = nlh->nlmsg_pid;
+       c.net = net;
        km_state_notify(NULL, &c);
 
        return 0;
@@ -1462,7 +1583,7 @@ nla_put_failure:
 static int xfrm_get_ae(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
-       struct net *net = &init_net;
+       struct net *net = sock_net(skb->sk);
        struct xfrm_state *x;
        struct sk_buff *r_skb;
        int err;
@@ -1501,6 +1622,7 @@ static int xfrm_get_ae(struct sk_buff *skb, struct nlmsghdr *nlh,
 static int xfrm_new_ae(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_state *x;
        struct km_event c;
        int err = - EINVAL;
@@ -1515,7 +1637,7 @@ static int xfrm_new_ae(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (!(nlh->nlmsg_flags&NLM_F_REPLACE))
                return err;
 
-       x = xfrm_state_lookup(&init_net, &p->sa_id.daddr, p->sa_id.spi, p->sa_id.proto, p->sa_id.family);
+       x = xfrm_state_lookup(net, &p->sa_id.daddr, p->sa_id.spi, p->sa_id.proto, p->sa_id.family);
        if (x == NULL)
                return -ESRCH;
 
@@ -1540,6 +1662,7 @@ out:
 static int xfrm_flush_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct km_event c;
        u8 type = XFRM_POLICY_TYPE_MAIN;
        int err;
@@ -1552,13 +1675,14 @@ static int xfrm_flush_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
        audit_info.loginuid = NETLINK_CB(skb).loginuid;
        audit_info.sessionid = NETLINK_CB(skb).sessionid;
        audit_info.secid = NETLINK_CB(skb).sid;
-       err = xfrm_policy_flush(&init_net, type, &audit_info);
+       err = xfrm_policy_flush(net, type, &audit_info);
        if (err)
                return err;
        c.data.type = type;
        c.event = nlh->nlmsg_type;
        c.seq = nlh->nlmsg_seq;
        c.pid = nlh->nlmsg_pid;
+       c.net = net;
        km_policy_notify(NULL, 0, &c);
        return 0;
 }
@@ -1566,6 +1690,7 @@ static int xfrm_flush_policy(struct sk_buff *skb, struct nlmsghdr *nlh,
 static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_policy *xp;
        struct xfrm_user_polexpire *up = nlmsg_data(nlh);
        struct xfrm_userpolicy_info *p = &up->pol;
@@ -1577,7 +1702,7 @@ static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh,
                return err;
 
        if (p->index)
-               xp = xfrm_policy_byid(&init_net, type, p->dir, p->index, 0, &err);
+               xp = xfrm_policy_byid(net, type, p->dir, p->index, 0, &err);
        else {
                struct nlattr *rt = attrs[XFRMA_SEC_CTX];
                struct xfrm_sec_ctx *ctx;
@@ -1594,7 +1719,7 @@ static int xfrm_add_pol_expire(struct sk_buff *skb, struct nlmsghdr *nlh,
                        if (err)
                                return err;
                }
-               xp = xfrm_policy_bysel_ctx(&init_net, type, p->dir, &p->sel, ctx, 0, &err);
+               xp = xfrm_policy_bysel_ctx(net, type, p->dir, &p->sel, ctx, 0, &err);
                security_xfrm_policy_free(ctx);
        }
        if (xp == NULL)
@@ -1629,12 +1754,13 @@ out:
 static int xfrm_add_sa_expire(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_state *x;
        int err;
        struct xfrm_user_expire *ue = nlmsg_data(nlh);
        struct xfrm_usersa_info *p = &ue->state;
 
-       x = xfrm_state_lookup(&init_net, &p->id.daddr, p->id.spi, p->id.proto, p->family);
+       x = xfrm_state_lookup(net, &p->id.daddr, p->id.spi, p->id.proto, p->family);
 
        err = -ENOENT;
        if (x == NULL)
@@ -1663,31 +1789,27 @@ out:
 static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh,
                struct nlattr **attrs)
 {
+       struct net *net = sock_net(skb->sk);
        struct xfrm_policy *xp;
        struct xfrm_user_tmpl *ut;
        int i;
        struct nlattr *rt = attrs[XFRMA_TMPL];
 
        struct xfrm_user_acquire *ua = nlmsg_data(nlh);
-       struct xfrm_state *x = xfrm_state_alloc(&init_net);
+       struct xfrm_state *x = xfrm_state_alloc(net);
        int err = -ENOMEM;
 
        if (!x)
-               return err;
+               goto nomem;
 
        err = verify_newpolicy_info(&ua->policy);
-       if (err) {
-               printk("BAD policy passed\n");
-               kfree(x);
-               return err;
-       }
+       if (err)
+               goto bad_policy;
 
        /*   build an XP */
-       xp = xfrm_policy_construct(&ua->policy, attrs, &err);
-       if (!xp) {
-               kfree(x);
-               return err;
-       }
+       xp = xfrm_policy_construct(net, &ua->policy, attrs, &err);
+       if (!xp)
+               goto free_state;
 
        memcpy(&x->id, &ua->id, sizeof(ua->id));
        memcpy(&x->props.saddr, &ua->saddr, sizeof(ua->saddr));
@@ -1712,6 +1834,13 @@ static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh,
        kfree(xp);
 
        return 0;
+
+bad_policy:
+       printk("BAD policy passed\n");
+free_state:
+       kfree(x);
+nomem:
+       return err;
 }
 
 #ifdef CONFIG_XFRM_MIGRATE
@@ -2041,7 +2170,7 @@ static int build_expire(struct sk_buff *skb, struct xfrm_state *x, struct km_eve
 
 static int xfrm_exp_state_notify(struct xfrm_state *x, struct km_event *c)
 {
-       struct net *net = &init_net;
+       struct net *net = xs_net(x);
        struct sk_buff *skb;
 
        skb = nlmsg_new(xfrm_expire_msgsize(), GFP_ATOMIC);
@@ -2056,7 +2185,7 @@ static int xfrm_exp_state_notify(struct xfrm_state *x, struct km_event *c)
 
 static int xfrm_aevent_state_notify(struct xfrm_state *x, struct km_event *c)
 {
-       struct net *net = &init_net;
+       struct net *net = xs_net(x);
        struct sk_buff *skb;
 
        skb = nlmsg_new(xfrm_aevent_msgsize(), GFP_ATOMIC);
@@ -2071,7 +2200,7 @@ static int xfrm_aevent_state_notify(struct xfrm_state *x, struct km_event *c)
 
 static int xfrm_notify_sa_flush(struct km_event *c)
 {
-       struct net *net = &init_net;
+       struct net *net = c->net;
        struct xfrm_usersa_flush *p;
        struct nlmsghdr *nlh;
        struct sk_buff *skb;
@@ -2100,8 +2229,11 @@ static inline size_t xfrm_sa_len(struct xfrm_state *x)
        size_t l = 0;
        if (x->aead)
                l += nla_total_size(aead_len(x->aead));
-       if (x->aalg)
-               l += nla_total_size(xfrm_alg_len(x->aalg));
+       if (x->aalg) {
+               l += nla_total_size(sizeof(struct xfrm_algo) +
+                                   (x->aalg->alg_key_len + 7) / 8);
+               l += nla_total_size(xfrm_alg_auth_len(x->aalg));
+       }
        if (x->ealg)
                l += nla_total_size(xfrm_alg_len(x->ealg));
        if (x->calg)
@@ -2122,7 +2254,7 @@ static inline size_t xfrm_sa_len(struct xfrm_state *x)
 
 static int xfrm_notify_sa(struct xfrm_state *x, struct km_event *c)
 {
-       struct net *net = &init_net;
+       struct net *net = xs_net(x);
        struct xfrm_usersa_info *p;
        struct xfrm_usersa_id *id;
        struct nlmsghdr *nlh;
@@ -2266,6 +2398,7 @@ static int xfrm_send_acquire(struct xfrm_state *x, struct xfrm_tmpl *xt,
 static struct xfrm_policy *xfrm_compile_policy(struct sock *sk, int opt,
                                               u8 *data, int len, int *dir)
 {
+       struct net *net = sock_net(sk);
        struct xfrm_userpolicy_info *p = (struct xfrm_userpolicy_info *)data;
        struct xfrm_user_tmpl *ut = (struct xfrm_user_tmpl *) (p + 1);
        struct xfrm_policy *xp;
@@ -2304,7 +2437,7 @@ static struct xfrm_policy *xfrm_compile_policy(struct sock *sk, int opt,
        if (p->dir > XFRM_POLICY_OUT)
                return NULL;
 
-       xp = xfrm_policy_alloc(&init_net, GFP_KERNEL);
+       xp = xfrm_policy_alloc(net, GFP_KERNEL);
        if (xp == NULL) {
                *dir = -ENOBUFS;
                return NULL;
@@ -2357,7 +2490,7 @@ nlmsg_failure:
 
 static int xfrm_exp_policy_notify(struct xfrm_policy *xp, int dir, struct km_event *c)
 {
-       struct net *net = &init_net;
+       struct net *net = xp_net(xp);
        struct sk_buff *skb;
 
        skb = nlmsg_new(xfrm_polexpire_msgsize(xp), GFP_ATOMIC);
@@ -2372,7 +2505,7 @@ static int xfrm_exp_policy_notify(struct xfrm_policy *xp, int dir, struct km_eve
 
 static int xfrm_notify_policy(struct xfrm_policy *xp, int dir, struct km_event *c)
 {
-       struct net *net = &init_net;
+       struct net *net = xp_net(xp);
        struct xfrm_userpolicy_info *p;
        struct xfrm_userpolicy_id *id;
        struct nlmsghdr *nlh;
@@ -2432,7 +2565,7 @@ nlmsg_failure:
 
 static int xfrm_notify_policy_flush(struct km_event *c)
 {
-       struct net *net = &init_net;
+       struct net *net = c->net;
        struct nlmsghdr *nlh;
        struct sk_buff *skb;
 
@@ -2504,10 +2637,9 @@ nla_put_failure:
        return -EMSGSIZE;
 }
 
-static int xfrm_send_report(u8 proto, struct xfrm_selector *sel,
-                           xfrm_address_t *addr)
+static int xfrm_send_report(struct net *net, u8 proto,
+                           struct xfrm_selector *sel, xfrm_address_t *addr)
 {
-       struct net *net = &init_net;
        struct sk_buff *skb;
 
        skb = nlmsg_new(xfrm_report_msgsize(), GFP_ATOMIC);
@@ -2591,22 +2723,24 @@ static int __net_init xfrm_user_net_init(struct net *net)
                                     xfrm_netlink_rcv, NULL, THIS_MODULE);
        if (nlsk == NULL)
                return -ENOMEM;
+       net->xfrm.nlsk_stash = nlsk; /* Don't set to NULL */
        rcu_assign_pointer(net->xfrm.nlsk, nlsk);
        return 0;
 }
 
-static void __net_exit xfrm_user_net_exit(struct net *net)
+static void __net_exit xfrm_user_net_exit(struct list_head *net_exit_list)
 {
-       struct sock *nlsk = net->xfrm.nlsk;
-
-       rcu_assign_pointer(net->xfrm.nlsk, NULL);
-       synchronize_rcu();
-       netlink_kernel_release(nlsk);
+       struct net *net;
+       list_for_each_entry(net, net_exit_list, exit_list)
+               rcu_assign_pointer(net->xfrm.nlsk, NULL);
+       synchronize_net();
+       list_for_each_entry(net, net_exit_list, exit_list)
+               netlink_kernel_release(net->xfrm.nlsk_stash);
 }
 
 static struct pernet_operations xfrm_user_net_ops = {
-       .init = xfrm_user_net_init,
-       .exit = xfrm_user_net_exit,
+       .init       = xfrm_user_net_init,
+       .exit_batch = xfrm_user_net_exit,
 };
 
 static int __init xfrm_user_init(void)