[IGMP]: workaround for IGMP v1/v2 bug
[safe/jmp/linux-2.6] / net / ipv4 / igmp.c
index 1f31831..4a195c7 100644 (file)
@@ -872,11 +872,18 @@ int igmp_rcv(struct sk_buff *skb)
                return 0;
        }
 
-       if (!pskb_may_pull(skb, sizeof(struct igmphdr)) || 
-           (u16)csum_fold(skb_checksum(skb, 0, len, 0))) {
-               in_dev_put(in_dev);
-               kfree_skb(skb);
-               return 0;
+       if (!pskb_may_pull(skb, sizeof(struct igmphdr)))
+               goto drop;
+
+       switch (skb->ip_summed) {
+       case CHECKSUM_HW:
+               if (!(u16)csum_fold(skb->csum))
+                       break;
+               /* fall through */
+       case CHECKSUM_NONE:
+               skb->csum = 0;
+               if (__skb_checksum_complete(skb))
+                       goto drop;
        }
 
        ih = skb->h.igmph;
@@ -890,7 +897,10 @@ int igmp_rcv(struct sk_buff *skb)
                /* Is it our report looped back? */
                if (((struct rtable*)skb->dst)->fl.iif == 0)
                        break;
-               igmp_heard_report(in_dev, ih->group);
+               /* don't rely on MC router hearing unicast reports */
+               if (skb->pkt_type == PACKET_MULTICAST ||
+                   skb->pkt_type == PACKET_BROADCAST)
+                       igmp_heard_report(in_dev, ih->group);
                break;
        case IGMP_PIM:
 #ifdef CONFIG_IP_PIMSM_V1
@@ -904,8 +914,10 @@ int igmp_rcv(struct sk_buff *skb)
        case IGMP_MTRACE_RESP:
                break;
        default:
-               NETDEBUG(printk(KERN_DEBUG "New IGMP type=%d, why we do not know about it?\n", ih->type));
+               NETDEBUG(KERN_DEBUG "New IGMP type=%d, why we do not know about it?\n", ih->type);
        }
+
+drop:
        in_dev_put(in_dev);
        kfree_skb(skb);
        return 0;
@@ -1323,7 +1335,7 @@ static struct in_device * ip_mc_find_dev(struct ip_mreqn *imr)
        }
        if (dev) {
                imr->imr_ifindex = dev->ifindex;
-               idev = __in_dev_get(dev);
+               idev = __in_dev_get_rtnl(dev);
        }
        return idev;
 }
@@ -1603,7 +1615,7 @@ static void ip_mc_clear_src(struct ip_mc_list *pmc)
        }
        pmc->sources = NULL;
        pmc->sfmode = MCAST_EXCLUDE;
-       pmc->sfcount[MCAST_EXCLUDE] = 0;
+       pmc->sfcount[MCAST_INCLUDE] = 0;
        pmc->sfcount[MCAST_EXCLUDE] = 1;
 }
 
@@ -1615,9 +1627,10 @@ int ip_mc_join_group(struct sock *sk , struct ip_mreqn *imr)
 {
        int err;
        u32 addr = imr->imr_multiaddr.s_addr;
-       struct ip_mc_socklist *iml, *i;
+       struct ip_mc_socklist *iml=NULL, *i;
        struct in_device *in_dev;
        struct inet_sock *inet = inet_sk(sk);
+       int ifindex;
        int count = 0;
 
        if (!MULTICAST(addr))
@@ -1633,37 +1646,30 @@ int ip_mc_join_group(struct sock *sk , struct ip_mreqn *imr)
                goto done;
        }
 
-       iml = (struct ip_mc_socklist *)sock_kmalloc(sk, sizeof(*iml), GFP_KERNEL);
-
        err = -EADDRINUSE;
+       ifindex = imr->imr_ifindex;
        for (i = inet->mc_list; i; i = i->next) {
-               if (memcmp(&i->multi, imr, sizeof(*imr)) == 0) {
-                       /* New style additions are reference counted */
-                       if (imr->imr_address.s_addr == 0) {
-                               i->count++;
-                               err = 0;
-                       }
+               if (i->multi.imr_multiaddr.s_addr == addr &&
+                   i->multi.imr_ifindex == ifindex)
                        goto done;
-               }
                count++;
        }
        err = -ENOBUFS;
-       if (iml == NULL || count >= sysctl_igmp_max_memberships)
+       if (count >= sysctl_igmp_max_memberships)
+               goto done;
+       iml = (struct ip_mc_socklist *)sock_kmalloc(sk,sizeof(*iml),GFP_KERNEL);
+       if (iml == NULL)
                goto done;
+
        memcpy(&iml->multi, imr, sizeof(*imr));
        iml->next = inet->mc_list;
-       iml->count = 1;
        iml->sflist = NULL;
        iml->sfmode = MCAST_EXCLUDE;
        inet->mc_list = iml;
        ip_mc_inc_group(in_dev, addr);
-       iml = NULL;
        err = 0;
-
 done:
        rtnl_shunlock();
-       if (iml)
-               sock_kfree_s(sk, iml, sizeof(*iml));
        return err;
 }
 
@@ -1693,30 +1699,25 @@ int ip_mc_leave_group(struct sock *sk, struct ip_mreqn *imr)
 {
        struct inet_sock *inet = inet_sk(sk);
        struct ip_mc_socklist *iml, **imlp;
+       struct in_device *in_dev;
+       u32 group = imr->imr_multiaddr.s_addr;
+       u32 ifindex;
 
        rtnl_lock();
+       in_dev = ip_mc_find_dev(imr);
+       if (!in_dev) {
+               rtnl_unlock();
+               return -ENODEV;
+       }
+       ifindex = imr->imr_ifindex;
        for (imlp = &inet->mc_list; (iml = *imlp) != NULL; imlp = &iml->next) {
-               if (iml->multi.imr_multiaddr.s_addr==imr->imr_multiaddr.s_addr &&
-                   iml->multi.imr_address.s_addr==imr->imr_address.s_addr &&
-                   (!imr->imr_ifindex || iml->multi.imr_ifindex==imr->imr_ifindex)) {
-                       struct in_device *in_dev;
-
-                       in_dev = inetdev_by_index(iml->multi.imr_ifindex);
-                       if (in_dev)
-                               (void) ip_mc_leave_src(sk, iml, in_dev);
-                       if (--iml->count) {
-                               rtnl_unlock();
-                               if (in_dev)
-                                       in_dev_put(in_dev);
-                               return 0;
-                       }
+               if (iml->multi.imr_multiaddr.s_addr == group &&
+                   iml->multi.imr_ifindex == ifindex) {
+                       (void) ip_mc_leave_src(sk, iml, in_dev);
 
                        *imlp = iml->next;
 
-                       if (in_dev) {
-                               ip_mc_dec_group(in_dev, imr->imr_multiaddr.s_addr);
-                               in_dev_put(in_dev);
-                       }
+                       ip_mc_dec_group(in_dev, group);
                        rtnl_unlock();
                        sock_kfree_s(sk, iml, sizeof(*iml));
                        return 0;
@@ -1736,6 +1737,7 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
        struct in_device *in_dev = NULL;
        struct inet_sock *inet = inet_sk(sk);
        struct ip_sf_socklist *psl;
+       int leavegroup = 0;
        int i, j, rv;
 
        if (!MULTICAST(addr))
@@ -1755,15 +1757,20 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
        err = -EADDRNOTAVAIL;
 
        for (pmc=inet->mc_list; pmc; pmc=pmc->next) {
-               if (memcmp(&pmc->multi, mreqs, 2*sizeof(__u32)) == 0)
+               if (pmc->multi.imr_multiaddr.s_addr == imr.imr_multiaddr.s_addr
+                   && pmc->multi.imr_ifindex == imr.imr_ifindex)
                        break;
        }
-       if (!pmc)               /* must have a prior join */
+       if (!pmc) {             /* must have a prior join */
+               err = -EINVAL;
                goto done;
+       }
        /* if a source filter was set, must be the same mode as before */
        if (pmc->sflist) {
-               if (pmc->sfmode != omode)
+               if (pmc->sfmode != omode) {
+                       err = -EINVAL;
                        goto done;
+               }
        } else if (pmc->sfmode != omode) {
                /* allow mode switches for empty-set filters */
                ip_mc_add_src(in_dev, &mreqs->imr_multiaddr, omode, 0, NULL, 0);
@@ -1775,7 +1782,7 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
        psl = pmc->sflist;
        if (!add) {
                if (!psl)
-                       goto done;
+                       goto done;      /* err = -EADDRNOTAVAIL */
                rv = !0;
                for (i=0; i<psl->sl_count; i++) {
                        rv = memcmp(&psl->sl_addr[i], &mreqs->imr_sourceaddr,
@@ -1784,7 +1791,13 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
                                break;
                }
                if (rv)         /* source not found */
+                       goto done;      /* err = -EADDRNOTAVAIL */
+
+               /* special case - (INCLUDE, empty) == LEAVE_GROUP */
+               if (psl->sl_count == 1 && omode == MCAST_INCLUDE) {
+                       leavegroup = 1;
                        goto done;
+               }
 
                /* update the interface filter */
                ip_mc_del_src(in_dev, &mreqs->imr_multiaddr, omode, 1, 
@@ -1842,18 +1855,21 @@ int ip_mc_source(int add, int omode, struct sock *sk, struct
                &mreqs->imr_sourceaddr, 1);
 done:
        rtnl_shunlock();
+       if (leavegroup)
+               return ip_mc_leave_group(sk, &imr);
        return err;
 }
 
 int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
 {
-       int err;
+       int err = 0;
        struct ip_mreqn imr;
        u32 addr = msf->imsf_multiaddr;
        struct ip_mc_socklist *pmc;
        struct in_device *in_dev;
        struct inet_sock *inet = inet_sk(sk);
        struct ip_sf_socklist *newpsl, *psl;
+       int leavegroup = 0;
 
        if (!MULTICAST(addr))
                return -EINVAL;
@@ -1872,15 +1888,22 @@ int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
                err = -ENODEV;
                goto done;
        }
-       err = -EADDRNOTAVAIL;
+
+       /* special case - (INCLUDE, empty) == LEAVE_GROUP */
+       if (msf->imsf_fmode == MCAST_INCLUDE && msf->imsf_numsrc == 0) {
+               leavegroup = 1;
+               goto done;
+       }
 
        for (pmc=inet->mc_list; pmc; pmc=pmc->next) {
                if (pmc->multi.imr_multiaddr.s_addr == msf->imsf_multiaddr &&
                    pmc->multi.imr_ifindex == imr.imr_ifindex)
                        break;
        }
-       if (!pmc)               /* must have a prior join */
+       if (!pmc) {             /* must have a prior join */
+               err = -EINVAL;
                goto done;
+       }
        if (msf->imsf_numsrc) {
                newpsl = (struct ip_sf_socklist *)sock_kmalloc(sk,
                                IP_SFLSIZE(msf->imsf_numsrc), GFP_KERNEL);
@@ -1897,8 +1920,11 @@ int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
                        sock_kfree_s(sk, newpsl, IP_SFLSIZE(newpsl->sl_max));
                        goto done;
                }
-       } else
+       } else {
                newpsl = NULL;
+               (void) ip_mc_add_src(in_dev, &msf->imsf_multiaddr,
+                                    msf->imsf_fmode, 0, NULL, 0);
+       }
        psl = pmc->sflist;
        if (psl) {
                (void) ip_mc_del_src(in_dev, &msf->imsf_multiaddr, pmc->sfmode,
@@ -1909,8 +1935,11 @@ int ip_mc_msfilter(struct sock *sk, struct ip_msfilter *msf, int ifindex)
                        0, NULL, 0);
        pmc->sflist = newpsl;
        pmc->sfmode = msf->imsf_fmode;
+       err = 0;
 done:
        rtnl_shunlock();
+       if (leavegroup)
+               err = ip_mc_leave_group(sk, &imr);
        return err;
 }