l2tp: fix UDP checksum support
[safe/jmp/linux-2.6] / drivers / net / pppol2tp.c
index 70cfdb4..f1a9467 100644 (file)
@@ -61,7 +61,6 @@
  */
 
 #include <linux/module.h>
-#include <linux/version.h>
 #include <linux/string.h>
 #include <linux/list.h>
 #include <asm/uaccess.h>
@@ -354,7 +353,7 @@ static void pppol2tp_recv_queue_skb(struct pppol2tp_session *session, struct sk_
        spin_lock_bh(&session->reorder_q.lock);
        skb_queue_walk_safe(&session->reorder_q, skbp, tmp) {
                if (PPPOL2TP_SKB_CB(skbp)->ns > ns) {
-                       __skb_insert(skb, skbp->prev, skbp, &session->reorder_q);
+                       __skb_queue_before(&session->reorder_q, skbp, skb);
                        PRINTK(session->debug, PPPOL2TP_MSG_SEQ, KERN_DEBUG,
                               "%s: pkt %hu, inserted before %hu, reorder_q len=%d\n",
                               session->name, ns, PPPOL2TP_SKB_CB(skbp)->ns,
@@ -490,6 +489,30 @@ out:
        spin_unlock_bh(&session->reorder_q.lock);
 }
 
+static inline int pppol2tp_verify_udp_checksum(struct sock *sk,
+                                              struct sk_buff *skb)
+{
+       struct udphdr *uh = udp_hdr(skb);
+       u16 ulen = ntohs(uh->len);
+       struct inet_sock *inet;
+       __wsum psum;
+
+       if (sk->sk_no_check || skb_csum_unnecessary(skb) || !uh->check)
+               return 0;
+
+       inet = inet_sk(sk);
+       psum = csum_tcpudp_nofold(inet->saddr, inet->daddr, ulen,
+                                 IPPROTO_UDP, 0);
+
+       if ((skb->ip_summed == CHECKSUM_COMPLETE) &&
+           !csum_fold(csum_add(psum, skb->csum)))
+               return 0;
+
+       skb->csum = psum;
+
+       return __skb_checksum_complete(skb);
+}
+
 /* Internal receive frame. Do the real work of receiving an L2TP data frame
  * here. The skb is not on a list when we get here.
  * Returns 0 if the packet was a data packet and was successfully passed on.
@@ -510,6 +533,9 @@ static int pppol2tp_recv_core(struct sock *sock, struct sk_buff *skb)
        if (tunnel == NULL)
                goto no_tunnel;
 
+       if (tunnel->sock && pppol2tp_verify_udp_checksum(tunnel->sock, skb))
+               goto discard_bad_csum;
+
        /* UDP always verifies the packet length. */
        __skb_pull(skb, sizeof(struct udphdr));
 
@@ -726,6 +752,14 @@ discard:
 
        return 0;
 
+discard_bad_csum:
+       LIMIT_NETDEBUG("%s: UDP: bad checksum\n", tunnel->name);
+       UDP_INC_STATS_USER(&init_net, UDP_MIB_INERRORS, 0);
+       tunnel->stats.rx_errors++;
+       kfree_skb(skb);
+
+       return 0;
+
 error:
        /* Put UDP header back */
        __skb_push(skb, sizeof(struct udphdr));
@@ -783,14 +817,18 @@ static int pppol2tp_recvmsg(struct kiocb *iocb, struct socket *sock,
        err = 0;
        skb = skb_recv_datagram(sk, flags & ~MSG_DONTWAIT,
                                flags & MSG_DONTWAIT, &err);
-       if (skb) {
-               err = memcpy_toiovec(msg->msg_iov, (unsigned char *) skb->data,
-                                    skb->len);
-               if (err < 0)
-                       goto do_skb_free;
-               err = skb->len;
-       }
-do_skb_free:
+       if (!skb)
+               goto end;
+
+       if (len > skb->len)
+               len = skb->len;
+       else if (len < skb->len)
+               msg->msg_flags |= MSG_TRUNC;
+
+       err = skb_copy_datagram_iovec(skb, 0, msg->msg_iov, len);
+       if (likely(err == 0))
+               err = len;
+
        kfree_skb(skb);
 end:
        return err;
@@ -848,7 +886,7 @@ static int pppol2tp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msgh
        static const unsigned char ppph[2] = { 0xff, 0x03 };
        struct sock *sk = sock->sk;
        struct inet_sock *inet;
-       __wsum csum = 0;
+       __wsum csum;
        struct sk_buff *skb;
        int error;
        int hdr_len;
@@ -856,6 +894,8 @@ static int pppol2tp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msgh
        struct pppol2tp_tunnel *tunnel;
        struct udphdr *uh;
        unsigned int len;
+       struct sock *sk_tun;
+       u16 udp_len;
 
        error = -ENOTCONN;
        if (sock_flag(sk, SOCK_DEAD) || !(sk->sk_state & PPPOX_CONNECTED))
@@ -867,7 +907,8 @@ static int pppol2tp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msgh
        if (session == NULL)
                goto error;
 
-       tunnel = pppol2tp_sock_to_tunnel(session->tunnel_sock);
+       sk_tun = session->tunnel_sock;
+       tunnel = pppol2tp_sock_to_tunnel(sk_tun);
        if (tunnel == NULL)
                goto error_put_sess;
 
@@ -890,11 +931,12 @@ static int pppol2tp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msgh
        skb_reset_transport_header(skb);
 
        /* Build UDP header */
-       inet = inet_sk(session->tunnel_sock);
+       inet = inet_sk(sk_tun);
+       udp_len = hdr_len + sizeof(ppph) + total_len;
        uh = (struct udphdr *) skb->data;
        uh->source = inet->sport;
        uh->dest = inet->dport;
-       uh->len = htons(hdr_len + sizeof(ppph) + total_len);
+       uh->len = htons(udp_len);
        uh->check = 0;
        skb_put(skb, sizeof(struct udphdr));
 
@@ -916,8 +958,22 @@ static int pppol2tp_sendmsg(struct kiocb *iocb, struct socket *sock, struct msgh
        skb_put(skb, total_len);
 
        /* Calculate UDP checksum if configured to do so */
-       if (session->tunnel_sock->sk_no_check != UDP_CSUM_NOXMIT)
-               csum = udp_csum_outgoing(sk, skb);
+       if (sk_tun->sk_no_check == UDP_CSUM_NOXMIT)
+               skb->ip_summed = CHECKSUM_NONE;
+       else if (!(skb->dst->dev->features & NETIF_F_V4_CSUM)) {
+               skb->ip_summed = CHECKSUM_COMPLETE;
+               csum = skb_checksum(skb, 0, udp_len, 0);
+               uh->check = csum_tcpudp_magic(inet->saddr, inet->daddr,
+                                             udp_len, IPPROTO_UDP, csum);
+               if (uh->check == 0)
+                       uh->check = CSUM_MANGLED_0;
+       } else {
+               skb->ip_summed = CHECKSUM_PARTIAL;
+               skb->csum_start = skb_transport_header(skb) - skb->head;
+               skb->csum_offset = offsetof(struct udphdr, check);
+               uh->check = ~csum_tcpudp_magic(inet->saddr, inet->daddr,
+                                              udp_len, IPPROTO_UDP, 0);
+       }
 
        /* Debug */
        if (session->send_seq)
@@ -1005,13 +1061,14 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
        struct sock *sk = (struct sock *) chan->private;
        struct sock *sk_tun;
        int hdr_len;
+       u16 udp_len;
        struct pppol2tp_session *session;
        struct pppol2tp_tunnel *tunnel;
        int rc;
        int headroom;
        int data_len = skb->len;
        struct inet_sock *inet;
-       __wsum csum = 0;
+       __wsum csum;
        struct udphdr *uh;
        unsigned int len;
        int old_headroom;
@@ -1057,6 +1114,8 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
        /* Setup L2TP header */
        pppol2tp_build_l2tp_header(session, __skb_push(skb, hdr_len));
 
+       udp_len = sizeof(struct udphdr) + hdr_len + sizeof(ppph) + data_len;
+
        /* Setup UDP header */
        inet = inet_sk(sk_tun);
        __skb_push(skb, sizeof(*uh));
@@ -1064,13 +1123,9 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
        uh = udp_hdr(skb);
        uh->source = inet->sport;
        uh->dest = inet->dport;
-       uh->len = htons(sizeof(struct udphdr) + hdr_len + sizeof(ppph) + data_len);
+       uh->len = htons(udp_len);
        uh->check = 0;
 
-       /* *BROKEN* Calculate UDP checksum if configured to do so */
-       if (sk_tun->sk_no_check != UDP_CSUM_NOXMIT)
-               csum = udp_csum_outgoing(sk_tun, skb);
-
        /* Debug */
        if (session->send_seq)
                PRINTK(session->debug, PPPOL2TP_MSG_DATA, KERN_DEBUG,
@@ -1105,6 +1160,24 @@ static int pppol2tp_xmit(struct ppp_channel *chan, struct sk_buff *skb)
        skb->dst = dst_clone(__sk_dst_get(sk_tun));
        pppol2tp_skb_set_owner_w(skb, sk_tun);
 
+       /* Calculate UDP checksum if configured to do so */
+       if (sk_tun->sk_no_check == UDP_CSUM_NOXMIT)
+               skb->ip_summed = CHECKSUM_NONE;
+       else if (!(skb->dst->dev->features & NETIF_F_V4_CSUM)) {
+               skb->ip_summed = CHECKSUM_COMPLETE;
+               csum = skb_checksum(skb, 0, udp_len, 0);
+               uh->check = csum_tcpudp_magic(inet->saddr, inet->daddr,
+                                             udp_len, IPPROTO_UDP, csum);
+               if (uh->check == 0)
+                       uh->check = CSUM_MANGLED_0;
+       } else {
+               skb->ip_summed = CHECKSUM_PARTIAL;
+               skb->csum_start = skb_transport_header(skb) - skb->head;
+               skb->csum_offset = offsetof(struct udphdr, check);
+               uh->check = ~csum_tcpudp_magic(inet->saddr, inet->daddr,
+                                              udp_len, IPPROTO_UDP, 0);
+       }
+
        /* Queue the packet to IP for output */
        len = skb->len;
        rc = ip_queue_xmit(skb, 1);
@@ -1350,6 +1423,7 @@ static int pppol2tp_release(struct socket *sock)
                        kfree_skb(skb);
                        sock_put(sk);
                }
+               sock_put(sk);
        }
 
        release_sock(sk);