net: add a noref bit on skb dst
[safe/jmp/linux-2.6] / net / core / sock.c
index 61a65a2..bf88a16 100644 (file)
@@ -307,6 +307,11 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
         */
        skb_len = skb->len;
 
+       /* we escape from rcu protected region, make sure we dont leak
+        * a norefcounted dst
+        */
+       skb_dst_force(skb);
+
        spin_lock_irqsave(&list->lock, flags);
        skb->dropcount = atomic_read(&sk->sk_drops);
        __skb_queue_tail(list, skb);
@@ -327,6 +332,10 @@ int sk_receive_skb(struct sock *sk, struct sk_buff *skb, const int nested)
 
        skb->dev = NULL;
 
+       if (sk_rcvqueues_full(sk, skb)) {
+               atomic_inc(&sk->sk_drops);
+               goto discard_and_relse;
+       }
        if (nested)
                bh_lock_sock_nested(sk);
        else
@@ -364,11 +373,11 @@ EXPORT_SYMBOL(sk_reset_txq);
 
 struct dst_entry *__sk_dst_check(struct sock *sk, u32 cookie)
 {
-       struct dst_entry *dst = sk->sk_dst_cache;
+       struct dst_entry *dst = __sk_dst_get(sk);
 
        if (dst && dst->obsolete && dst->ops->check(dst, cookie) == NULL) {
                sk_tx_queue_clear(sk);
-               sk->sk_dst_cache = NULL;
+               rcu_assign_pointer(sk->sk_dst_cache, NULL);
                dst_release(dst);
                return NULL;
        }
@@ -1157,7 +1166,7 @@ struct sock *sk_clone(const struct sock *sk, const gfp_t priority)
                skb_queue_head_init(&newsk->sk_async_wait_queue);
 #endif
 
-               rwlock_init(&newsk->sk_dst_lock);
+               spin_lock_init(&newsk->sk_dst_lock);
                rwlock_init(&newsk->sk_callback_lock);
                lockdep_set_class_and_name(&newsk->sk_callback_lock,
                                af_callback_keys + newsk->sk_family,
@@ -1207,7 +1216,7 @@ struct sock *sk_clone(const struct sock *sk, const gfp_t priority)
                 */
                sk_refcnt_debug_inc(newsk);
                sk_set_socket(newsk, NULL);
-               newsk->sk_sleep  = NULL;
+               newsk->sk_wq = NULL;
 
                if (newsk->sk_prot->sockets_allocated)
                        percpu_counter_inc(newsk->sk_prot->sockets_allocated);
@@ -1227,6 +1236,7 @@ void sk_setup_caps(struct sock *sk, struct dst_entry *dst)
        sk->sk_route_caps = dst->dev->features;
        if (sk->sk_route_caps & NETIF_F_GSO)
                sk->sk_route_caps |= NETIF_F_GSO_SOFTWARE;
+       sk->sk_route_caps &= ~sk->sk_route_nocaps;
        if (sk_can_gso(sk)) {
                if (dst->header_len) {
                        sk->sk_route_caps &= ~NETIF_F_GSO_MASK;
@@ -1395,7 +1405,7 @@ static long sock_wait_for_wmem(struct sock *sk, long timeo)
                if (signal_pending(current))
                        break;
                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-               prepare_to_wait(sk->sk_sleep, &wait, TASK_INTERRUPTIBLE);
+               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                if (atomic_read(&sk->sk_wmem_alloc) < sk->sk_sndbuf)
                        break;
                if (sk->sk_shutdown & SEND_SHUTDOWN)
@@ -1404,7 +1414,7 @@ static long sock_wait_for_wmem(struct sock *sk, long timeo)
                        break;
                timeo = schedule_timeout(timeo);
        }
-       finish_wait(sk->sk_sleep, &wait);
+       finish_wait(sk_sleep(sk), &wait);
        return timeo;
 }
 
@@ -1531,6 +1541,7 @@ static void __release_sock(struct sock *sk)
                do {
                        struct sk_buff *next = skb->next;
 
+                       WARN_ON_ONCE(skb_dst_is_noref(skb));
                        skb->next = NULL;
                        sk_backlog_rcv(sk, skb);
 
@@ -1570,11 +1581,11 @@ int sk_wait_data(struct sock *sk, long *timeo)
        int rc;
        DEFINE_WAIT(wait);
 
-       prepare_to_wait(sk->sk_sleep, &wait, TASK_INTERRUPTIBLE);
+       prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
        set_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
        rc = sk_wait_event(sk, timeo, !skb_queue_empty(&sk->sk_receive_queue));
        clear_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
-       finish_wait(sk->sk_sleep, &wait);
+       finish_wait(sk_sleep(sk), &wait);
        return rc;
 }
 EXPORT_SYMBOL(sk_wait_data);
@@ -1796,41 +1807,53 @@ EXPORT_SYMBOL(sock_no_sendpage);
 
 static void sock_def_wakeup(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
-       if (sk_has_sleeper(sk))
-               wake_up_interruptible_all(sk->sk_sleep);
-       read_unlock(&sk->sk_callback_lock);
+       struct socket_wq *wq;
+
+       rcu_read_lock();
+       wq = rcu_dereference(sk->sk_wq);
+       if (wq_has_sleeper(wq))
+               wake_up_interruptible_all(&wq->wait);
+       rcu_read_unlock();
 }
 
 static void sock_def_error_report(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
-       if (sk_has_sleeper(sk))
-               wake_up_interruptible_poll(sk->sk_sleep, POLLERR);
+       struct socket_wq *wq;
+
+       rcu_read_lock();
+       wq = rcu_dereference(sk->sk_wq);
+       if (wq_has_sleeper(wq))
+               wake_up_interruptible_poll(&wq->wait, POLLERR);
        sk_wake_async(sk, SOCK_WAKE_IO, POLL_ERR);
-       read_unlock(&sk->sk_callback_lock);
+       rcu_read_unlock();
 }
 
 static void sock_def_readable(struct sock *sk, int len)
 {
-       read_lock(&sk->sk_callback_lock);
-       if (sk_has_sleeper(sk))
-               wake_up_interruptible_sync_poll(sk->sk_sleep, POLLIN |
+       struct socket_wq *wq;
+
+       rcu_read_lock();
+       wq = rcu_dereference(sk->sk_wq);
+       if (wq_has_sleeper(wq))
+               wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
                                                POLLRDNORM | POLLRDBAND);
        sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
-       read_unlock(&sk->sk_callback_lock);
+       rcu_read_unlock();
 }
 
 static void sock_def_write_space(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
+       struct socket_wq *wq;
+
+       rcu_read_lock();
 
        /* Do not wake up a writer until he can make "significant"
         * progress.  --DaveM
         */
        if ((atomic_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf) {
-               if (sk_has_sleeper(sk))
-                       wake_up_interruptible_sync_poll(sk->sk_sleep, POLLOUT |
+               wq = rcu_dereference(sk->sk_wq);
+               if (wq_has_sleeper(wq))
+                       wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
                                                POLLWRNORM | POLLWRBAND);
 
                /* Should agree with poll, otherwise some programs break */
@@ -1838,7 +1861,7 @@ static void sock_def_write_space(struct sock *sk)
                        sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
        }
 
-       read_unlock(&sk->sk_callback_lock);
+       rcu_read_unlock();
 }
 
 static void sock_def_destruct(struct sock *sk)
@@ -1885,7 +1908,6 @@ void sock_init_data(struct socket *sock, struct sock *sk)
        sk->sk_allocation       =       GFP_KERNEL;
        sk->sk_rcvbuf           =       sysctl_rmem_default;
        sk->sk_sndbuf           =       sysctl_wmem_default;
-       sk->sk_backlog.limit    =       sk->sk_rcvbuf << 1;
        sk->sk_state            =       TCP_CLOSE;
        sk_set_socket(sk, sock);
 
@@ -1893,12 +1915,12 @@ void sock_init_data(struct socket *sock, struct sock *sk)
 
        if (sock) {
                sk->sk_type     =       sock->type;
-               sk->sk_sleep    =       &sock->wait;
+               sk->sk_wq       =       sock->wq;
                sock->sk        =       sk;
        } else
-               sk->sk_sleep    =       NULL;
+               sk->sk_wq       =       NULL;
 
-       rwlock_init(&sk->sk_dst_lock);
+       spin_lock_init(&sk->sk_dst_lock);
        rwlock_init(&sk->sk_callback_lock);
        lockdep_set_class_and_name(&sk->sk_callback_lock,
                        af_callback_keys + sk->sk_family,
@@ -2288,7 +2310,8 @@ out_free_request_sock_slab:
                prot->rsk_prot->slab = NULL;
        }
 out_free_request_sock_slab_name:
-       kfree(prot->rsk_prot->slab_name);
+       if (prot->rsk_prot)
+               kfree(prot->rsk_prot->slab_name);
 out_free_sock_slab:
        kmem_cache_destroy(prot->slab);
        prot->slab = NULL;