[IPSEC]: Add missing sg_init_table() calls to ESP.
[safe/jmp/linux-2.6] / net / ipv4 / esp4.c
index e87377e..ba98401 100644 (file)
@@ -8,6 +8,7 @@
 #include <linux/kernel.h>
 #include <linux/pfkeyv2.h>
 #include <linux/random.h>
+#include <linux/spinlock.h>
 #include <net/icmp.h>
 #include <net/protocol.h>
 #include <net/udp.h>
 static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
 {
        int err;
-       struct iphdr *top_iph;
        struct ip_esp_hdr *esph;
        struct crypto_blkcipher *tfm;
        struct blkcipher_desc desc;
        struct esp_data *esp;
        struct sk_buff *trailer;
+       u8 *tail;
        int blksize;
        int clen;
        int alen;
        int nfrags;
 
-       /* Strip IP+ESP header. */
-       __skb_pull(skb, skb->h.raw - skb->data);
-       /* Now skb is pure payload to encrypt */
+       /* skb is pure payload to encrypt */
 
        err = -ENOMEM;
 
@@ -49,30 +48,32 @@ static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
                goto error;
 
        /* Fill padding... */
+       tail = skb_tail_pointer(trailer);
        do {
                int i;
                for (i=0; i<clen-skb->len - 2; i++)
-                       *(u8*)(trailer->tail + i) = i+1;
+                       tail[i] = i + 1;
        } while (0);
-       *(u8*)(trailer->tail + clen-skb->len - 2) = (clen - skb->len)-2;
+       tail[clen - skb->len - 2] = (clen - skb->len) - 2;
        pskb_put(skb, trailer, clen - skb->len);
 
-       __skb_push(skb, skb->data - skb->nh.raw);
-       top_iph = skb->nh.iph;
-       esph = (struct ip_esp_hdr *)(skb->nh.raw + top_iph->ihl*4);
-       top_iph->tot_len = htons(skb->len + alen);
-       *(u8*)(trailer->tail - 1) = top_iph->protocol;
+       skb_push(skb, -skb_network_offset(skb));
+       esph = ip_esp_hdr(skb);
+       *(skb_tail_pointer(trailer) - 1) = *skb_mac_header(skb);
+       *skb_mac_header(skb) = IPPROTO_ESP;
+
+       spin_lock_bh(&x->lock);
 
        /* this is non-NULL only with UDP Encapsulation */
        if (x->encap) {
                struct xfrm_encap_tmpl *encap = x->encap;
                struct udphdr *uh;
-               u32 *udpdata32;
+               __be32 *udpdata32;
 
                uh = (struct udphdr *)esph;
                uh->source = encap->encap_sport;
                uh->dest = encap->encap_dport;
-               uh->len = htons(skb->len + alen - top_iph->ihl*4);
+               uh->len = htons(skb->len + alen - skb_transport_offset(skb));
                uh->check = 0;
 
                switch (encap->encap_type) {
@@ -81,22 +82,25 @@ static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
                        esph = (struct ip_esp_hdr *)(uh + 1);
                        break;
                case UDP_ENCAP_ESPINUDP_NON_IKE:
-                       udpdata32 = (u32 *)(uh + 1);
+                       udpdata32 = (__be32 *)(uh + 1);
                        udpdata32[0] = udpdata32[1] = 0;
                        esph = (struct ip_esp_hdr *)(udpdata32 + 2);
                        break;
                }
 
-               top_iph->protocol = IPPROTO_UDP;
-       } else
-               top_iph->protocol = IPPROTO_ESP;
+               *skb_mac_header(skb) = IPPROTO_UDP;
+       }
 
        esph->spi = x->id.spi;
-       esph->seq_no = htonl(++x->replay.oseq);
-       xfrm_aevent_doreplay(x);
+       esph->seq_no = htonl(XFRM_SKB_CB(skb)->seq);
 
-       if (esp->conf.ivlen)
+       if (esp->conf.ivlen) {
+               if (unlikely(!esp->conf.ivinitted)) {
+                       get_random_bytes(esp->conf.ivec, esp->conf.ivlen);
+                       esp->conf.ivinitted = 1;
+               }
                crypto_blkcipher_set_iv(tfm, esp->conf.ivec, esp->conf.ivlen);
+       }
 
        do {
                struct scatterlist *sg = &esp->sgbuf[0];
@@ -104,8 +108,9 @@ static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
                if (unlikely(nfrags > ESP_NUM_FAST_SG)) {
                        sg = kmalloc(sizeof(struct scatterlist)*nfrags, GFP_ATOMIC);
                        if (!sg)
-                               goto error;
+                               goto unlock;
                }
+               sg_init_table(sg, nfrags);
                skb_to_sgvec(skb, sg, esph->enc_data+esp->conf.ivlen-skb->data, clen);
                err = crypto_blkcipher_encrypt(&desc, sg, sg, clen);
                if (unlikely(sg != &esp->sgbuf[0]))
@@ -113,7 +118,7 @@ static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
        } while (0);
 
        if (unlikely(err))
-               goto error;
+               goto unlock;
 
        if (esp->conf.ivlen) {
                memcpy(esph->enc_data, esp->conf.ivec, esp->conf.ivlen);
@@ -126,7 +131,8 @@ static int esp_output(struct xfrm_state *x, struct sk_buff *skb)
                memcpy(pskb_put(skb, trailer, alen), esp->auth.work_icv, alen);
        }
 
-       ip_send_check(top_iph);
+unlock:
+       spin_unlock_bh(&x->lock);
 
 error:
        return err;
@@ -147,7 +153,7 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
        struct sk_buff *trailer;
        int blksize = ALIGN(crypto_blkcipher_blocksize(tfm), 4);
        int alen = esp->auth.icv_trunc_len;
-       int elen = skb->len - sizeof(struct ip_esp_hdr) - esp->conf.ivlen - alen;
+       int elen = skb->len - sizeof(*esph) - esp->conf.ivlen - alen;
        int nfrags;
        int ihl;
        u8 nexthdr[2];
@@ -155,7 +161,7 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
        int padlen;
        int err;
 
-       if (!pskb_may_pull(skb, sizeof(struct ip_esp_hdr)))
+       if (!pskb_may_pull(skb, sizeof(*esph)))
                goto out;
 
        if (elen <= 0 || (elen & (blksize-1)))
@@ -183,7 +189,7 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
 
        skb->ip_summed = CHECKSUM_NONE;
 
-       esph = (struct ip_esp_hdr*)skb->data;
+       esph = (struct ip_esp_hdr *)skb->data;
 
        /* Get ivec. This can be wrong, check against another impls. */
        if (esp->conf.ivlen)
@@ -196,7 +202,8 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
                if (!sg)
                        goto out;
        }
-       skb_to_sgvec(skb, sg, sizeof(struct ip_esp_hdr) + esp->conf.ivlen, elen);
+       sg_init_table(sg, nfrags);
+       skb_to_sgvec(skb, sg, sizeof(*esph) + esp->conf.ivlen, elen);
        err = crypto_blkcipher_decrypt(&desc, sg, sg, elen);
        if (unlikely(sg != &esp->sgbuf[0]))
                kfree(sg);
@@ -210,14 +217,14 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
        if (padlen+2 >= elen)
                goto out;
 
-       /* ... check padding bits here. Silly. :-) */ 
+       /* ... check padding bits here. Silly. :-) */
 
-       iph = skb->nh.iph;
+       iph = ip_hdr(skb);
        ihl = iph->ihl * 4;
 
        if (x->encap) {
                struct xfrm_encap_tmpl *encap = x->encap;
-               struct udphdr *uh = (void *)(skb->nh.raw + ihl);
+               struct udphdr *uh = (void *)(skb_network_header(skb) + ihl);
 
                /*
                 * 1) if the NAT-T peer's IP or port changed then
@@ -231,7 +238,7 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
 
                        ipaddr.a4 = iph->saddr;
                        km_new_mapping(x, &ipaddr, uh->source);
-                               
+
                        /* XXX: perhaps add an extra
                         * policy check here, to see
                         * if we should allow or
@@ -240,7 +247,7 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
                         * address/port.
                         */
                }
-       
+
                /*
                 * 2) ignore UDP/TCP checksums in case
                 *    of NAT-T in Transport Mode, or
@@ -252,31 +259,43 @@ static int esp_input(struct xfrm_state *x, struct sk_buff *skb)
                        skb->ip_summed = CHECKSUM_UNNECESSARY;
        }
 
-       iph->protocol = nexthdr[1];
        pskb_trim(skb, skb->len - alen - padlen - 2);
-       skb->h.raw = __skb_pull(skb, sizeof(*esph) + esp->conf.ivlen) - ihl;
+       __skb_pull(skb, sizeof(*esph) + esp->conf.ivlen);
+       skb_set_transport_header(skb, -ihl);
 
-       return 0;
+       return nexthdr[1];
 
 out:
        return -EINVAL;
 }
 
-static u32 esp4_get_max_size(struct xfrm_state *x, int mtu)
+static u32 esp4_get_mtu(struct xfrm_state *x, int mtu)
 {
        struct esp_data *esp = x->data;
        u32 blksize = ALIGN(crypto_blkcipher_blocksize(esp->conf.tfm), 4);
-
-       if (x->props.mode == XFRM_MODE_TUNNEL) {
-               mtu = ALIGN(mtu + 2, blksize);
-       } else {
+       u32 align = max_t(u32, blksize, esp->conf.padlen);
+       u32 rem;
+
+       mtu -= x->props.header_len + esp->auth.icv_trunc_len;
+       rem = mtu & (align - 1);
+       mtu &= ~(align - 1);
+
+       switch (x->props.mode) {
+       case XFRM_MODE_TUNNEL:
+               break;
+       default:
+       case XFRM_MODE_TRANSPORT:
+               /* The worst case */
+               mtu -= blksize - 4;
+               mtu += min_t(u32, blksize - 4, rem);
+               break;
+       case XFRM_MODE_BEET:
                /* The worst case. */
-               mtu = ALIGN(mtu + 2, 4) + blksize - 4;
+               mtu += min_t(u32, IPV4_BEET_PHMAXLEN, rem);
+               break;
        }
-       if (esp->conf.padlen)
-               mtu = ALIGN(mtu, esp->conf.padlen);
 
-       return mtu + x->props.header_len + esp->auth.icv_trunc_len;
+       return mtu - 2;
 }
 
 static void esp4_err(struct sk_buff *skb, u32 info)
@@ -285,8 +304,8 @@ static void esp4_err(struct sk_buff *skb, u32 info)
        struct ip_esp_hdr *esph = (struct ip_esp_hdr*)(skb->data+(iph->ihl<<2));
        struct xfrm_state *x;
 
-       if (skb->h.icmph->type != ICMP_DEST_UNREACH ||
-           skb->h.icmph->code != ICMP_FRAG_NEEDED)
+       if (icmp_hdr(skb)->type != ICMP_DEST_UNREACH ||
+           icmp_hdr(skb)->code != ICMP_FRAG_NEEDED)
                return;
 
        x = xfrm_state_lookup((xfrm_address_t *)&iph->daddr, esph->spi, IPPROTO_ESP, AF_INET);
@@ -319,12 +338,8 @@ static int esp_init_state(struct xfrm_state *x)
 {
        struct esp_data *esp = NULL;
        struct crypto_blkcipher *tfm;
+       u32 align;
 
-       /* null auth and encryption can have zero length keys */
-       if (x->aalg) {
-               if (x->aalg->alg_key_len > 512)
-                       goto error;
-       }
        if (x->ealg == NULL)
                goto error;
 
@@ -336,15 +351,14 @@ static int esp_init_state(struct xfrm_state *x)
                struct xfrm_algo_desc *aalg_desc;
                struct crypto_hash *hash;
 
-               esp->auth.key = x->aalg->alg_key;
-               esp->auth.key_len = (x->aalg->alg_key_len+7)/8;
                hash = crypto_alloc_hash(x->aalg->alg_name, 0,
                                         CRYPTO_ALG_ASYNC);
                if (IS_ERR(hash))
                        goto error;
 
                esp->auth.tfm = hash;
-               if (crypto_hash_setkey(hash, esp->auth.key, esp->auth.key_len))
+               if (crypto_hash_setkey(hash, x->aalg->alg_key,
+                                      (x->aalg->alg_key_len + 7) / 8))
                        goto error;
 
                aalg_desc = xfrm_aalg_get_byname(x->aalg->alg_name, 0);
@@ -366,8 +380,7 @@ static int esp_init_state(struct xfrm_state *x)
                if (!esp->auth.work_icv)
                        goto error;
        }
-       esp->conf.key = x->ealg->alg_key;
-       esp->conf.key_len = (x->ealg->alg_key_len+7)/8;
+
        tfm = crypto_alloc_blkcipher(x->ealg->alg_name, 0, CRYPTO_ALG_ASYNC);
        if (IS_ERR(tfm))
                goto error;
@@ -378,13 +391,16 @@ static int esp_init_state(struct xfrm_state *x)
                esp->conf.ivec = kmalloc(esp->conf.ivlen, GFP_KERNEL);
                if (unlikely(esp->conf.ivec == NULL))
                        goto error;
-               get_random_bytes(esp->conf.ivec, esp->conf.ivlen);
+               esp->conf.ivinitted = 0;
        }
-       if (crypto_blkcipher_setkey(tfm, esp->conf.key, esp->conf.key_len))
+       if (crypto_blkcipher_setkey(tfm, x->ealg->alg_key,
+                                   (x->ealg->alg_key_len + 7) / 8))
                goto error;
        x->props.header_len = sizeof(struct ip_esp_hdr) + esp->conf.ivlen;
        if (x->props.mode == XFRM_MODE_TUNNEL)
                x->props.header_len += sizeof(struct iphdr);
+       else if (x->props.mode == XFRM_MODE_BEET)
+               x->props.header_len += IPV4_BEET_PHMAXLEN;
        if (x->encap) {
                struct xfrm_encap_tmpl *encap = x->encap;
 
@@ -400,7 +416,10 @@ static int esp_init_state(struct xfrm_state *x)
                }
        }
        x->data = esp;
-       x->props.trailer_len = esp4_get_max_size(x, 0) - x->props.header_len;
+       align = ALIGN(crypto_blkcipher_blocksize(esp->conf.tfm), 4);
+       if (esp->conf.padlen)
+               align = max_t(u32, align, esp->conf.padlen);
+       x->props.trailer_len = align + 1 + esp->auth.icv_trunc_len;
        return 0;
 
 error:
@@ -415,9 +434,10 @@ static struct xfrm_type esp_type =
        .description    = "ESP4",
        .owner          = THIS_MODULE,
        .proto          = IPPROTO_ESP,
+       .flags          = XFRM_TYPE_REPLAY_PROT,
        .init_state     = esp_init_state,
        .destructor     = esp_destroy,
-       .get_max_size   = esp4_get_max_size,
+       .get_mtu        = esp4_get_mtu,
        .input          = esp_input,
        .output         = esp_output
 };
@@ -453,3 +473,4 @@ static void __exit esp4_fini(void)
 module_init(esp4_init);
 module_exit(esp4_fini);
 MODULE_LICENSE("GPL");
+MODULE_ALIAS_XFRM_TYPE(AF_INET, XFRM_PROTO_ESP);