net: af_key: use seq_hlist_foo() helpers
[safe/jmp/linux-2.6] / net / xfrm / xfrm_user.c
index 3e32ec2..d5a7129 100644 (file)
@@ -62,6 +62,22 @@ static int verify_one_alg(struct nlattr **attrs, enum xfrm_attr_type_t type)
        return 0;
 }
 
+static int verify_auth_trunc(struct nlattr **attrs)
+{
+       struct nlattr *rt = attrs[XFRMA_ALG_AUTH_TRUNC];
+       struct xfrm_algo_auth *algp;
+
+       if (!rt)
+               return 0;
+
+       algp = nla_data(rt);
+       if (nla_len(rt) < xfrm_alg_auth_len(algp))
+               return -EINVAL;
+
+       algp->alg_name[CRYPTO_MAX_ALG_NAME - 1] = '\0';
+       return 0;
+}
+
 static int verify_aead(struct nlattr **attrs)
 {
        struct nlattr *rt = attrs[XFRMA_ALG_AEAD];
@@ -128,7 +144,8 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
        err = -EINVAL;
        switch (p->id.proto) {
        case IPPROTO_AH:
-               if (!attrs[XFRMA_ALG_AUTH]      ||
+               if ((!attrs[XFRMA_ALG_AUTH]     &&
+                    !attrs[XFRMA_ALG_AUTH_TRUNC]) ||
                    attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_CRYPT]      ||
                    attrs[XFRMA_ALG_COMP])
@@ -139,10 +156,12 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
                if (attrs[XFRMA_ALG_COMP])
                        goto out;
                if (!attrs[XFRMA_ALG_AUTH] &&
+                   !attrs[XFRMA_ALG_AUTH_TRUNC] &&
                    !attrs[XFRMA_ALG_CRYPT] &&
                    !attrs[XFRMA_ALG_AEAD])
                        goto out;
                if ((attrs[XFRMA_ALG_AUTH] ||
+                    attrs[XFRMA_ALG_AUTH_TRUNC] ||
                     attrs[XFRMA_ALG_CRYPT]) &&
                    attrs[XFRMA_ALG_AEAD])
                        goto out;
@@ -152,6 +171,7 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
                if (!attrs[XFRMA_ALG_COMP]      ||
                    attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_AUTH]       ||
+                   attrs[XFRMA_ALG_AUTH_TRUNC] ||
                    attrs[XFRMA_ALG_CRYPT])
                        goto out;
                break;
@@ -161,6 +181,7 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
        case IPPROTO_ROUTING:
                if (attrs[XFRMA_ALG_COMP]       ||
                    attrs[XFRMA_ALG_AUTH]       ||
+                   attrs[XFRMA_ALG_AUTH_TRUNC] ||
                    attrs[XFRMA_ALG_AEAD]       ||
                    attrs[XFRMA_ALG_CRYPT]      ||
                    attrs[XFRMA_ENCAP]          ||
@@ -176,6 +197,8 @@ static int verify_newsa_info(struct xfrm_usersa_info *p,
 
        if ((err = verify_aead(attrs)))
                goto out;
+       if ((err = verify_auth_trunc(attrs)))
+               goto out;
        if ((err = verify_one_alg(attrs, XFRMA_ALG_AUTH)))
                goto out;
        if ((err = verify_one_alg(attrs, XFRMA_ALG_CRYPT)))
@@ -229,6 +252,66 @@ static int attach_one_algo(struct xfrm_algo **algpp, u8 *props,
        return 0;
 }
 
+static int attach_auth(struct xfrm_algo_auth **algpp, u8 *props,
+                      struct nlattr *rta)
+{
+       struct xfrm_algo *ualg;
+       struct xfrm_algo_auth *p;
+       struct xfrm_algo_desc *algo;
+
+       if (!rta)
+               return 0;
+
+       ualg = nla_data(rta);
+
+       algo = xfrm_aalg_get_byname(ualg->alg_name, 1);
+       if (!algo)
+               return -ENOSYS;
+       *props = algo->desc.sadb_alg_id;
+
+       p = kmalloc(sizeof(*p) + (ualg->alg_key_len + 7) / 8, GFP_KERNEL);
+       if (!p)
+               return -ENOMEM;
+
+       strcpy(p->alg_name, algo->name);
+       p->alg_key_len = ualg->alg_key_len;
+       p->alg_trunc_len = algo->uinfo.auth.icv_truncbits;
+       memcpy(p->alg_key, ualg->alg_key, (ualg->alg_key_len + 7) / 8);
+
+       *algpp = p;
+       return 0;
+}
+
+static int attach_auth_trunc(struct xfrm_algo_auth **algpp, u8 *props,
+                            struct nlattr *rta)
+{
+       struct xfrm_algo_auth *p, *ualg;
+       struct xfrm_algo_desc *algo;
+
+       if (!rta)
+               return 0;
+
+       ualg = nla_data(rta);
+
+       algo = xfrm_aalg_get_byname(ualg->alg_name, 1);
+       if (!algo)
+               return -ENOSYS;
+       if (ualg->alg_trunc_len > algo->uinfo.auth.icv_fullbits)
+               return -EINVAL;
+       *props = algo->desc.sadb_alg_id;
+
+       p = kmemdup(ualg, xfrm_alg_auth_len(ualg), GFP_KERNEL);
+       if (!p)
+               return -ENOMEM;
+
+       strcpy(p->alg_name, algo->name);
+       if (!p->alg_trunc_len)
+               p->alg_trunc_len = algo->uinfo.auth.icv_truncbits;
+
+       *algpp = p;
+       return 0;
+}
+
 static int attach_aead(struct xfrm_algo_aead **algpp, u8 *props,
                       struct nlattr *rta)
 {
@@ -332,10 +415,14 @@ static struct xfrm_state *xfrm_state_construct(struct net *net,
        if ((err = attach_aead(&x->aead, &x->props.ealgo,
                               attrs[XFRMA_ALG_AEAD])))
                goto error;
-       if ((err = attach_one_algo(&x->aalg, &x->props.aalgo,
-                                  xfrm_aalg_get_byname,
-                                  attrs[XFRMA_ALG_AUTH])))
+       if ((err = attach_auth_trunc(&x->aalg, &x->props.aalgo,
+                                    attrs[XFRMA_ALG_AUTH_TRUNC])))
                goto error;
+       if (!x->props.aalgo) {
+               if ((err = attach_auth(&x->aalg, &x->props.aalgo,
+                                      attrs[XFRMA_ALG_AUTH])))
+                       goto error;
+       }
        if ((err = attach_one_algo(&x->ealg, &x->props.ealgo,
                                   xfrm_ealg_get_byname,
                                   attrs[XFRMA_ALG_CRYPT])))
@@ -368,9 +455,9 @@ static struct xfrm_state *xfrm_state_construct(struct net *net,
                goto error;
 
        x->km.seq = p->seq;
-       x->replay_maxdiff = sysctl_xfrm_aevent_rseqth;
+       x->replay_maxdiff = net->xfrm.sysctl_aevent_rseqth;
        /* sysctl_xfrm_aevent_etime is in 100ms units */
-       x->replay_maxage = (sysctl_xfrm_aevent_etime*HZ)/XFRM_AE_ETH_M;
+       x->replay_maxage = (net->xfrm.sysctl_aevent_etime*HZ)/XFRM_AE_ETH_M;
        x->preplay.bitmap = 0;
        x->preplay.seq = x->replay.seq+x->replay_maxdiff;
        x->preplay.oseq = x->replay.oseq +x->replay_maxdiff;
@@ -548,6 +635,24 @@ static int copy_sec_ctx(struct xfrm_sec_ctx *s, struct sk_buff *skb)
        return 0;
 }
 
+static int copy_to_user_auth(struct xfrm_algo_auth *auth, struct sk_buff *skb)
+{
+       struct xfrm_algo *algo;
+       struct nlattr *nla;
+
+       nla = nla_reserve(skb, XFRMA_ALG_AUTH,
+                         sizeof(*algo) + (auth->alg_key_len + 7) / 8);
+       if (!nla)
+               return -EMSGSIZE;
+
+       algo = nla_data(nla);
+       strcpy(algo->alg_name, auth->alg_name);
+       memcpy(algo->alg_key, auth->alg_key, (auth->alg_key_len + 7) / 8);
+       algo->alg_key_len = auth->alg_key_len;
+
+       return 0;
+}
+
 /* Don't change this without updating xfrm_sa_len! */
 static int copy_to_user_state_extra(struct xfrm_state *x,
                                    struct xfrm_usersa_info *p,
@@ -563,8 +668,13 @@ static int copy_to_user_state_extra(struct xfrm_state *x,
 
        if (x->aead)
                NLA_PUT(skb, XFRMA_ALG_AEAD, aead_len(x->aead), x->aead);
-       if (x->aalg)
-               NLA_PUT(skb, XFRMA_ALG_AUTH, xfrm_alg_len(x->aalg), x->aalg);
+       if (x->aalg) {
+               if (copy_to_user_auth(x->aalg, skb))
+                       goto nla_put_failure;
+
+               NLA_PUT(skb, XFRMA_ALG_AUTH_TRUNC,
+                       xfrm_alg_auth_len(x->aalg), x->aalg);
+       }
        if (x->ealg)
                NLA_PUT(skb, XFRMA_ALG_CRYPT, xfrm_alg_len(x->ealg), x->ealg);
        if (x->calg)
@@ -671,7 +781,8 @@ static inline size_t xfrm_spdinfo_msgsize(void)
               + nla_total_size(sizeof(struct xfrmu_spdhinfo));
 }
 
-static int build_spdinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
+static int build_spdinfo(struct sk_buff *skb, struct net *net,
+                        u32 pid, u32 seq, u32 flags)
 {
        struct xfrmk_spdinfo si;
        struct xfrmu_spdinfo spc;
@@ -685,7 +796,7 @@ static int build_spdinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
 
        f = nlmsg_data(nlh);
        *f = flags;
-       xfrm_spd_getinfo(&si);
+       xfrm_spd_getinfo(net, &si);
        spc.incnt = si.incnt;
        spc.outcnt = si.outcnt;
        spc.fwdcnt = si.fwdcnt;
@@ -718,7 +829,7 @@ static int xfrm_get_spdinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (r_skb == NULL)
                return -ENOMEM;
 
-       if (build_spdinfo(r_skb, spid, seq, *flags) < 0)
+       if (build_spdinfo(r_skb, net, spid, seq, *flags) < 0)
                BUG();
 
        return nlmsg_unicast(net->xfrm.nlsk, r_skb, spid);
@@ -731,7 +842,8 @@ static inline size_t xfrm_sadinfo_msgsize(void)
               + nla_total_size(4); /* XFRMA_SAD_CNT */
 }
 
-static int build_sadinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
+static int build_sadinfo(struct sk_buff *skb, struct net *net,
+                        u32 pid, u32 seq, u32 flags)
 {
        struct xfrmk_sadinfo si;
        struct xfrmu_sadhinfo sh;
@@ -744,7 +856,7 @@ static int build_sadinfo(struct sk_buff *skb, u32 pid, u32 seq, u32 flags)
 
        f = nlmsg_data(nlh);
        *f = flags;
-       xfrm_sad_getinfo(&si);
+       xfrm_sad_getinfo(net, &si);
 
        sh.sadhmcnt = si.sadhmcnt;
        sh.sadhcnt = si.sadhcnt;
@@ -772,7 +884,7 @@ static int xfrm_get_sadinfo(struct sk_buff *skb, struct nlmsghdr *nlh,
        if (r_skb == NULL)
                return -ENOMEM;
 
-       if (build_sadinfo(r_skb, spid, seq, *flags) < 0)
+       if (build_sadinfo(r_skb, net, spid, seq, *flags) < 0)
                BUG();
 
        return nlmsg_unicast(net->xfrm.nlsk, r_skb, spid);
@@ -1688,21 +1800,16 @@ static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh,
        int err = -ENOMEM;
 
        if (!x)
-               return err;
+               goto nomem;
 
        err = verify_newpolicy_info(&ua->policy);
-       if (err) {
-               printk("BAD policy passed\n");
-               kfree(x);
-               return err;
-       }
+       if (err)
+               goto bad_policy;
 
        /*   build an XP */
        xp = xfrm_policy_construct(net, &ua->policy, attrs, &err);
-       if (!xp) {
-               kfree(x);
-               return err;
-       }
+       if (!xp)
+               goto free_state;
 
        memcpy(&x->id, &ua->id, sizeof(ua->id));
        memcpy(&x->props.saddr, &ua->saddr, sizeof(ua->saddr));
@@ -1727,6 +1834,13 @@ static int xfrm_add_acquire(struct sk_buff *skb, struct nlmsghdr *nlh,
        kfree(xp);
 
        return 0;
+
+bad_policy:
+       printk("BAD policy passed\n");
+free_state:
+       kfree(x);
+nomem:
+       return err;
 }
 
 #ifdef CONFIG_XFRM_MIGRATE
@@ -2115,8 +2229,11 @@ static inline size_t xfrm_sa_len(struct xfrm_state *x)
        size_t l = 0;
        if (x->aead)
                l += nla_total_size(aead_len(x->aead));
-       if (x->aalg)
-               l += nla_total_size(xfrm_alg_len(x->aalg));
+       if (x->aalg) {
+               l += nla_total_size(sizeof(struct xfrm_algo) +
+                                   (x->aalg->alg_key_len + 7) / 8);
+               l += nla_total_size(xfrm_alg_auth_len(x->aalg));
+       }
        if (x->ealg)
                l += nla_total_size(xfrm_alg_len(x->ealg));
        if (x->calg)
@@ -2520,10 +2637,9 @@ nla_put_failure:
        return -EMSGSIZE;
 }
 
-static int xfrm_send_report(u8 proto, struct xfrm_selector *sel,
-                           xfrm_address_t *addr)
+static int xfrm_send_report(struct net *net, u8 proto,
+                           struct xfrm_selector *sel, xfrm_address_t *addr)
 {
-       struct net *net = &init_net;
        struct sk_buff *skb;
 
        skb = nlmsg_new(xfrm_report_msgsize(), GFP_ATOMIC);
@@ -2607,22 +2723,24 @@ static int __net_init xfrm_user_net_init(struct net *net)
                                     xfrm_netlink_rcv, NULL, THIS_MODULE);
        if (nlsk == NULL)
                return -ENOMEM;
+       net->xfrm.nlsk_stash = nlsk; /* Don't set to NULL */
        rcu_assign_pointer(net->xfrm.nlsk, nlsk);
        return 0;
 }
 
-static void __net_exit xfrm_user_net_exit(struct net *net)
+static void __net_exit xfrm_user_net_exit(struct list_head *net_exit_list)
 {
-       struct sock *nlsk = net->xfrm.nlsk;
-
-       rcu_assign_pointer(net->xfrm.nlsk, NULL);
-       synchronize_rcu();
-       netlink_kernel_release(nlsk);
+       struct net *net;
+       list_for_each_entry(net, net_exit_list, exit_list)
+               rcu_assign_pointer(net->xfrm.nlsk, NULL);
+       synchronize_net();
+       list_for_each_entry(net, net_exit_list, exit_list)
+               netlink_kernel_release(net->xfrm.nlsk_stash);
 }
 
 static struct pernet_operations xfrm_user_net_ops = {
-       .init = xfrm_user_net_init,
-       .exit = xfrm_user_net_exit,
+       .init       = xfrm_user_net_init,
+       .exit_batch = xfrm_user_net_exit,
 };
 
 static int __init xfrm_user_init(void)