bridge br_multicast: Ensure to initialize BR_INPUT_SKB_CB(skb)->mrouters_only.
[safe/jmp/linux-2.6] / net / bridge / br_multicast.c
1 /*
2  * Bridge multicast support.
3  *
4  * Copyright (c) 2010 Herbert Xu <herbert@gondor.apana.org.au>
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms of the GNU General Public License as published by the Free
8  * Software Foundation; either version 2 of the License, or (at your option)
9  * any later version.
10  *
11  */
12
13 #include <linux/err.h>
14 #include <linux/if_ether.h>
15 #include <linux/igmp.h>
16 #include <linux/jhash.h>
17 #include <linux/kernel.h>
18 #include <linux/log2.h>
19 #include <linux/netdevice.h>
20 #include <linux/netfilter_bridge.h>
21 #include <linux/random.h>
22 #include <linux/rculist.h>
23 #include <linux/skbuff.h>
24 #include <linux/slab.h>
25 #include <linux/timer.h>
26 #include <net/ip.h>
27 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
28 #include <net/ipv6.h>
29 #include <net/mld.h>
30 #include <net/addrconf.h>
31 #endif
32
33 #include "br_private.h"
34
35 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
36 static inline int ipv6_is_local_multicast(const struct in6_addr *addr)
37 {
38         if (ipv6_addr_is_multicast(addr) &&
39             IPV6_ADDR_MC_SCOPE(addr) <= IPV6_ADDR_SCOPE_LINKLOCAL)
40                 return 1;
41         return 0;
42 }
43 #endif
44
45 static inline int br_ip_equal(const struct br_ip *a, const struct br_ip *b)
46 {
47         if (a->proto != b->proto)
48                 return 0;
49         switch (a->proto) {
50         case htons(ETH_P_IP):
51                 return a->u.ip4 == b->u.ip4;
52 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
53         case htons(ETH_P_IPV6):
54                 return ipv6_addr_equal(&a->u.ip6, &b->u.ip6);
55 #endif
56         }
57         return 0;
58 }
59
60 static inline int __br_ip4_hash(struct net_bridge_mdb_htable *mdb, __be32 ip)
61 {
62         return jhash_1word(mdb->secret, (__force u32)ip) & (mdb->max - 1);
63 }
64
65 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
66 static inline int __br_ip6_hash(struct net_bridge_mdb_htable *mdb,
67                                 const struct in6_addr *ip)
68 {
69         return jhash2((__force u32 *)ip->s6_addr32, 4, mdb->secret) & (mdb->max - 1);
70 }
71 #endif
72
73 static inline int br_ip_hash(struct net_bridge_mdb_htable *mdb,
74                              struct br_ip *ip)
75 {
76         switch (ip->proto) {
77         case htons(ETH_P_IP):
78                 return __br_ip4_hash(mdb, ip->u.ip4);
79 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
80         case htons(ETH_P_IPV6):
81                 return __br_ip6_hash(mdb, &ip->u.ip6);
82 #endif
83         }
84         return 0;
85 }
86
87 static struct net_bridge_mdb_entry *__br_mdb_ip_get(
88         struct net_bridge_mdb_htable *mdb, struct br_ip *dst, int hash)
89 {
90         struct net_bridge_mdb_entry *mp;
91         struct hlist_node *p;
92
93         hlist_for_each_entry_rcu(mp, p, &mdb->mhash[hash], hlist[mdb->ver]) {
94                 if (br_ip_equal(&mp->addr, dst))
95                         return mp;
96         }
97
98         return NULL;
99 }
100
101 static struct net_bridge_mdb_entry *br_mdb_ip4_get(
102         struct net_bridge_mdb_htable *mdb, __be32 dst)
103 {
104         struct br_ip br_dst;
105
106         br_dst.u.ip4 = dst;
107         br_dst.proto = htons(ETH_P_IP);
108
109         return __br_mdb_ip_get(mdb, &br_dst, __br_ip4_hash(mdb, dst));
110 }
111
112 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
113 static struct net_bridge_mdb_entry *br_mdb_ip6_get(
114         struct net_bridge_mdb_htable *mdb, const struct in6_addr *dst)
115 {
116         struct br_ip br_dst;
117
118         ipv6_addr_copy(&br_dst.u.ip6, dst);
119         br_dst.proto = htons(ETH_P_IPV6);
120
121         return __br_mdb_ip_get(mdb, &br_dst, __br_ip6_hash(mdb, dst));
122 }
123 #endif
124
125 static struct net_bridge_mdb_entry *br_mdb_ip_get(
126         struct net_bridge_mdb_htable *mdb, struct br_ip *dst)
127 {
128         return __br_mdb_ip_get(mdb, dst, br_ip_hash(mdb, dst));
129 }
130
131 struct net_bridge_mdb_entry *br_mdb_get(struct net_bridge *br,
132                                         struct sk_buff *skb)
133 {
134         struct net_bridge_mdb_htable *mdb = br->mdb;
135         struct br_ip ip;
136
137         if (!mdb || br->multicast_disabled)
138                 return NULL;
139
140         if (BR_INPUT_SKB_CB(skb)->igmp)
141                 return NULL;
142
143         ip.proto = skb->protocol;
144
145         switch (skb->protocol) {
146         case htons(ETH_P_IP):
147                 ip.u.ip4 = ip_hdr(skb)->daddr;
148                 break;
149 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
150         case htons(ETH_P_IPV6):
151                 ipv6_addr_copy(&ip.u.ip6, &ipv6_hdr(skb)->daddr);
152                 break;
153 #endif
154         default:
155                 return NULL;
156         }
157
158         return br_mdb_ip_get(mdb, &ip);
159 }
160
161 static void br_mdb_free(struct rcu_head *head)
162 {
163         struct net_bridge_mdb_htable *mdb =
164                 container_of(head, struct net_bridge_mdb_htable, rcu);
165         struct net_bridge_mdb_htable *old = mdb->old;
166
167         mdb->old = NULL;
168         kfree(old->mhash);
169         kfree(old);
170 }
171
172 static int br_mdb_copy(struct net_bridge_mdb_htable *new,
173                        struct net_bridge_mdb_htable *old,
174                        int elasticity)
175 {
176         struct net_bridge_mdb_entry *mp;
177         struct hlist_node *p;
178         int maxlen;
179         int len;
180         int i;
181
182         for (i = 0; i < old->max; i++)
183                 hlist_for_each_entry(mp, p, &old->mhash[i], hlist[old->ver])
184                         hlist_add_head(&mp->hlist[new->ver],
185                                        &new->mhash[br_ip_hash(new, &mp->addr)]);
186
187         if (!elasticity)
188                 return 0;
189
190         maxlen = 0;
191         for (i = 0; i < new->max; i++) {
192                 len = 0;
193                 hlist_for_each_entry(mp, p, &new->mhash[i], hlist[new->ver])
194                         len++;
195                 if (len > maxlen)
196                         maxlen = len;
197         }
198
199         return maxlen > elasticity ? -EINVAL : 0;
200 }
201
202 static void br_multicast_free_pg(struct rcu_head *head)
203 {
204         struct net_bridge_port_group *p =
205                 container_of(head, struct net_bridge_port_group, rcu);
206
207         kfree(p);
208 }
209
210 static void br_multicast_free_group(struct rcu_head *head)
211 {
212         struct net_bridge_mdb_entry *mp =
213                 container_of(head, struct net_bridge_mdb_entry, rcu);
214
215         kfree(mp);
216 }
217
218 static void br_multicast_group_expired(unsigned long data)
219 {
220         struct net_bridge_mdb_entry *mp = (void *)data;
221         struct net_bridge *br = mp->br;
222         struct net_bridge_mdb_htable *mdb;
223
224         spin_lock(&br->multicast_lock);
225         if (!netif_running(br->dev) || timer_pending(&mp->timer))
226                 goto out;
227
228         if (!hlist_unhashed(&mp->mglist))
229                 hlist_del_init(&mp->mglist);
230
231         if (mp->ports)
232                 goto out;
233
234         mdb = br->mdb;
235         hlist_del_rcu(&mp->hlist[mdb->ver]);
236         mdb->size--;
237
238         del_timer(&mp->query_timer);
239         call_rcu_bh(&mp->rcu, br_multicast_free_group);
240
241 out:
242         spin_unlock(&br->multicast_lock);
243 }
244
245 static void br_multicast_del_pg(struct net_bridge *br,
246                                 struct net_bridge_port_group *pg)
247 {
248         struct net_bridge_mdb_htable *mdb = br->mdb;
249         struct net_bridge_mdb_entry *mp;
250         struct net_bridge_port_group *p;
251         struct net_bridge_port_group **pp;
252
253         mp = br_mdb_ip_get(mdb, &pg->addr);
254         if (WARN_ON(!mp))
255                 return;
256
257         for (pp = &mp->ports; (p = *pp); pp = &p->next) {
258                 if (p != pg)
259                         continue;
260
261                 *pp = p->next;
262                 hlist_del_init(&p->mglist);
263                 del_timer(&p->timer);
264                 del_timer(&p->query_timer);
265                 call_rcu_bh(&p->rcu, br_multicast_free_pg);
266
267                 if (!mp->ports && hlist_unhashed(&mp->mglist) &&
268                     netif_running(br->dev))
269                         mod_timer(&mp->timer, jiffies);
270
271                 return;
272         }
273
274         WARN_ON(1);
275 }
276
277 static void br_multicast_port_group_expired(unsigned long data)
278 {
279         struct net_bridge_port_group *pg = (void *)data;
280         struct net_bridge *br = pg->port->br;
281
282         spin_lock(&br->multicast_lock);
283         if (!netif_running(br->dev) || timer_pending(&pg->timer) ||
284             hlist_unhashed(&pg->mglist))
285                 goto out;
286
287         br_multicast_del_pg(br, pg);
288
289 out:
290         spin_unlock(&br->multicast_lock);
291 }
292
293 static int br_mdb_rehash(struct net_bridge_mdb_htable **mdbp, int max,
294                          int elasticity)
295 {
296         struct net_bridge_mdb_htable *old = *mdbp;
297         struct net_bridge_mdb_htable *mdb;
298         int err;
299
300         mdb = kmalloc(sizeof(*mdb), GFP_ATOMIC);
301         if (!mdb)
302                 return -ENOMEM;
303
304         mdb->max = max;
305         mdb->old = old;
306
307         mdb->mhash = kzalloc(max * sizeof(*mdb->mhash), GFP_ATOMIC);
308         if (!mdb->mhash) {
309                 kfree(mdb);
310                 return -ENOMEM;
311         }
312
313         mdb->size = old ? old->size : 0;
314         mdb->ver = old ? old->ver ^ 1 : 0;
315
316         if (!old || elasticity)
317                 get_random_bytes(&mdb->secret, sizeof(mdb->secret));
318         else
319                 mdb->secret = old->secret;
320
321         if (!old)
322                 goto out;
323
324         err = br_mdb_copy(mdb, old, elasticity);
325         if (err) {
326                 kfree(mdb->mhash);
327                 kfree(mdb);
328                 return err;
329         }
330
331         call_rcu_bh(&mdb->rcu, br_mdb_free);
332
333 out:
334         rcu_assign_pointer(*mdbp, mdb);
335
336         return 0;
337 }
338
339 static struct sk_buff *br_ip4_multicast_alloc_query(struct net_bridge *br,
340                                                     __be32 group)
341 {
342         struct sk_buff *skb;
343         struct igmphdr *ih;
344         struct ethhdr *eth;
345         struct iphdr *iph;
346
347         skb = netdev_alloc_skb_ip_align(br->dev, sizeof(*eth) + sizeof(*iph) +
348                                                  sizeof(*ih) + 4);
349         if (!skb)
350                 goto out;
351
352         skb->protocol = htons(ETH_P_IP);
353
354         skb_reset_mac_header(skb);
355         eth = eth_hdr(skb);
356
357         memcpy(eth->h_source, br->dev->dev_addr, 6);
358         eth->h_dest[0] = 1;
359         eth->h_dest[1] = 0;
360         eth->h_dest[2] = 0x5e;
361         eth->h_dest[3] = 0;
362         eth->h_dest[4] = 0;
363         eth->h_dest[5] = 1;
364         eth->h_proto = htons(ETH_P_IP);
365         skb_put(skb, sizeof(*eth));
366
367         skb_set_network_header(skb, skb->len);
368         iph = ip_hdr(skb);
369
370         iph->version = 4;
371         iph->ihl = 6;
372         iph->tos = 0xc0;
373         iph->tot_len = htons(sizeof(*iph) + sizeof(*ih) + 4);
374         iph->id = 0;
375         iph->frag_off = htons(IP_DF);
376         iph->ttl = 1;
377         iph->protocol = IPPROTO_IGMP;
378         iph->saddr = 0;
379         iph->daddr = htonl(INADDR_ALLHOSTS_GROUP);
380         ((u8 *)&iph[1])[0] = IPOPT_RA;
381         ((u8 *)&iph[1])[1] = 4;
382         ((u8 *)&iph[1])[2] = 0;
383         ((u8 *)&iph[1])[3] = 0;
384         ip_send_check(iph);
385         skb_put(skb, 24);
386
387         skb_set_transport_header(skb, skb->len);
388         ih = igmp_hdr(skb);
389         ih->type = IGMP_HOST_MEMBERSHIP_QUERY;
390         ih->code = (group ? br->multicast_last_member_interval :
391                             br->multicast_query_response_interval) /
392                    (HZ / IGMP_TIMER_SCALE);
393         ih->group = group;
394         ih->csum = 0;
395         ih->csum = ip_compute_csum((void *)ih, sizeof(struct igmphdr));
396         skb_put(skb, sizeof(*ih));
397
398         __skb_pull(skb, sizeof(*eth));
399
400 out:
401         return skb;
402 }
403
404 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
405 static struct sk_buff *br_ip6_multicast_alloc_query(struct net_bridge *br,
406                                                     struct in6_addr *group)
407 {
408         struct sk_buff *skb;
409         struct ipv6hdr *ip6h;
410         struct mld_msg *mldq;
411         struct ethhdr *eth;
412         u8 *hopopt;
413         unsigned long interval;
414
415         skb = netdev_alloc_skb_ip_align(br->dev, sizeof(*eth) + sizeof(*ip6h) +
416                                                  8 + sizeof(*mldq));
417         if (!skb)
418                 goto out;
419
420         skb->protocol = htons(ETH_P_IPV6);
421
422         /* Ethernet header */
423         skb_reset_mac_header(skb);
424         eth = eth_hdr(skb);
425
426         memcpy(eth->h_source, br->dev->dev_addr, 6);
427         ipv6_eth_mc_map(group, eth->h_dest);
428         eth->h_proto = htons(ETH_P_IPV6);
429         skb_put(skb, sizeof(*eth));
430
431         /* IPv6 header + HbH option */
432         skb_set_network_header(skb, skb->len);
433         ip6h = ipv6_hdr(skb);
434
435         *(__force __be32 *)ip6h = htonl(0x60000000);
436         ip6h->payload_len = 8 + sizeof(*mldq);
437         ip6h->nexthdr = IPPROTO_HOPOPTS;
438         ip6h->hop_limit = 1;
439         ipv6_addr_set(&ip6h->saddr, 0, 0, 0, 0);
440         ipv6_addr_set(&ip6h->daddr, htonl(0xff020000), 0, 0, htonl(1));
441
442         hopopt = (u8 *)(ip6h + 1);
443         hopopt[0] = IPPROTO_ICMPV6;             /* next hdr */
444         hopopt[1] = 0;                          /* length of HbH */
445         hopopt[2] = IPV6_TLV_ROUTERALERT;       /* Router Alert */
446         hopopt[3] = 2;                          /* Length of RA Option */
447         hopopt[4] = 0;                          /* Type = 0x0000 (MLD) */
448         hopopt[5] = 0;
449         hopopt[6] = IPV6_TLV_PAD0;              /* Pad0 */
450         hopopt[7] = IPV6_TLV_PAD0;              /* Pad0 */
451
452         skb_put(skb, sizeof(*ip6h) + 8);
453
454         /* ICMPv6 */
455         skb_set_transport_header(skb, skb->len);
456         mldq = (struct mld_msg *) icmp6_hdr(skb);
457
458         interval = ipv6_addr_any(group) ? br->multicast_last_member_interval :
459                                           br->multicast_query_response_interval;
460
461         mldq->mld_type = ICMPV6_MGM_QUERY;
462         mldq->mld_code = 0;
463         mldq->mld_cksum = 0;
464         mldq->mld_maxdelay = htons((u16)jiffies_to_msecs(interval));
465         mldq->mld_reserved = 0;
466         ipv6_addr_copy(&mldq->mld_mca, group);
467
468         /* checksum */
469         mldq->mld_cksum = csum_ipv6_magic(&ip6h->saddr, &ip6h->daddr,
470                                           sizeof(*mldq), IPPROTO_ICMPV6,
471                                           csum_partial(mldq,
472                                                        sizeof(*mldq), 0));
473         skb_put(skb, sizeof(*mldq));
474
475         __skb_pull(skb, sizeof(*eth));
476
477 out:
478         return skb;
479 }
480 #endif
481
482 static struct sk_buff *br_multicast_alloc_query(struct net_bridge *br,
483                                                 struct br_ip *addr)
484 {
485         switch (addr->proto) {
486         case htons(ETH_P_IP):
487                 return br_ip4_multicast_alloc_query(br, addr->u.ip4);
488 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
489         case htons(ETH_P_IPV6):
490                 return br_ip6_multicast_alloc_query(br, &addr->u.ip6);
491 #endif
492         }
493         return NULL;
494 }
495
496 static void br_multicast_send_group_query(struct net_bridge_mdb_entry *mp)
497 {
498         struct net_bridge *br = mp->br;
499         struct sk_buff *skb;
500
501         skb = br_multicast_alloc_query(br, &mp->addr);
502         if (!skb)
503                 goto timer;
504
505         netif_rx(skb);
506
507 timer:
508         if (++mp->queries_sent < br->multicast_last_member_count)
509                 mod_timer(&mp->query_timer,
510                           jiffies + br->multicast_last_member_interval);
511 }
512
513 static void br_multicast_group_query_expired(unsigned long data)
514 {
515         struct net_bridge_mdb_entry *mp = (void *)data;
516         struct net_bridge *br = mp->br;
517
518         spin_lock(&br->multicast_lock);
519         if (!netif_running(br->dev) || hlist_unhashed(&mp->mglist) ||
520             mp->queries_sent >= br->multicast_last_member_count)
521                 goto out;
522
523         br_multicast_send_group_query(mp);
524
525 out:
526         spin_unlock(&br->multicast_lock);
527 }
528
529 static void br_multicast_send_port_group_query(struct net_bridge_port_group *pg)
530 {
531         struct net_bridge_port *port = pg->port;
532         struct net_bridge *br = port->br;
533         struct sk_buff *skb;
534
535         skb = br_multicast_alloc_query(br, &pg->addr);
536         if (!skb)
537                 goto timer;
538
539         br_deliver(port, skb);
540
541 timer:
542         if (++pg->queries_sent < br->multicast_last_member_count)
543                 mod_timer(&pg->query_timer,
544                           jiffies + br->multicast_last_member_interval);
545 }
546
547 static void br_multicast_port_group_query_expired(unsigned long data)
548 {
549         struct net_bridge_port_group *pg = (void *)data;
550         struct net_bridge_port *port = pg->port;
551         struct net_bridge *br = port->br;
552
553         spin_lock(&br->multicast_lock);
554         if (!netif_running(br->dev) || hlist_unhashed(&pg->mglist) ||
555             pg->queries_sent >= br->multicast_last_member_count)
556                 goto out;
557
558         br_multicast_send_port_group_query(pg);
559
560 out:
561         spin_unlock(&br->multicast_lock);
562 }
563
564 static struct net_bridge_mdb_entry *br_multicast_get_group(
565         struct net_bridge *br, struct net_bridge_port *port,
566         struct br_ip *group, int hash)
567 {
568         struct net_bridge_mdb_htable *mdb = br->mdb;
569         struct net_bridge_mdb_entry *mp;
570         struct hlist_node *p;
571         unsigned count = 0;
572         unsigned max;
573         int elasticity;
574         int err;
575
576         hlist_for_each_entry(mp, p, &mdb->mhash[hash], hlist[mdb->ver]) {
577                 count++;
578                 if (unlikely(br_ip_equal(group, &mp->addr)))
579                         return mp;
580         }
581
582         elasticity = 0;
583         max = mdb->max;
584
585         if (unlikely(count > br->hash_elasticity && count)) {
586                 if (net_ratelimit())
587                         printk(KERN_INFO "%s: Multicast hash table "
588                                "chain limit reached: %s\n",
589                                br->dev->name, port ? port->dev->name :
590                                                      br->dev->name);
591
592                 elasticity = br->hash_elasticity;
593         }
594
595         if (mdb->size >= max) {
596                 max *= 2;
597                 if (unlikely(max >= br->hash_max)) {
598                         printk(KERN_WARNING "%s: Multicast hash table maximum "
599                                "reached, disabling snooping: %s, %d\n",
600                                br->dev->name, port ? port->dev->name :
601                                                      br->dev->name,
602                                max);
603                         err = -E2BIG;
604 disable:
605                         br->multicast_disabled = 1;
606                         goto err;
607                 }
608         }
609
610         if (max > mdb->max || elasticity) {
611                 if (mdb->old) {
612                         if (net_ratelimit())
613                                 printk(KERN_INFO "%s: Multicast hash table "
614                                        "on fire: %s\n",
615                                        br->dev->name, port ? port->dev->name :
616                                                              br->dev->name);
617                         err = -EEXIST;
618                         goto err;
619                 }
620
621                 err = br_mdb_rehash(&br->mdb, max, elasticity);
622                 if (err) {
623                         printk(KERN_WARNING "%s: Cannot rehash multicast "
624                                "hash table, disabling snooping: "
625                                "%s, %d, %d\n",
626                                br->dev->name, port ? port->dev->name :
627                                                      br->dev->name,
628                                mdb->size, err);
629                         goto disable;
630                 }
631
632                 err = -EAGAIN;
633                 goto err;
634         }
635
636         return NULL;
637
638 err:
639         mp = ERR_PTR(err);
640         return mp;
641 }
642
643 static struct net_bridge_mdb_entry *br_multicast_new_group(
644         struct net_bridge *br, struct net_bridge_port *port,
645         struct br_ip *group)
646 {
647         struct net_bridge_mdb_htable *mdb = br->mdb;
648         struct net_bridge_mdb_entry *mp;
649         int hash;
650
651         if (!mdb) {
652                 if (br_mdb_rehash(&br->mdb, BR_HASH_SIZE, 0))
653                         return NULL;
654                 goto rehash;
655         }
656
657         hash = br_ip_hash(mdb, group);
658         mp = br_multicast_get_group(br, port, group, hash);
659         switch (PTR_ERR(mp)) {
660         case 0:
661                 break;
662
663         case -EAGAIN:
664 rehash:
665                 mdb = br->mdb;
666                 hash = br_ip_hash(mdb, group);
667                 break;
668
669         default:
670                 goto out;
671         }
672
673         mp = kzalloc(sizeof(*mp), GFP_ATOMIC);
674         if (unlikely(!mp))
675                 goto out;
676
677         mp->br = br;
678         mp->addr = *group;
679         setup_timer(&mp->timer, br_multicast_group_expired,
680                     (unsigned long)mp);
681         setup_timer(&mp->query_timer, br_multicast_group_query_expired,
682                     (unsigned long)mp);
683
684         hlist_add_head_rcu(&mp->hlist[mdb->ver], &mdb->mhash[hash]);
685         mdb->size++;
686
687 out:
688         return mp;
689 }
690
691 static int br_multicast_add_group(struct net_bridge *br,
692                                   struct net_bridge_port *port,
693                                   struct br_ip *group)
694 {
695         struct net_bridge_mdb_entry *mp;
696         struct net_bridge_port_group *p;
697         struct net_bridge_port_group **pp;
698         unsigned long now = jiffies;
699         int err;
700
701         spin_lock(&br->multicast_lock);
702         if (!netif_running(br->dev) ||
703             (port && port->state == BR_STATE_DISABLED))
704                 goto out;
705
706         mp = br_multicast_new_group(br, port, group);
707         err = PTR_ERR(mp);
708         if (unlikely(IS_ERR(mp) || !mp))
709                 goto err;
710
711         if (!port) {
712                 hlist_add_head(&mp->mglist, &br->mglist);
713                 mod_timer(&mp->timer, now + br->multicast_membership_interval);
714                 goto out;
715         }
716
717         for (pp = &mp->ports; (p = *pp); pp = &p->next) {
718                 if (p->port == port)
719                         goto found;
720                 if ((unsigned long)p->port < (unsigned long)port)
721                         break;
722         }
723
724         p = kzalloc(sizeof(*p), GFP_ATOMIC);
725         err = -ENOMEM;
726         if (unlikely(!p))
727                 goto err;
728
729         p->addr = *group;
730         p->port = port;
731         p->next = *pp;
732         hlist_add_head(&p->mglist, &port->mglist);
733         setup_timer(&p->timer, br_multicast_port_group_expired,
734                     (unsigned long)p);
735         setup_timer(&p->query_timer, br_multicast_port_group_query_expired,
736                     (unsigned long)p);
737
738         rcu_assign_pointer(*pp, p);
739
740 found:
741         mod_timer(&p->timer, now + br->multicast_membership_interval);
742 out:
743         err = 0;
744
745 err:
746         spin_unlock(&br->multicast_lock);
747         return err;
748 }
749
750 static int br_ip4_multicast_add_group(struct net_bridge *br,
751                                       struct net_bridge_port *port,
752                                       __be32 group)
753 {
754         struct br_ip br_group;
755
756         if (ipv4_is_local_multicast(group))
757                 return 0;
758
759         br_group.u.ip4 = group;
760         br_group.proto = htons(ETH_P_IP);
761
762         return br_multicast_add_group(br, port, &br_group);
763 }
764
765 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
766 static int br_ip6_multicast_add_group(struct net_bridge *br,
767                                       struct net_bridge_port *port,
768                                       const struct in6_addr *group)
769 {
770         struct br_ip br_group;
771
772         if (ipv6_is_local_multicast(group))
773                 return 0;
774
775         ipv6_addr_copy(&br_group.u.ip6, group);
776         br_group.proto = htons(ETH_P_IP);
777
778         return br_multicast_add_group(br, port, &br_group);
779 }
780 #endif
781
782 static void br_multicast_router_expired(unsigned long data)
783 {
784         struct net_bridge_port *port = (void *)data;
785         struct net_bridge *br = port->br;
786
787         spin_lock(&br->multicast_lock);
788         if (port->multicast_router != 1 ||
789             timer_pending(&port->multicast_router_timer) ||
790             hlist_unhashed(&port->rlist))
791                 goto out;
792
793         hlist_del_init_rcu(&port->rlist);
794
795 out:
796         spin_unlock(&br->multicast_lock);
797 }
798
799 static void br_multicast_local_router_expired(unsigned long data)
800 {
801 }
802
803 static void __br_multicast_send_query(struct net_bridge *br,
804                                       struct net_bridge_port *port,
805                                       struct br_ip *ip)
806 {
807         struct sk_buff *skb;
808
809         skb = br_multicast_alloc_query(br, ip);
810         if (!skb)
811                 return;
812
813         if (port) {
814                 __skb_push(skb, sizeof(struct ethhdr));
815                 skb->dev = port->dev;
816                 NF_HOOK(PF_BRIDGE, NF_BR_LOCAL_OUT, skb, NULL, skb->dev,
817                         dev_queue_xmit);
818         } else
819                 netif_rx(skb);
820 }
821
822 static void br_multicast_send_query(struct net_bridge *br,
823                                     struct net_bridge_port *port, u32 sent)
824 {
825         unsigned long time;
826         struct br_ip br_group;
827
828         if (!netif_running(br->dev) || br->multicast_disabled ||
829             timer_pending(&br->multicast_querier_timer))
830                 return;
831
832         memset(&br_group.u, 0, sizeof(br_group.u));
833
834         br_group.proto = htons(ETH_P_IP);
835         __br_multicast_send_query(br, port, &br_group);
836
837 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
838         br_group.proto = htons(ETH_P_IPV6);
839         __br_multicast_send_query(br, port, &br_group);
840 #endif
841
842         time = jiffies;
843         time += sent < br->multicast_startup_query_count ?
844                 br->multicast_startup_query_interval :
845                 br->multicast_query_interval;
846         mod_timer(port ? &port->multicast_query_timer :
847                          &br->multicast_query_timer, time);
848 }
849
850 static void br_multicast_port_query_expired(unsigned long data)
851 {
852         struct net_bridge_port *port = (void *)data;
853         struct net_bridge *br = port->br;
854
855         spin_lock(&br->multicast_lock);
856         if (port->state == BR_STATE_DISABLED ||
857             port->state == BR_STATE_BLOCKING)
858                 goto out;
859
860         if (port->multicast_startup_queries_sent <
861             br->multicast_startup_query_count)
862                 port->multicast_startup_queries_sent++;
863
864         br_multicast_send_query(port->br, port,
865                                 port->multicast_startup_queries_sent);
866
867 out:
868         spin_unlock(&br->multicast_lock);
869 }
870
871 void br_multicast_add_port(struct net_bridge_port *port)
872 {
873         port->multicast_router = 1;
874
875         setup_timer(&port->multicast_router_timer, br_multicast_router_expired,
876                     (unsigned long)port);
877         setup_timer(&port->multicast_query_timer,
878                     br_multicast_port_query_expired, (unsigned long)port);
879 }
880
881 void br_multicast_del_port(struct net_bridge_port *port)
882 {
883         del_timer_sync(&port->multicast_router_timer);
884 }
885
886 static void __br_multicast_enable_port(struct net_bridge_port *port)
887 {
888         port->multicast_startup_queries_sent = 0;
889
890         if (try_to_del_timer_sync(&port->multicast_query_timer) >= 0 ||
891             del_timer(&port->multicast_query_timer))
892                 mod_timer(&port->multicast_query_timer, jiffies);
893 }
894
895 void br_multicast_enable_port(struct net_bridge_port *port)
896 {
897         struct net_bridge *br = port->br;
898
899         spin_lock(&br->multicast_lock);
900         if (br->multicast_disabled || !netif_running(br->dev))
901                 goto out;
902
903         __br_multicast_enable_port(port);
904
905 out:
906         spin_unlock(&br->multicast_lock);
907 }
908
909 void br_multicast_disable_port(struct net_bridge_port *port)
910 {
911         struct net_bridge *br = port->br;
912         struct net_bridge_port_group *pg;
913         struct hlist_node *p, *n;
914
915         spin_lock(&br->multicast_lock);
916         hlist_for_each_entry_safe(pg, p, n, &port->mglist, mglist)
917                 br_multicast_del_pg(br, pg);
918
919         if (!hlist_unhashed(&port->rlist))
920                 hlist_del_init_rcu(&port->rlist);
921         del_timer(&port->multicast_router_timer);
922         del_timer(&port->multicast_query_timer);
923         spin_unlock(&br->multicast_lock);
924 }
925
926 static int br_ip4_multicast_igmp3_report(struct net_bridge *br,
927                                          struct net_bridge_port *port,
928                                          struct sk_buff *skb)
929 {
930         struct igmpv3_report *ih;
931         struct igmpv3_grec *grec;
932         int i;
933         int len;
934         int num;
935         int type;
936         int err = 0;
937         __be32 group;
938
939         if (!pskb_may_pull(skb, sizeof(*ih)))
940                 return -EINVAL;
941
942         ih = igmpv3_report_hdr(skb);
943         num = ntohs(ih->ngrec);
944         len = sizeof(*ih);
945
946         for (i = 0; i < num; i++) {
947                 len += sizeof(*grec);
948                 if (!pskb_may_pull(skb, len))
949                         return -EINVAL;
950
951                 grec = (void *)(skb->data + len - sizeof(*grec));
952                 group = grec->grec_mca;
953                 type = grec->grec_type;
954
955                 len += ntohs(grec->grec_nsrcs) * 4;
956                 if (!pskb_may_pull(skb, len))
957                         return -EINVAL;
958
959                 /* We treat this as an IGMPv2 report for now. */
960                 switch (type) {
961                 case IGMPV3_MODE_IS_INCLUDE:
962                 case IGMPV3_MODE_IS_EXCLUDE:
963                 case IGMPV3_CHANGE_TO_INCLUDE:
964                 case IGMPV3_CHANGE_TO_EXCLUDE:
965                 case IGMPV3_ALLOW_NEW_SOURCES:
966                 case IGMPV3_BLOCK_OLD_SOURCES:
967                         break;
968
969                 default:
970                         continue;
971                 }
972
973                 err = br_ip4_multicast_add_group(br, port, group);
974                 if (err)
975                         break;
976         }
977
978         return err;
979 }
980
981 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
982 static int br_ip6_multicast_mld2_report(struct net_bridge *br,
983                                         struct net_bridge_port *port,
984                                         struct sk_buff *skb)
985 {
986         struct icmp6hdr *icmp6h;
987         struct mld2_grec *grec;
988         int i;
989         int len;
990         int num;
991         int err = 0;
992
993         if (!pskb_may_pull(skb, sizeof(*icmp6h)))
994                 return -EINVAL;
995
996         icmp6h = icmp6_hdr(skb);
997         num = ntohs(icmp6h->icmp6_dataun.un_data16[1]);
998         len = sizeof(*icmp6h);
999
1000         for (i = 0; i < num; i++) {
1001                 __be16 *nsrcs, _nsrcs;
1002
1003                 nsrcs = skb_header_pointer(skb,
1004                                            len + offsetof(struct mld2_grec,
1005                                                           grec_mca),
1006                                            sizeof(_nsrcs), &_nsrcs);
1007                 if (!nsrcs)
1008                         return -EINVAL;
1009
1010                 if (!pskb_may_pull(skb,
1011                                    len + sizeof(*grec) +
1012                                    sizeof(struct in6_addr) * (*nsrcs)))
1013                         return -EINVAL;
1014
1015                 grec = (struct mld2_grec *)(skb->data + len);
1016                 len += sizeof(*grec) + sizeof(struct in6_addr) * (*nsrcs);
1017
1018                 /* We treat these as MLDv1 reports for now. */
1019                 switch (grec->grec_type) {
1020                 case MLD2_MODE_IS_INCLUDE:
1021                 case MLD2_MODE_IS_EXCLUDE:
1022                 case MLD2_CHANGE_TO_INCLUDE:
1023                 case MLD2_CHANGE_TO_EXCLUDE:
1024                 case MLD2_ALLOW_NEW_SOURCES:
1025                 case MLD2_BLOCK_OLD_SOURCES:
1026                         break;
1027
1028                 default:
1029                         continue;
1030                 }
1031
1032                 err = br_ip6_multicast_add_group(br, port, &grec->grec_mca);
1033                 if (!err)
1034                         break;
1035         }
1036
1037         return err;
1038 }
1039 #endif
1040
1041 static void br_multicast_add_router(struct net_bridge *br,
1042                                     struct net_bridge_port *port)
1043 {
1044         struct hlist_node *p;
1045         struct hlist_node **h;
1046
1047         for (h = &br->router_list.first;
1048              (p = *h) &&
1049              (unsigned long)container_of(p, struct net_bridge_port, rlist) >
1050              (unsigned long)port;
1051              h = &p->next)
1052                 ;
1053
1054         port->rlist.pprev = h;
1055         port->rlist.next = p;
1056         rcu_assign_pointer(*h, &port->rlist);
1057         if (p)
1058                 p->pprev = &port->rlist.next;
1059 }
1060
1061 static void br_multicast_mark_router(struct net_bridge *br,
1062                                      struct net_bridge_port *port)
1063 {
1064         unsigned long now = jiffies;
1065
1066         if (!port) {
1067                 if (br->multicast_router == 1)
1068                         mod_timer(&br->multicast_router_timer,
1069                                   now + br->multicast_querier_interval);
1070                 return;
1071         }
1072
1073         if (port->multicast_router != 1)
1074                 return;
1075
1076         if (!hlist_unhashed(&port->rlist))
1077                 goto timer;
1078
1079         br_multicast_add_router(br, port);
1080
1081 timer:
1082         mod_timer(&port->multicast_router_timer,
1083                   now + br->multicast_querier_interval);
1084 }
1085
1086 static void br_multicast_query_received(struct net_bridge *br,
1087                                         struct net_bridge_port *port,
1088                                         int saddr)
1089 {
1090         if (saddr)
1091                 mod_timer(&br->multicast_querier_timer,
1092                           jiffies + br->multicast_querier_interval);
1093         else if (timer_pending(&br->multicast_querier_timer))
1094                 return;
1095
1096         br_multicast_mark_router(br, port);
1097 }
1098
1099 static int br_ip4_multicast_query(struct net_bridge *br,
1100                                   struct net_bridge_port *port,
1101                                   struct sk_buff *skb)
1102 {
1103         struct iphdr *iph = ip_hdr(skb);
1104         struct igmphdr *ih = igmp_hdr(skb);
1105         struct net_bridge_mdb_entry *mp;
1106         struct igmpv3_query *ih3;
1107         struct net_bridge_port_group *p;
1108         struct net_bridge_port_group **pp;
1109         unsigned long max_delay;
1110         unsigned long now = jiffies;
1111         __be32 group;
1112         int err = 0;
1113
1114         spin_lock(&br->multicast_lock);
1115         if (!netif_running(br->dev) ||
1116             (port && port->state == BR_STATE_DISABLED))
1117                 goto out;
1118
1119         br_multicast_query_received(br, port, !!iph->saddr);
1120
1121         group = ih->group;
1122
1123         if (skb->len == sizeof(*ih)) {
1124                 max_delay = ih->code * (HZ / IGMP_TIMER_SCALE);
1125
1126                 if (!max_delay) {
1127                         max_delay = 10 * HZ;
1128                         group = 0;
1129                 }
1130         } else {
1131                 if (!pskb_may_pull(skb, sizeof(struct igmpv3_query))) {
1132                         err = -EINVAL;
1133                         goto out;
1134                 }
1135
1136                 ih3 = igmpv3_query_hdr(skb);
1137                 if (ih3->nsrcs)
1138                         goto out;
1139
1140                 max_delay = ih3->code ?
1141                             IGMPV3_MRC(ih3->code) * (HZ / IGMP_TIMER_SCALE) : 1;
1142         }
1143
1144         if (!group)
1145                 goto out;
1146
1147         mp = br_mdb_ip4_get(br->mdb, group);
1148         if (!mp)
1149                 goto out;
1150
1151         max_delay *= br->multicast_last_member_count;
1152
1153         if (!hlist_unhashed(&mp->mglist) &&
1154             (timer_pending(&mp->timer) ?
1155              time_after(mp->timer.expires, now + max_delay) :
1156              try_to_del_timer_sync(&mp->timer) >= 0))
1157                 mod_timer(&mp->timer, now + max_delay);
1158
1159         for (pp = &mp->ports; (p = *pp); pp = &p->next) {
1160                 if (timer_pending(&p->timer) ?
1161                     time_after(p->timer.expires, now + max_delay) :
1162                     try_to_del_timer_sync(&p->timer) >= 0)
1163                         mod_timer(&mp->timer, now + max_delay);
1164         }
1165
1166 out:
1167         spin_unlock(&br->multicast_lock);
1168         return err;
1169 }
1170
1171 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
1172 static int br_ip6_multicast_query(struct net_bridge *br,
1173                                   struct net_bridge_port *port,
1174                                   struct sk_buff *skb)
1175 {
1176         struct ipv6hdr *ip6h = ipv6_hdr(skb);
1177         struct mld_msg *mld = (struct mld_msg *) icmp6_hdr(skb);
1178         struct net_bridge_mdb_entry *mp;
1179         struct mld2_query *mld2q;
1180         struct net_bridge_port_group *p, **pp;
1181         unsigned long max_delay;
1182         unsigned long now = jiffies;
1183         struct in6_addr *group = NULL;
1184         int err = 0;
1185
1186         spin_lock(&br->multicast_lock);
1187         if (!netif_running(br->dev) ||
1188             (port && port->state == BR_STATE_DISABLED))
1189                 goto out;
1190
1191         br_multicast_query_received(br, port, !ipv6_addr_any(&ip6h->saddr));
1192
1193         if (skb->len == sizeof(*mld)) {
1194                 if (!pskb_may_pull(skb, sizeof(*mld))) {
1195                         err = -EINVAL;
1196                         goto out;
1197                 }
1198                 mld = (struct mld_msg *) icmp6_hdr(skb);
1199                 max_delay = msecs_to_jiffies(htons(mld->mld_maxdelay));
1200                 if (max_delay)
1201                         group = &mld->mld_mca;
1202         } else if (skb->len >= sizeof(*mld2q)) {
1203                 if (!pskb_may_pull(skb, sizeof(*mld2q))) {
1204                         err = -EINVAL;
1205                         goto out;
1206                 }
1207                 mld2q = (struct mld2_query *)icmp6_hdr(skb);
1208                 if (!mld2q->mld2q_nsrcs)
1209                         group = &mld2q->mld2q_mca;
1210                 max_delay = mld2q->mld2q_mrc ? MLDV2_MRC(mld2q->mld2q_mrc) : 1;
1211         }
1212
1213         if (!group)
1214                 goto out;
1215
1216         mp = br_mdb_ip6_get(br->mdb, group);
1217         if (!mp)
1218                 goto out;
1219
1220         max_delay *= br->multicast_last_member_count;
1221         if (!hlist_unhashed(&mp->mglist) &&
1222             (timer_pending(&mp->timer) ?
1223              time_after(mp->timer.expires, now + max_delay) :
1224              try_to_del_timer_sync(&mp->timer) >= 0))
1225                 mod_timer(&mp->timer, now + max_delay);
1226
1227         for (pp = &mp->ports; (p = *pp); pp = &p->next) {
1228                 if (timer_pending(&p->timer) ?
1229                     time_after(p->timer.expires, now + max_delay) :
1230                     try_to_del_timer_sync(&p->timer) >= 0)
1231                         mod_timer(&mp->timer, now + max_delay);
1232         }
1233
1234 out:
1235         spin_unlock(&br->multicast_lock);
1236         return err;
1237 }
1238 #endif
1239
1240 static void br_multicast_leave_group(struct net_bridge *br,
1241                                      struct net_bridge_port *port,
1242                                      struct br_ip *group)
1243 {
1244         struct net_bridge_mdb_htable *mdb;
1245         struct net_bridge_mdb_entry *mp;
1246         struct net_bridge_port_group *p;
1247         unsigned long now;
1248         unsigned long time;
1249
1250         spin_lock(&br->multicast_lock);
1251         if (!netif_running(br->dev) ||
1252             (port && port->state == BR_STATE_DISABLED) ||
1253             timer_pending(&br->multicast_querier_timer))
1254                 goto out;
1255
1256         mdb = br->mdb;
1257         mp = br_mdb_ip_get(mdb, group);
1258         if (!mp)
1259                 goto out;
1260
1261         now = jiffies;
1262         time = now + br->multicast_last_member_count *
1263                      br->multicast_last_member_interval;
1264
1265         if (!port) {
1266                 if (!hlist_unhashed(&mp->mglist) &&
1267                     (timer_pending(&mp->timer) ?
1268                      time_after(mp->timer.expires, time) :
1269                      try_to_del_timer_sync(&mp->timer) >= 0)) {
1270                         mod_timer(&mp->timer, time);
1271
1272                         mp->queries_sent = 0;
1273                         mod_timer(&mp->query_timer, now);
1274                 }
1275
1276                 goto out;
1277         }
1278
1279         for (p = mp->ports; p; p = p->next) {
1280                 if (p->port != port)
1281                         continue;
1282
1283                 if (!hlist_unhashed(&p->mglist) &&
1284                     (timer_pending(&p->timer) ?
1285                      time_after(p->timer.expires, time) :
1286                      try_to_del_timer_sync(&p->timer) >= 0)) {
1287                         mod_timer(&p->timer, time);
1288
1289                         p->queries_sent = 0;
1290                         mod_timer(&p->query_timer, now);
1291                 }
1292
1293                 break;
1294         }
1295
1296 out:
1297         spin_unlock(&br->multicast_lock);
1298 }
1299
1300 static void br_ip4_multicast_leave_group(struct net_bridge *br,
1301                                          struct net_bridge_port *port,
1302                                          __be32 group)
1303 {
1304         struct br_ip br_group;
1305
1306         if (ipv4_is_local_multicast(group))
1307                 return;
1308
1309         br_group.u.ip4 = group;
1310         br_group.proto = htons(ETH_P_IP);
1311
1312         br_multicast_leave_group(br, port, &br_group);
1313 }
1314
1315 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
1316 static void br_ip6_multicast_leave_group(struct net_bridge *br,
1317                                          struct net_bridge_port *port,
1318                                          const struct in6_addr *group)
1319 {
1320         struct br_ip br_group;
1321
1322         if (ipv6_is_local_multicast(group))
1323                 return;
1324
1325         ipv6_addr_copy(&br_group.u.ip6, group);
1326         br_group.proto = htons(ETH_P_IPV6);
1327
1328         br_multicast_leave_group(br, port, &br_group);
1329 }
1330 #endif
1331
1332 static int br_multicast_ipv4_rcv(struct net_bridge *br,
1333                                  struct net_bridge_port *port,
1334                                  struct sk_buff *skb)
1335 {
1336         struct sk_buff *skb2 = skb;
1337         struct iphdr *iph;
1338         struct igmphdr *ih;
1339         unsigned len;
1340         unsigned offset;
1341         int err;
1342
1343         /* We treat OOM as packet loss for now. */
1344         if (!pskb_may_pull(skb, sizeof(*iph)))
1345                 return -EINVAL;
1346
1347         iph = ip_hdr(skb);
1348
1349         if (iph->ihl < 5 || iph->version != 4)
1350                 return -EINVAL;
1351
1352         if (!pskb_may_pull(skb, ip_hdrlen(skb)))
1353                 return -EINVAL;
1354
1355         iph = ip_hdr(skb);
1356
1357         if (unlikely(ip_fast_csum((u8 *)iph, iph->ihl)))
1358                 return -EINVAL;
1359
1360         if (iph->protocol != IPPROTO_IGMP)
1361                 return 0;
1362
1363         len = ntohs(iph->tot_len);
1364         if (skb->len < len || len < ip_hdrlen(skb))
1365                 return -EINVAL;
1366
1367         if (skb->len > len) {
1368                 skb2 = skb_clone(skb, GFP_ATOMIC);
1369                 if (!skb2)
1370                         return -ENOMEM;
1371
1372                 err = pskb_trim_rcsum(skb2, len);
1373                 if (err)
1374                         goto err_out;
1375         }
1376
1377         len -= ip_hdrlen(skb2);
1378         offset = skb_network_offset(skb2) + ip_hdrlen(skb2);
1379         __skb_pull(skb2, offset);
1380         skb_reset_transport_header(skb2);
1381
1382         err = -EINVAL;
1383         if (!pskb_may_pull(skb2, sizeof(*ih)))
1384                 goto out;
1385
1386         switch (skb2->ip_summed) {
1387         case CHECKSUM_COMPLETE:
1388                 if (!csum_fold(skb2->csum))
1389                         break;
1390                 /* fall through */
1391         case CHECKSUM_NONE:
1392                 skb2->csum = 0;
1393                 if (skb_checksum_complete(skb2))
1394                         goto out;
1395         }
1396
1397         err = 0;
1398
1399         BR_INPUT_SKB_CB(skb)->igmp = 1;
1400         ih = igmp_hdr(skb2);
1401
1402         switch (ih->type) {
1403         case IGMP_HOST_MEMBERSHIP_REPORT:
1404         case IGMPV2_HOST_MEMBERSHIP_REPORT:
1405                 BR_INPUT_SKB_CB(skb2)->mrouters_only = 1;
1406                 err = br_ip4_multicast_add_group(br, port, ih->group);
1407                 break;
1408         case IGMPV3_HOST_MEMBERSHIP_REPORT:
1409                 err = br_ip4_multicast_igmp3_report(br, port, skb2);
1410                 break;
1411         case IGMP_HOST_MEMBERSHIP_QUERY:
1412                 err = br_ip4_multicast_query(br, port, skb2);
1413                 break;
1414         case IGMP_HOST_LEAVE_MESSAGE:
1415                 br_ip4_multicast_leave_group(br, port, ih->group);
1416                 break;
1417         }
1418
1419 out:
1420         __skb_push(skb2, offset);
1421 err_out:
1422         if (skb2 != skb)
1423                 kfree_skb(skb2);
1424         return err;
1425 }
1426
1427 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
1428 static int br_multicast_ipv6_rcv(struct net_bridge *br,
1429                                  struct net_bridge_port *port,
1430                                  struct sk_buff *skb)
1431 {
1432         struct sk_buff *skb2 = skb;
1433         struct ipv6hdr *ip6h;
1434         struct icmp6hdr *icmp6h;
1435         u8 nexthdr;
1436         unsigned len;
1437         unsigned offset;
1438         int err;
1439
1440         if (!pskb_may_pull(skb, sizeof(*ip6h)))
1441                 return -EINVAL;
1442
1443         ip6h = ipv6_hdr(skb);
1444
1445         /*
1446          * We're interested in MLD messages only.
1447          *  - Version is 6
1448          *  - MLD has always Router Alert hop-by-hop option
1449          *  - But we do not support jumbrograms.
1450          */
1451         if (ip6h->version != 6 ||
1452             ip6h->nexthdr != IPPROTO_HOPOPTS ||
1453             ip6h->payload_len == 0)
1454                 return 0;
1455
1456         len = ntohs(ip6h->payload_len);
1457         if (skb->len < len)
1458                 return -EINVAL;
1459
1460         nexthdr = ip6h->nexthdr;
1461         offset = ipv6_skip_exthdr(skb, sizeof(*ip6h), &nexthdr);
1462
1463         if (offset < 0 || nexthdr != IPPROTO_ICMPV6)
1464                 return 0;
1465
1466         /* Okay, we found ICMPv6 header */
1467         skb2 = skb_clone(skb, GFP_ATOMIC);
1468         if (!skb2)
1469                 return -ENOMEM;
1470
1471         len -= offset - skb_network_offset(skb2);
1472
1473         __skb_pull(skb2, offset);
1474         skb_reset_transport_header(skb2);
1475
1476         err = -EINVAL;
1477         if (!pskb_may_pull(skb2, sizeof(*icmp6h)))
1478                 goto out;
1479
1480         icmp6h = icmp6_hdr(skb2);
1481
1482         switch (icmp6h->icmp6_type) {
1483         case ICMPV6_MGM_QUERY:
1484         case ICMPV6_MGM_REPORT:
1485         case ICMPV6_MGM_REDUCTION:
1486         case ICMPV6_MLD2_REPORT:
1487                 break;
1488         default:
1489                 err = 0;
1490                 goto out;
1491         }
1492
1493         /* Okay, we found MLD message. Check further. */
1494         if (skb2->len > len) {
1495                 err = pskb_trim_rcsum(skb2, len);
1496                 if (err)
1497                         goto out;
1498         }
1499
1500         switch (skb2->ip_summed) {
1501         case CHECKSUM_COMPLETE:
1502                 if (!csum_fold(skb2->csum))
1503                         break;
1504                 /*FALLTHROUGH*/
1505         case CHECKSUM_NONE:
1506                 skb2->csum = 0;
1507                 if (skb_checksum_complete(skb2))
1508                         goto out;
1509         }
1510
1511         err = 0;
1512
1513         BR_INPUT_SKB_CB(skb)->igmp = 1;
1514
1515         switch (icmp6h->icmp6_type) {
1516         case ICMPV6_MGM_REPORT:
1517             {
1518                 struct mld_msg *mld = (struct mld_msg *)icmp6h;
1519                 BR_INPUT_SKB_CB(skb2)->mrouters_only = 1;
1520                 err = br_ip6_multicast_add_group(br, port, &mld->mld_mca);
1521                 break;
1522             }
1523         case ICMPV6_MLD2_REPORT:
1524                 err = br_ip6_multicast_mld2_report(br, port, skb2);
1525                 break;
1526         case ICMPV6_MGM_QUERY:
1527                 err = br_ip6_multicast_query(br, port, skb2);
1528                 break;
1529         case ICMPV6_MGM_REDUCTION:
1530             {
1531                 struct mld_msg *mld = (struct mld_msg *)icmp6h;
1532                 br_ip6_multicast_leave_group(br, port, &mld->mld_mca);
1533             }
1534         }
1535
1536 out:
1537         __skb_push(skb2, offset);
1538         if (skb2 != skb)
1539                 kfree_skb(skb2);
1540         return err;
1541 }
1542 #endif
1543
1544 int br_multicast_rcv(struct net_bridge *br, struct net_bridge_port *port,
1545                      struct sk_buff *skb)
1546 {
1547         BR_INPUT_SKB_CB(skb)->igmp = 0;
1548         BR_INPUT_SKB_CB(skb)->mrouters_only = 0;
1549
1550         if (br->multicast_disabled)
1551                 return 0;
1552
1553         switch (skb->protocol) {
1554         case htons(ETH_P_IP):
1555                 return br_multicast_ipv4_rcv(br, port, skb);
1556 #if defined(CONFIG_IPV6) || defined(CONFIG_IPV6_MODULE)
1557         case htons(ETH_P_IPV6):
1558                 return br_multicast_ipv6_rcv(br, port, skb);
1559 #endif
1560         }
1561
1562         return 0;
1563 }
1564
1565 static void br_multicast_query_expired(unsigned long data)
1566 {
1567         struct net_bridge *br = (void *)data;
1568
1569         spin_lock(&br->multicast_lock);
1570         if (br->multicast_startup_queries_sent <
1571             br->multicast_startup_query_count)
1572                 br->multicast_startup_queries_sent++;
1573
1574         br_multicast_send_query(br, NULL, br->multicast_startup_queries_sent);
1575
1576         spin_unlock(&br->multicast_lock);
1577 }
1578
1579 void br_multicast_init(struct net_bridge *br)
1580 {
1581         br->hash_elasticity = 4;
1582         br->hash_max = 512;
1583
1584         br->multicast_router = 1;
1585         br->multicast_last_member_count = 2;
1586         br->multicast_startup_query_count = 2;
1587
1588         br->multicast_last_member_interval = HZ;
1589         br->multicast_query_response_interval = 10 * HZ;
1590         br->multicast_startup_query_interval = 125 * HZ / 4;
1591         br->multicast_query_interval = 125 * HZ;
1592         br->multicast_querier_interval = 255 * HZ;
1593         br->multicast_membership_interval = 260 * HZ;
1594
1595         spin_lock_init(&br->multicast_lock);
1596         setup_timer(&br->multicast_router_timer,
1597                     br_multicast_local_router_expired, 0);
1598         setup_timer(&br->multicast_querier_timer,
1599                     br_multicast_local_router_expired, 0);
1600         setup_timer(&br->multicast_query_timer, br_multicast_query_expired,
1601                     (unsigned long)br);
1602 }
1603
1604 void br_multicast_open(struct net_bridge *br)
1605 {
1606         br->multicast_startup_queries_sent = 0;
1607
1608         if (br->multicast_disabled)
1609                 return;
1610
1611         mod_timer(&br->multicast_query_timer, jiffies);
1612 }
1613
1614 void br_multicast_stop(struct net_bridge *br)
1615 {
1616         struct net_bridge_mdb_htable *mdb;
1617         struct net_bridge_mdb_entry *mp;
1618         struct hlist_node *p, *n;
1619         u32 ver;
1620         int i;
1621
1622         del_timer_sync(&br->multicast_router_timer);
1623         del_timer_sync(&br->multicast_querier_timer);
1624         del_timer_sync(&br->multicast_query_timer);
1625
1626         spin_lock_bh(&br->multicast_lock);
1627         mdb = br->mdb;
1628         if (!mdb)
1629                 goto out;
1630
1631         br->mdb = NULL;
1632
1633         ver = mdb->ver;
1634         for (i = 0; i < mdb->max; i++) {
1635                 hlist_for_each_entry_safe(mp, p, n, &mdb->mhash[i],
1636                                           hlist[ver]) {
1637                         del_timer(&mp->timer);
1638                         del_timer(&mp->query_timer);
1639                         call_rcu_bh(&mp->rcu, br_multicast_free_group);
1640                 }
1641         }
1642
1643         if (mdb->old) {
1644                 spin_unlock_bh(&br->multicast_lock);
1645                 rcu_barrier_bh();
1646                 spin_lock_bh(&br->multicast_lock);
1647                 WARN_ON(mdb->old);
1648         }
1649
1650         mdb->old = mdb;
1651         call_rcu_bh(&mdb->rcu, br_mdb_free);
1652
1653 out:
1654         spin_unlock_bh(&br->multicast_lock);
1655 }
1656
1657 int br_multicast_set_router(struct net_bridge *br, unsigned long val)
1658 {
1659         int err = -ENOENT;
1660
1661         spin_lock_bh(&br->multicast_lock);
1662         if (!netif_running(br->dev))
1663                 goto unlock;
1664
1665         switch (val) {
1666         case 0:
1667         case 2:
1668                 del_timer(&br->multicast_router_timer);
1669                 /* fall through */
1670         case 1:
1671                 br->multicast_router = val;
1672                 err = 0;
1673                 break;
1674
1675         default:
1676                 err = -EINVAL;
1677                 break;
1678         }
1679
1680 unlock:
1681         spin_unlock_bh(&br->multicast_lock);
1682
1683         return err;
1684 }
1685
1686 int br_multicast_set_port_router(struct net_bridge_port *p, unsigned long val)
1687 {
1688         struct net_bridge *br = p->br;
1689         int err = -ENOENT;
1690
1691         spin_lock(&br->multicast_lock);
1692         if (!netif_running(br->dev) || p->state == BR_STATE_DISABLED)
1693                 goto unlock;
1694
1695         switch (val) {
1696         case 0:
1697         case 1:
1698         case 2:
1699                 p->multicast_router = val;
1700                 err = 0;
1701
1702                 if (val < 2 && !hlist_unhashed(&p->rlist))
1703                         hlist_del_init_rcu(&p->rlist);
1704
1705                 if (val == 1)
1706                         break;
1707
1708                 del_timer(&p->multicast_router_timer);
1709
1710                 if (val == 0)
1711                         break;
1712
1713                 br_multicast_add_router(br, p);
1714                 break;
1715
1716         default:
1717                 err = -EINVAL;
1718                 break;
1719         }
1720
1721 unlock:
1722         spin_unlock(&br->multicast_lock);
1723
1724         return err;
1725 }
1726
1727 int br_multicast_toggle(struct net_bridge *br, unsigned long val)
1728 {
1729         struct net_bridge_port *port;
1730         int err = -ENOENT;
1731
1732         spin_lock(&br->multicast_lock);
1733         if (!netif_running(br->dev))
1734                 goto unlock;
1735
1736         err = 0;
1737         if (br->multicast_disabled == !val)
1738                 goto unlock;
1739
1740         br->multicast_disabled = !val;
1741         if (br->multicast_disabled)
1742                 goto unlock;
1743
1744         if (br->mdb) {
1745                 if (br->mdb->old) {
1746                         err = -EEXIST;
1747 rollback:
1748                         br->multicast_disabled = !!val;
1749                         goto unlock;
1750                 }
1751
1752                 err = br_mdb_rehash(&br->mdb, br->mdb->max,
1753                                     br->hash_elasticity);
1754                 if (err)
1755                         goto rollback;
1756         }
1757
1758         br_multicast_open(br);
1759         list_for_each_entry(port, &br->port_list, list) {
1760                 if (port->state == BR_STATE_DISABLED ||
1761                     port->state == BR_STATE_BLOCKING)
1762                         continue;
1763
1764                 __br_multicast_enable_port(port);
1765         }
1766
1767 unlock:
1768         spin_unlock(&br->multicast_lock);
1769
1770         return err;
1771 }
1772
1773 int br_multicast_set_hash_max(struct net_bridge *br, unsigned long val)
1774 {
1775         int err = -ENOENT;
1776         u32 old;
1777
1778         spin_lock(&br->multicast_lock);
1779         if (!netif_running(br->dev))
1780                 goto unlock;
1781
1782         err = -EINVAL;
1783         if (!is_power_of_2(val))
1784                 goto unlock;
1785         if (br->mdb && val < br->mdb->size)
1786                 goto unlock;
1787
1788         err = 0;
1789
1790         old = br->hash_max;
1791         br->hash_max = val;
1792
1793         if (br->mdb) {
1794                 if (br->mdb->old) {
1795                         err = -EEXIST;
1796 rollback:
1797                         br->hash_max = old;
1798                         goto unlock;
1799                 }
1800
1801                 err = br_mdb_rehash(&br->mdb, br->hash_max,
1802                                     br->hash_elasticity);
1803                 if (err)
1804                         goto rollback;
1805         }
1806
1807 unlock:
1808         spin_unlock(&br->multicast_lock);
1809
1810         return err;
1811 }