nfsd: silence lockdep warning
[safe/jmp/linux-2.6] / net / ipv6 / udp.c
index 71e259e..8905712 100644 (file)
 #include <linux/seq_file.h>
 #include "udp_impl.h"
 
+int ipv6_rcv_saddr_equal(const struct sock *sk, const struct sock *sk2)
+{
+       const struct in6_addr *sk_rcv_saddr6 = &inet6_sk(sk)->rcv_saddr;
+       const struct in6_addr *sk2_rcv_saddr6 = inet6_rcv_saddr(sk2);
+       __be32 sk_rcv_saddr = inet_sk(sk)->rcv_saddr;
+       __be32 sk2_rcv_saddr = inet_rcv_saddr(sk2);
+       int sk_ipv6only = ipv6_only_sock(sk);
+       int sk2_ipv6only = inet_v6_ipv6only(sk2);
+       int addr_type = ipv6_addr_type(sk_rcv_saddr6);
+       int addr_type2 = sk2_rcv_saddr6 ? ipv6_addr_type(sk2_rcv_saddr6) : IPV6_ADDR_MAPPED;
+
+       /* if both are mapped, treat as IPv4 */
+       if (addr_type == IPV6_ADDR_MAPPED && addr_type2 == IPV6_ADDR_MAPPED)
+               return (!sk2_ipv6only &&
+                       (!sk_rcv_saddr || !sk2_rcv_saddr ||
+                         sk_rcv_saddr == sk2_rcv_saddr));
+
+       if (addr_type2 == IPV6_ADDR_ANY &&
+           !(sk2_ipv6only && addr_type == IPV6_ADDR_MAPPED))
+               return 1;
+
+       if (addr_type == IPV6_ADDR_ANY &&
+           !(sk_ipv6only && addr_type2 == IPV6_ADDR_MAPPED))
+               return 1;
+
+       if (sk2_rcv_saddr6 &&
+           ipv6_addr_equal(sk_rcv_saddr6, sk2_rcv_saddr6))
+               return 1;
+
+       return 0;
+}
+
 int udp_v6_get_port(struct sock *sk, unsigned short snum)
 {
        return udp_lib_get_port(sk, snum, ipv6_rcv_saddr_equal);
 }
 
+static inline int compute_score(struct sock *sk, struct net *net,
+                               unsigned short hnum,
+                               struct in6_addr *saddr, __be16 sport,
+                               struct in6_addr *daddr, __be16 dport,
+                               int dif)
+{
+       int score = -1;
+
+       if (net_eq(sock_net(sk), net) && sk->sk_hash == hnum &&
+                       sk->sk_family == PF_INET6) {
+               struct ipv6_pinfo *np = inet6_sk(sk);
+               struct inet_sock *inet = inet_sk(sk);
+
+               score = 0;
+               if (inet->dport) {
+                       if (inet->dport != sport)
+                               return -1;
+                       score++;
+               }
+               if (!ipv6_addr_any(&np->rcv_saddr)) {
+                       if (!ipv6_addr_equal(&np->rcv_saddr, daddr))
+                               return -1;
+                       score++;
+               }
+               if (!ipv6_addr_any(&np->daddr)) {
+                       if (!ipv6_addr_equal(&np->daddr, saddr))
+                               return -1;
+                       score++;
+               }
+               if (sk->sk_bound_dev_if) {
+                       if (sk->sk_bound_dev_if != dif)
+                               return -1;
+                       score++;
+               }
+       }
+       return score;
+}
+
 static struct sock *__udp6_lib_lookup(struct net *net,
                                      struct in6_addr *saddr, __be16 sport,
                                      struct in6_addr *daddr, __be16 dport,
-                                     int dif, struct hlist_head udptable[])
+                                     int dif, struct udp_table *udptable)
 {
-       struct sock *sk, *result = NULL;
-       struct hlist_node *node;
+       struct sock *sk, *result;
+       struct hlist_nulls_node *node;
        unsigned short hnum = ntohs(dport);
-       int badness = -1;
-
-       read_lock(&udp_hash_lock);
-       sk_for_each(sk, node, &udptable[udp_hashfn(net, hnum)]) {
-               struct inet_sock *inet = inet_sk(sk);
-
-               if (net_eq(sock_net(sk), net) && sk->sk_hash == hnum &&
-                               sk->sk_family == PF_INET6) {
-                       struct ipv6_pinfo *np = inet6_sk(sk);
-                       int score = 0;
-                       if (inet->dport) {
-                               if (inet->dport != sport)
-                                       continue;
-                               score++;
-                       }
-                       if (!ipv6_addr_any(&np->rcv_saddr)) {
-                               if (!ipv6_addr_equal(&np->rcv_saddr, daddr))
-                                       continue;
-                               score++;
-                       }
-                       if (!ipv6_addr_any(&np->daddr)) {
-                               if (!ipv6_addr_equal(&np->daddr, saddr))
-                                       continue;
-                               score++;
-                       }
-                       if (sk->sk_bound_dev_if) {
-                               if (sk->sk_bound_dev_if != dif)
-                                       continue;
-                               score++;
-                       }
-                       if (score == 4) {
-                               result = sk;
-                               break;
-                       } else if (score > badness) {
-                               result = sk;
-                               badness = score;
-                       }
+       unsigned int hash = udp_hashfn(net, hnum);
+       struct udp_hslot *hslot = &udptable->hash[hash];
+       int score, badness;
+
+       rcu_read_lock();
+begin:
+       result = NULL;
+       badness = -1;
+       sk_nulls_for_each_rcu(sk, node, &hslot->head) {
+               score = compute_score(sk, net, hnum, saddr, sport, daddr, dport, dif);
+               if (score > badness) {
+                       result = sk;
+                       badness = score;
                }
        }
-       if (result)
-               sock_hold(result);
-       read_unlock(&udp_hash_lock);
+       /*
+        * if the nulls value we got at the end of this lookup is
+        * not the expected one, we must restart lookup.
+        * We probably met an item that was moved to another chain.
+        */
+       if (get_nulls_value(node) != hash)
+               goto begin;
+
+       if (result) {
+               if (unlikely(!atomic_inc_not_zero(&result->sk_refcnt)))
+                       result = NULL;
+               else if (unlikely(compute_score(result, net, hnum, saddr, sport,
+                                       daddr, dport, dif) < badness)) {
+                       sock_put(result);
+                       goto begin;
+               }
+       }
+       rcu_read_unlock();
        return result;
 }
 
 static struct sock *__udp6_lib_lookup_skb(struct sk_buff *skb,
                                          __be16 sport, __be16 dport,
-                                         struct hlist_head udptable[])
+                                         struct udp_table *udptable)
 {
        struct sock *sk;
        struct ipv6hdr *iph = ipv6_hdr(skb);
@@ -138,6 +199,7 @@ int udpv6_recvmsg(struct kiocb *iocb, struct sock *sk,
        int peeked;
        int err;
        int is_udplite = IS_UDPLITE(sk);
+       int is_udp4;
 
        if (addr_len)
                *addr_len=sizeof(struct sockaddr_in6);
@@ -158,6 +220,8 @@ try_again:
        else if (copied < ulen)
                msg->msg_flags |= MSG_TRUNC;
 
+       is_udp4 = (skb->protocol == htons(ETH_P_IP));
+
        /*
         * If checksum is needed at all, try to do it while copying the
         * data.  If the data is truncated, or if we only want a partial
@@ -180,9 +244,14 @@ try_again:
        if (err)
                goto out_free;
 
-       if (!peeked)
-               UDP6_INC_STATS_USER(sock_net(sk),
-                               UDP_MIB_INDATAGRAMS, is_udplite);
+       if (!peeked) {
+               if (is_udp4)
+                       UDP_INC_STATS_USER(sock_net(sk),
+                                       UDP_MIB_INDATAGRAMS, is_udplite);
+               else
+                       UDP6_INC_STATS_USER(sock_net(sk),
+                                       UDP_MIB_INDATAGRAMS, is_udplite);
+       }
 
        sock_recv_timestamp(msg, sk, skb);
 
@@ -196,7 +265,7 @@ try_again:
                sin6->sin6_flowinfo = 0;
                sin6->sin6_scope_id = 0;
 
-               if (skb->protocol == htons(ETH_P_IP))
+               if (is_udp4)
                        ipv6_addr_set(&sin6->sin6_addr, 0, 0,
                                      htonl(0xffff), ip_hdr(skb)->saddr);
                else {
@@ -207,7 +276,7 @@ try_again:
                }
 
        }
-       if (skb->protocol == htons(ETH_P_IP)) {
+       if (is_udp4) {
                if (inet->cmsg_flags)
                        ip_cmsg_recv(msg, skb);
        } else {
@@ -228,8 +297,14 @@ out:
 
 csum_copy_err:
        lock_sock(sk);
-       if (!skb_kill_datagram(sk, skb, flags))
-               UDP6_INC_STATS_USER(sock_net(sk), UDP_MIB_INERRORS, is_udplite);
+       if (!skb_kill_datagram(sk, skb, flags)) {
+               if (is_udp4)
+                       UDP_INC_STATS_USER(sock_net(sk),
+                                       UDP_MIB_INERRORS, is_udplite);
+               else
+                       UDP6_INC_STATS_USER(sock_net(sk),
+                                       UDP_MIB_INERRORS, is_udplite);
+       }
        release_sock(sk);
 
        if (flags & MSG_DONTWAIT)
@@ -239,7 +314,7 @@ csum_copy_err:
 
 void __udp6_lib_err(struct sk_buff *skb, struct inet6_skb_parm *opt,
                    int type, int code, int offset, __be32 info,
-                   struct hlist_head udptable[]                    )
+                   struct udp_table *udptable)
 {
        struct ipv6_pinfo *np;
        struct ipv6hdr *hdr = (struct ipv6hdr*)skb->data;
@@ -275,7 +350,7 @@ static __inline__ void udpv6_err(struct sk_buff *skb,
                                 struct inet6_skb_parm *opt, int type,
                                 int code, int offset, __be32 info     )
 {
-       __udp6_lib_err(skb, opt, type, code, offset, info, udp_hash);
+       __udp6_lib_err(skb, opt, type, code, offset, info, &udp_table);
 }
 
 int udpv6_queue_rcv_skb(struct sock * sk, struct sk_buff *skb)
@@ -333,11 +408,11 @@ static struct sock *udp_v6_mcast_next(struct net *net, struct sock *sk,
                                      __be16 rmt_port, struct in6_addr *rmt_addr,
                                      int dif)
 {
-       struct hlist_node *node;
+       struct hlist_nulls_node *node;
        struct sock *s = sk;
        unsigned short num = ntohs(loc_port);
 
-       sk_for_each_from(s, node) {
+       sk_nulls_for_each_from(s, node) {
                struct inet_sock *inet = inet_sk(s);
 
                if (!net_eq(sock_net(s), net))
@@ -374,14 +449,15 @@ static struct sock *udp_v6_mcast_next(struct net *net, struct sock *sk,
  */
 static int __udp6_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
                struct in6_addr *saddr, struct in6_addr *daddr,
-               struct hlist_head udptable[])
+               struct udp_table *udptable)
 {
        struct sock *sk, *sk2;
        const struct udphdr *uh = udp_hdr(skb);
+       struct udp_hslot *hslot = &udptable->hash[udp_hashfn(net, ntohs(uh->dest))];
        int dif;
 
-       read_lock(&udp_hash_lock);
-       sk = sk_head(&udptable[udp_hashfn(net, ntohs(uh->dest))]);
+       spin_lock(&hslot->lock);
+       sk = sk_nulls_head(&hslot->head);
        dif = inet6_iif(skb);
        sk = udp_v6_mcast_next(net, sk, uh->dest, daddr, uh->source, saddr, dif);
        if (!sk) {
@@ -390,7 +466,7 @@ static int __udp6_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
        }
 
        sk2 = sk;
-       while ((sk2 = udp_v6_mcast_next(net, sk_next(sk2), uh->dest, daddr,
+       while ((sk2 = udp_v6_mcast_next(net, sk_nulls_next(sk2), uh->dest, daddr,
                                        uh->source, saddr, dif))) {
                struct sk_buff *buff = skb_clone(skb, GFP_ATOMIC);
                if (buff) {
@@ -409,7 +485,7 @@ static int __udp6_lib_mcast_deliver(struct net *net, struct sk_buff *skb,
                sk_add_backlog(sk, skb);
        bh_unlock_sock(sk);
 out:
-       read_unlock(&udp_hash_lock);
+       spin_unlock(&hslot->lock);
        return 0;
 }
 
@@ -447,7 +523,7 @@ static inline int udp6_csum_init(struct sk_buff *skb, struct udphdr *uh,
        return 0;
 }
 
-int __udp6_lib_rcv(struct sk_buff *skb, struct hlist_head udptable[],
+int __udp6_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
                   int proto)
 {
        struct sock *sk;
@@ -544,7 +620,7 @@ discard:
 
 static __inline__ int udpv6_rcv(struct sk_buff *skb)
 {
-       return __udp6_lib_rcv(skb, udp_hash, IPPROTO_UDP);
+       return __udp6_lib_rcv(skb, &udp_table, IPPROTO_UDP);
 }
 
 /*
@@ -749,6 +825,9 @@ do_udp_sendmsg:
        if (!fl.oif)
                fl.oif = sk->sk_bound_dev_if;
 
+       if (!fl.oif)
+               fl.oif = np->sticky_pktinfo.ipi6_ifindex;
+
        if (msg->msg_controllen) {
                opt = &opt_space;
                memset(opt, 0, sizeof(struct ipv6_txoptions));
@@ -805,7 +884,8 @@ do_udp_sendmsg:
        if (final_p)
                ipv6_addr_copy(&fl.fl6_dst, final_p);
 
-       if ((err = __xfrm_lookup(&dst, &fl, sk, XFRM_LOOKUP_WAIT)) < 0) {
+       err = __xfrm_lookup(sock_net(sk), &dst, &fl, sk, XFRM_LOOKUP_WAIT);
+       if (err < 0) {
                if (err == -EREMOTE)
                        err = ip6_dst_blackhole(sk, &dst, &fl);
                if (err < 0)
@@ -1008,7 +1088,7 @@ int udp6_seq_show(struct seq_file *seq, void *v)
 static struct udp_seq_afinfo udp6_seq_afinfo = {
        .name           = "udp6",
        .family         = AF_INET6,
-       .hashtable      = udp_hash,
+       .udp_table      = &udp_table,
        .seq_fops       = {
                .owner  =       THIS_MODULE,
        },
@@ -1050,7 +1130,8 @@ struct proto udpv6_prot = {
        .sysctl_wmem       = &sysctl_udp_wmem_min,
        .sysctl_rmem       = &sysctl_udp_rmem_min,
        .obj_size          = sizeof(struct udp6_sock),
-       .h.udp_hash        = udp_hash,
+       .slab_flags        = SLAB_DESTROY_BY_RCU,
+       .h.udp_table       = &udp_table,
 #ifdef CONFIG_COMPAT
        .compat_setsockopt = compat_udpv6_setsockopt,
        .compat_getsockopt = compat_udpv6_getsockopt,