cls_cgroup: Store classid in struct sock
[safe/jmp/linux-2.6] / net / sched / cls_cgroup.c
index 91a3db4..78ef2c5 100644 (file)
  */
 
 #include <linux/module.h>
+#include <linux/slab.h>
 #include <linux/types.h>
 #include <linux/string.h>
 #include <linux/errno.h>
 #include <linux/skbuff.h>
 #include <linux/cgroup.h>
+#include <linux/rcupdate.h>
 #include <net/rtnetlink.h>
 #include <net/pkt_cls.h>
+#include <net/sock.h>
+#include <net/cls_cgroup.h>
 
-struct cgroup_cls_state
-{
-       struct cgroup_subsys_state css;
-       u32 classid;
+static struct cgroup_subsys_state *cgrp_create(struct cgroup_subsys *ss,
+                                              struct cgroup *cgrp);
+static void cgrp_destroy(struct cgroup_subsys *ss, struct cgroup *cgrp);
+static int cgrp_populate(struct cgroup_subsys *ss, struct cgroup *cgrp);
+
+struct cgroup_subsys net_cls_subsys = {
+       .name           = "net_cls",
+       .create         = cgrp_create,
+       .destroy        = cgrp_destroy,
+       .populate       = cgrp_populate,
+#ifdef CONFIG_NET_CLS_CGROUP
+       .subsys_id      = net_cls_subsys_id,
+#else
+#define net_cls_subsys_id net_cls_subsys.subsys_id
+#endif
+       .module         = THIS_MODULE,
 };
 
+
 static inline struct cgroup_cls_state *cgrp_cls_state(struct cgroup *cgrp)
 {
        return container_of(cgroup_subsys_state(cgrp, net_cls_subsys_id),
@@ -62,13 +79,7 @@ static u64 read_classid(struct cgroup *cgrp, struct cftype *cft)
 
 static int write_classid(struct cgroup *cgrp, struct cftype *cft, u64 value)
 {
-       if (!cgroup_lock_live_group(cgrp))
-               return -ENODEV;
-
        cgrp_cls_state(cgrp)->classid = (u32) value;
-
-       cgroup_unlock();
-
        return 0;
 }
 
@@ -85,14 +96,6 @@ static int cgrp_populate(struct cgroup_subsys *ss, struct cgroup *cgrp)
        return cgroup_add_files(cgrp, ss, ss_files, ARRAY_SIZE(ss_files));
 }
 
-struct cgroup_subsys net_cls_subsys = {
-       .name           = "net_cls",
-       .create         = cgrp_create,
-       .destroy        = cgrp_destroy,
-       .populate       = cgrp_populate,
-       .subsys_id      = net_cls_subsys_id,
-};
-
 struct cls_cgroup_head
 {
        u32                     handle;
@@ -104,8 +107,11 @@ static int cls_cgroup_classify(struct sk_buff *skb, struct tcf_proto *tp,
                               struct tcf_result *res)
 {
        struct cls_cgroup_head *head = tp->root;
-       struct cgroup_cls_state *cs;
-       int ret = 0;
+       u32 classid;
+
+       rcu_read_lock();
+       classid = task_cls_state(current)->classid;
+       rcu_read_unlock();
 
        /*
         * Due to the nature of the classifier it is required to ignore all
@@ -117,21 +123,22 @@ static int cls_cgroup_classify(struct sk_buff *skb, struct tcf_proto *tp,
         * calls by looking at the number of nested bh disable calls because
         * softirqs always disables bh.
         */
-       if (softirq_count() != SOFTIRQ_OFFSET)
-               return -1;
+       if (softirq_count() != SOFTIRQ_OFFSET) {
+               /* If there is an sk_classid we'll use that. */
+               if (!skb->sk)
+                       return -1;
+               classid = skb->sk->sk_classid;
+       }
 
-       rcu_read_lock();
-       cs = task_cls_state(current);
-       if (cs->classid && tcf_em_tree_match(skb, &head->ematches, NULL)) {
-               res->classid = cs->classid;
-               res->class = 0;
-               ret = tcf_exts_exec(skb, &head->exts, res);
-       } else
-               ret = -1;
+       if (!classid)
+               return -1;
 
-       rcu_read_unlock();
+       if (!tcf_em_tree_match(skb, &head->ematches, NULL))
+               return -1;
 
-       return ret;
+       res->classid = classid;
+       res->class = 0;
+       return tcf_exts_exec(skb, &head->exts, res);
 }
 
 static unsigned long cls_cgroup_get(struct tcf_proto *tp, u32 handle)
@@ -167,6 +174,9 @@ static int cls_cgroup_change(struct tcf_proto *tp, unsigned long base,
        struct tcf_exts e;
        int err;
 
+       if (!tca[TCA_OPTIONS])
+               return -EINVAL;
+
        if (head == NULL) {
                if (!handle)
                        return -EINVAL;
@@ -280,12 +290,36 @@ static struct tcf_proto_ops cls_cgroup_ops __read_mostly = {
 
 static int __init init_cgroup_cls(void)
 {
-       return register_tcf_proto_ops(&cls_cgroup_ops);
+       int ret;
+
+       ret = cgroup_load_subsys(&net_cls_subsys);
+       if (ret)
+               goto out;
+
+#ifndef CONFIG_NET_CLS_CGROUP
+       /* We can't use rcu_assign_pointer because this is an int. */
+       smp_wmb();
+       net_cls_subsys_id = net_cls_subsys.subsys_id;
+#endif
+
+       ret = register_tcf_proto_ops(&cls_cgroup_ops);
+       if (ret)
+               cgroup_unload_subsys(&net_cls_subsys);
+
+out:
+       return ret;
 }
 
 static void __exit exit_cgroup_cls(void)
 {
        unregister_tcf_proto_ops(&cls_cgroup_ops);
+
+#ifndef CONFIG_NET_CLS_CGROUP
+       net_cls_subsys_id = -1;
+       synchronize_rcu();
+#endif
+
+       cgroup_unload_subsys(&net_cls_subsys);
 }
 
 module_init(init_cgroup_cls);