netns xfrm: deal with dst entries in netns
[safe/jmp/linux-2.6] / net / ipv4 / xfrm4_policy.c
index 42d8ded..67107d6 100644 (file)
@@ -1,24 +1,62 @@
-/* 
+/*
  * xfrm4_policy.c
  *
  * Changes:
  *     Kazunori MIYAZAWA @USAGI
  *     YOSHIFUJI Hideaki @USAGI
  *             Split up af-specific portion
- *     
+ *
  */
 
-#include <linux/compiler.h>
+#include <linux/err.h>
+#include <linux/kernel.h>
 #include <linux/inetdevice.h>
+#include <net/dst.h>
 #include <net/xfrm.h>
 #include <net/ip.h>
 
-static struct dst_ops xfrm4_dst_ops;
 static struct xfrm_policy_afinfo xfrm4_policy_afinfo;
 
-static int xfrm4_dst_lookup(struct xfrm_dst **dst, struct flowi *fl)
+static struct dst_entry *xfrm4_dst_lookup(struct net *net, int tos,
+                                         xfrm_address_t *saddr,
+                                         xfrm_address_t *daddr)
 {
-       return __ip_route_output_key((struct rtable**)dst, fl);
+       struct flowi fl = {
+               .nl_u = {
+                       .ip4_u = {
+                               .tos = tos,
+                               .daddr = daddr->a4,
+                       },
+               },
+       };
+       struct dst_entry *dst;
+       struct rtable *rt;
+       int err;
+
+       if (saddr)
+               fl.fl4_src = saddr->a4;
+
+       err = __ip_route_output_key(net, &rt, &fl);
+       dst = &rt->u.dst;
+       if (err)
+               dst = ERR_PTR(err);
+       return dst;
+}
+
+static int xfrm4_get_saddr(struct net *net,
+                          xfrm_address_t *saddr, xfrm_address_t *daddr)
+{
+       struct dst_entry *dst;
+       struct rtable *rt;
+
+       dst = xfrm4_dst_lookup(net, 0, NULL, daddr);
+       if (IS_ERR(dst))
+               return -EHOSTUNREACH;
+
+       rt = (struct rtable *)dst;
+       saddr->a4 = rt->rt_src;
+       dst_release(dst);
+       return 0;
 }
 
 static struct dst_entry *
@@ -28,12 +66,12 @@ __xfrm4_find_bundle(struct flowi *fl, struct xfrm_policy *policy)
 
        read_lock_bh(&policy->lock);
        for (dst = policy->bundles; dst; dst = dst->next) {
-               struct xfrm_dst *xdst = (struct xfrm_dst*)dst;
+               struct xfrm_dst *xdst = (struct xfrm_dst *)dst;
                if (xdst->u.rt.fl.oif == fl->oif &&     /*XXX*/
                    xdst->u.rt.fl.fl4_dst == fl->fl4_dst &&
-                   xdst->u.rt.fl.fl4_src == fl->fl4_src &&
-                   xdst->u.rt.fl.fl4_tos == fl->fl4_tos &&
-                   xfrm_bundle_ok(xdst, fl, AF_INET, 0)) {
+                   xdst->u.rt.fl.fl4_src == fl->fl4_src &&
+                   xdst->u.rt.fl.fl4_tos == fl->fl4_tos &&
+                   xfrm_bundle_ok(policy, xdst, fl, AF_INET, 0)) {
                        dst_clone(dst);
                        break;
                }
@@ -42,151 +80,67 @@ __xfrm4_find_bundle(struct flowi *fl, struct xfrm_policy *policy)
        return dst;
 }
 
-/* Allocate chain of dst_entry's, attach known xfrm's, calculate
- * all the metrics... Shortly, bundle a bundle.
- */
-
-static int
-__xfrm4_bundle_create(struct xfrm_policy *policy, struct xfrm_state **xfrm, int nx,
-                     struct flowi *fl, struct dst_entry **dst_p)
+static int xfrm4_get_tos(struct flowi *fl)
 {
-       struct dst_entry *dst, *dst_prev;
-       struct rtable *rt0 = (struct rtable*)(*dst_p);
-       struct rtable *rt = rt0;
-       u32 remote = fl->fl4_dst;
-       u32 local  = fl->fl4_src;
-       struct flowi fl_tunnel = {
-               .nl_u = {
-                       .ip4_u = {
-                               .saddr = local,
-                               .daddr = remote,
-                               .tos = fl->fl4_tos
-                       }
-               }
-       };
-       int i;
-       int err;
-       int header_len = 0;
-       int trailer_len = 0;
+       return fl->fl4_tos;
+}
 
-       dst = dst_prev = NULL;
-       dst_hold(&rt->u.dst);
+static int xfrm4_init_path(struct xfrm_dst *path, struct dst_entry *dst,
+                          int nfheader_len)
+{
+       return 0;
+}
 
-       for (i = 0; i < nx; i++) {
-               struct dst_entry *dst1 = dst_alloc(&xfrm4_dst_ops);
-               struct xfrm_dst *xdst;
-               int tunnel = 0;
+static int xfrm4_fill_dst(struct xfrm_dst *xdst, struct net_device *dev)
+{
+       struct rtable *rt = (struct rtable *)xdst->route;
 
-               if (unlikely(dst1 == NULL)) {
-                       err = -ENOBUFS;
-                       dst_release(&rt->u.dst);
-                       goto error;
-               }
+       xdst->u.rt.fl = rt->fl;
 
-               if (!dst)
-                       dst = dst1;
-               else {
-                       dst_prev->child = dst1;
-                       dst1->flags |= DST_NOHASH;
-                       dst_clone(dst1);
-               }
+       xdst->u.dst.dev = dev;
+       dev_hold(dev);
 
-               xdst = (struct xfrm_dst *)dst1;
-               xdst->route = &rt->u.dst;
+       xdst->u.rt.idev = in_dev_get(dev);
+       if (!xdst->u.rt.idev)
+               return -ENODEV;
 
-               dst1->next = dst_prev;
-               dst_prev = dst1;
-               if (xfrm[i]->props.mode != XFRM_MODE_TRANSPORT) {
-                       remote = xfrm[i]->id.daddr.a4;
-                       local  = xfrm[i]->props.saddr.a4;
-                       tunnel = 1;
-               }
-               header_len += xfrm[i]->props.header_len;
-               trailer_len += xfrm[i]->props.trailer_len;
-
-               if (tunnel) {
-                       fl_tunnel.fl4_src = local;
-                       fl_tunnel.fl4_dst = remote;
-                       err = xfrm_dst_lookup((struct xfrm_dst **)&rt,
-                                             &fl_tunnel, AF_INET);
-                       if (err)
-                               goto error;
-               } else
-                       dst_hold(&rt->u.dst);
-       }
+       xdst->u.rt.peer = rt->peer;
+       if (rt->peer)
+               atomic_inc(&rt->peer->refcnt);
 
-       dst_prev->child = &rt->u.dst;
-       dst->path = &rt->u.dst;
-
-       *dst_p = dst;
-       dst = dst_prev;
-
-       dst_prev = *dst_p;
-       i = 0;
-       for (; dst_prev != &rt->u.dst; dst_prev = dst_prev->child) {
-               struct xfrm_dst *x = (struct xfrm_dst*)dst_prev;
-               x->u.rt.fl = *fl;
-
-               dst_prev->xfrm = xfrm[i++];
-               dst_prev->dev = rt->u.dst.dev;
-               if (rt->u.dst.dev)
-                       dev_hold(rt->u.dst.dev);
-               dst_prev->obsolete      = -1;
-               dst_prev->flags        |= DST_HOST;
-               dst_prev->lastuse       = jiffies;
-               dst_prev->header_len    = header_len;
-               dst_prev->nfheader_len  = 0;
-               dst_prev->trailer_len   = trailer_len;
-               memcpy(&dst_prev->metrics, &x->route->metrics, sizeof(dst_prev->metrics));
-
-               /* Copy neighbout for reachability confirmation */
-               dst_prev->neighbour     = neigh_clone(rt->u.dst.neighbour);
-               dst_prev->input         = rt->u.dst.input;
-               dst_prev->output        = xfrm4_output;
-               if (rt->peer)
-                       atomic_inc(&rt->peer->refcnt);
-               x->u.rt.peer = rt->peer;
-               /* Sheit... I remember I did this right. Apparently,
-                * it was magically lost, so this code needs audit */
-               x->u.rt.rt_flags = rt0->rt_flags&(RTCF_BROADCAST|RTCF_MULTICAST|RTCF_LOCAL);
-               x->u.rt.rt_type = rt->rt_type;
-               x->u.rt.rt_src = rt0->rt_src;
-               x->u.rt.rt_dst = rt0->rt_dst;
-               x->u.rt.rt_gateway = rt->rt_gateway;
-               x->u.rt.rt_spec_dst = rt0->rt_spec_dst;
-               x->u.rt.idev = rt0->idev;
-               in_dev_hold(rt0->idev);
-               header_len -= x->u.dst.xfrm->props.header_len;
-               trailer_len -= x->u.dst.xfrm->props.trailer_len;
-       }
+       /* Sheit... I remember I did this right. Apparently,
+        * it was magically lost, so this code needs audit */
+       xdst->u.rt.rt_flags = rt->rt_flags & (RTCF_BROADCAST | RTCF_MULTICAST |
+                                             RTCF_LOCAL);
+       xdst->u.rt.rt_type = rt->rt_type;
+       xdst->u.rt.rt_src = rt->rt_src;
+       xdst->u.rt.rt_dst = rt->rt_dst;
+       xdst->u.rt.rt_gateway = rt->rt_gateway;
+       xdst->u.rt.rt_spec_dst = rt->rt_spec_dst;
 
-       xfrm_init_pmtu(dst);
        return 0;
-
-error:
-       if (dst)
-               dst_free(dst);
-       return err;
 }
 
 static void
-_decode_session4(struct sk_buff *skb, struct flowi *fl)
+_decode_session4(struct sk_buff *skb, struct flowi *fl, int reverse)
 {
-       struct iphdr *iph = skb->nh.iph;
-       u8 *xprth = skb->nh.raw + iph->ihl*4;
+       struct iphdr *iph = ip_hdr(skb);
+       u8 *xprth = skb_network_header(skb) + iph->ihl * 4;
 
        memset(fl, 0, sizeof(struct flowi));
        if (!(iph->frag_off & htons(IP_MF | IP_OFFSET))) {
                switch (iph->protocol) {
                case IPPROTO_UDP:
+               case IPPROTO_UDPLITE:
                case IPPROTO_TCP:
                case IPPROTO_SCTP:
                case IPPROTO_DCCP:
-                       if (pskb_may_pull(skb, xprth + 4 - skb->data)) {
-                               u16 *ports = (u16 *)xprth;
+                       if (xprth + 4 < skb->data ||
+                           pskb_may_pull(skb, xprth + 4 - skb->data)) {
+                               __be16 *ports = (__be16 *)xprth;
 
-                               fl->fl_ip_sport = ports[0];
-                               fl->fl_ip_dport = ports[1];
+                               fl->fl_ip_sport = ports[!!reverse];
+                               fl->fl_ip_dport = ports[!reverse];
                        }
                        break;
 
@@ -201,7 +155,7 @@ _decode_session4(struct sk_buff *skb, struct flowi *fl)
 
                case IPPROTO_ESP:
                        if (pskb_may_pull(skb, xprth + 4 - skb->data)) {
-                               u32 *ehdr = (u32 *)xprth;
+                               __be32 *ehdr = (__be32 *)xprth;
 
                                fl->fl_ipsec_spi = ehdr[0];
                        }
@@ -209,7 +163,7 @@ _decode_session4(struct sk_buff *skb, struct flowi *fl)
 
                case IPPROTO_AH:
                        if (pskb_may_pull(skb, xprth + 8 - skb->data)) {
-                               u32 *ah_hdr = (u32*)xprth;
+                               __be32 *ah_hdr = (__be32*)xprth;
 
                                fl->fl_ipsec_spi = ah_hdr[1];
                        }
@@ -217,7 +171,7 @@ _decode_session4(struct sk_buff *skb, struct flowi *fl)
 
                case IPPROTO_COMP:
                        if (pskb_may_pull(skb, xprth + 4 - skb->data)) {
-                               u16 *ipcomp_hdr = (u16 *)xprth;
+                               __be16 *ipcomp_hdr = (__be16 *)xprth;
 
                                fl->fl_ipsec_spi = htonl(ntohs(ipcomp_hdr[1]));
                        }
@@ -225,18 +179,20 @@ _decode_session4(struct sk_buff *skb, struct flowi *fl)
                default:
                        fl->fl_ipsec_spi = 0;
                        break;
-               };
+               }
        }
        fl->proto = iph->protocol;
-       fl->fl4_dst = iph->daddr;
-       fl->fl4_src = iph->saddr;
+       fl->fl4_dst = reverse ? iph->saddr : iph->daddr;
+       fl->fl4_src = reverse ? iph->daddr : iph->saddr;
        fl->fl4_tos = iph->tos;
 }
 
-static inline int xfrm4_garbage_collect(void)
+static inline int xfrm4_garbage_collect(struct dst_ops *ops)
 {
-       xfrm4_policy_afinfo.garbage_collect();
-       return (atomic_read(&xfrm4_dst_ops.entries) > xfrm4_dst_ops.gc_thresh*2);
+       struct net *net = container_of(ops, struct net, xfrm.xfrm4_dst_ops);
+
+       xfrm4_policy_afinfo.garbage_collect(net);
+       return (atomic_read(&ops->entries) > ops->gc_thresh * 2);
 }
 
 static void xfrm4_update_pmtu(struct dst_entry *dst, u32 mtu)
@@ -253,6 +209,8 @@ static void xfrm4_dst_destroy(struct dst_entry *dst)
 
        if (likely(xdst->u.rt.idev))
                in_dev_put(xdst->u.rt.idev);
+       if (likely(xdst->u.rt.peer))
+               inet_putpeer(xdst->u.rt.peer);
        xfrm_dst_destroy(xdst);
 }
 
@@ -266,7 +224,8 @@ static void xfrm4_dst_ifdown(struct dst_entry *dst, struct net_device *dev,
 
        xdst = (struct xfrm_dst *)dst;
        if (xdst->u.rt.idev->dev == dev) {
-               struct in_device *loopback_idev = in_dev_get(&loopback_dev);
+               struct in_device *loopback_idev =
+                       in_dev_get(dev_net(dev)->loopback_dev);
                BUG_ON(!loopback_idev);
 
                do {
@@ -284,24 +243,43 @@ static void xfrm4_dst_ifdown(struct dst_entry *dst, struct net_device *dev,
 
 static struct dst_ops xfrm4_dst_ops = {
        .family =               AF_INET,
-       .protocol =             __constant_htons(ETH_P_IP),
+       .protocol =             cpu_to_be16(ETH_P_IP),
        .gc =                   xfrm4_garbage_collect,
        .update_pmtu =          xfrm4_update_pmtu,
        .destroy =              xfrm4_dst_destroy,
        .ifdown =               xfrm4_dst_ifdown,
+       .local_out =            __ip_local_out,
        .gc_thresh =            1024,
-       .entry_size =           sizeof(struct xfrm_dst),
+       .entries =              ATOMIC_INIT(0),
 };
 
 static struct xfrm_policy_afinfo xfrm4_policy_afinfo = {
        .family =               AF_INET,
        .dst_ops =              &xfrm4_dst_ops,
        .dst_lookup =           xfrm4_dst_lookup,
+       .get_saddr =            xfrm4_get_saddr,
        .find_bundle =          __xfrm4_find_bundle,
-       .bundle_create =        __xfrm4_bundle_create,
        .decode_session =       _decode_session4,
+       .get_tos =              xfrm4_get_tos,
+       .init_path =            xfrm4_init_path,
+       .fill_dst =             xfrm4_fill_dst,
+};
+
+#ifdef CONFIG_SYSCTL
+static struct ctl_table xfrm4_policy_table[] = {
+       {
+               .procname       = "xfrm4_gc_thresh",
+               .data           = &init_net.xfrm.xfrm4_dst_ops.gc_thresh,
+               .maxlen         = sizeof(int),
+               .mode           = 0644,
+               .proc_handler   = proc_dointvec,
+       },
+       { }
 };
 
+static struct ctl_table_header *sysctl_hdr;
+#endif
+
 static void __init xfrm4_policy_init(void)
 {
        xfrm_policy_register_afinfo(&xfrm4_policy_afinfo);
@@ -309,12 +287,32 @@ static void __init xfrm4_policy_init(void)
 
 static void __exit xfrm4_policy_fini(void)
 {
+#ifdef CONFIG_SYSCTL
+       if (sysctl_hdr)
+               unregister_net_sysctl_table(sysctl_hdr);
+#endif
        xfrm_policy_unregister_afinfo(&xfrm4_policy_afinfo);
 }
 
-void __init xfrm4_init(void)
+void __init xfrm4_init(int rt_max_size)
 {
+       /*
+        * Select a default value for the gc_thresh based on the main route
+        * table hash size.  It seems to me the worst case scenario is when
+        * we have ipsec operating in transport mode, in which we create a
+        * dst_entry per socket.  The xfrm gc algorithm starts trying to remove
+        * entries at gc_thresh, and prevents new allocations as 2*gc_thresh
+        * so lets set an initial xfrm gc_thresh value at the rt_max_size/2.
+        * That will let us store an ipsec connection per route table entry,
+        * and start cleaning when were 1/2 full
+        */
+       xfrm4_dst_ops.gc_thresh = rt_max_size/2;
+
        xfrm4_state_init();
        xfrm4_policy_init();
+#ifdef CONFIG_SYSCTL
+       sysctl_hdr = register_net_sysctl_table(&init_net, net_ipv4_ctl_path,
+                                               xfrm4_policy_table);
+#endif
 }