NetLabel: Cleanup the LSM domain hash functions
[safe/jmp/linux-2.6] / net / netlabel / netlabel_cipso_v4.c
index a6ce1d6..becf91a 100644 (file)
 #include <net/genetlink.h>
 #include <net/netlabel.h>
 #include <net/cipso_ipv4.h>
+#include <asm/atomic.h>
 
 #include "netlabel_user.h"
 #include "netlabel_cipso_v4.h"
+#include "netlabel_mgmt.h"
 
 /* Argument struct for cipso_v4_doi_walk() */
 struct netlbl_cipsov4_doiwalk_arg {
@@ -59,7 +61,7 @@ static struct genl_family netlbl_cipsov4_gnl_family = {
 };
 
 /* NetLabel Netlink attribute policy */
-static struct nla_policy netlbl_cipsov4_genl_policy[NLBL_CIPSOV4_A_MAX + 1] = {
+static const struct nla_policy netlbl_cipsov4_genl_policy[NLBL_CIPSOV4_A_MAX + 1] = {
        [NLBL_CIPSOV4_A_DOI] = { .type = NLA_U32 },
        [NLBL_CIPSOV4_A_MTYPE] = { .type = NLA_U32 },
        [NLBL_CIPSOV4_A_TAG] = { .type = NLA_U8 },
@@ -129,13 +131,13 @@ static int netlbl_cipsov4_add_common(struct genl_info *info,
                return -EINVAL;
 
        nla_for_each_nested(nla, info->attrs[NLBL_CIPSOV4_A_TAGLST], nla_rem)
-               if (nla->nla_type == NLBL_CIPSOV4_A_TAG) {
-                       if (iter > CIPSO_V4_TAG_MAXCNT)
+               if (nla_type(nla) == NLBL_CIPSOV4_A_TAG) {
+                       if (iter >= CIPSO_V4_TAG_MAXCNT)
                                return -EINVAL;
                        doi_def->tags[iter++] = nla_get_u8(nla);
                }
-       if (iter < CIPSO_V4_TAG_MAXCNT)
-               doi_def->tags[iter] = CIPSO_V4_TAG_INVALID;
+       while (iter < CIPSO_V4_TAG_MAXCNT)
+               doi_def->tags[iter++] = CIPSO_V4_TAG_INVALID;
 
        return 0;
 }
@@ -162,6 +164,7 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
        struct nlattr *nla_b;
        int nla_a_rem;
        int nla_b_rem;
+       u32 iter;
 
        if (!info->attrs[NLBL_CIPSOV4_A_TAGLST] ||
            !info->attrs[NLBL_CIPSOV4_A_MLSLVLLST])
@@ -185,20 +188,31 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
        ret_val = netlbl_cipsov4_add_common(info, doi_def);
        if (ret_val != 0)
                goto add_std_failure;
+       ret_val = -EINVAL;
 
        nla_for_each_nested(nla_a,
                            info->attrs[NLBL_CIPSOV4_A_MLSLVLLST],
                            nla_a_rem)
-               if (nla_a->nla_type == NLBL_CIPSOV4_A_MLSLVL) {
+               if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSLVL) {
+                       if (nla_validate_nested(nla_a,
+                                           NLBL_CIPSOV4_A_MAX,
+                                           netlbl_cipsov4_genl_policy) != 0)
+                                       goto add_std_failure;
                        nla_for_each_nested(nla_b, nla_a, nla_b_rem)
-                               switch (nla_b->nla_type) {
+                               switch (nla_type(nla_b)) {
                                case NLBL_CIPSOV4_A_MLSLVLLOC:
+                                       if (nla_get_u32(nla_b) >
+                                           CIPSO_V4_MAX_LOC_LVLS)
+                                               goto add_std_failure;
                                        if (nla_get_u32(nla_b) >=
                                            doi_def->map.std->lvl.local_size)
                                             doi_def->map.std->lvl.local_size =
                                                     nla_get_u32(nla_b) + 1;
                                        break;
                                case NLBL_CIPSOV4_A_MLSLVLREM:
+                                       if (nla_get_u32(nla_b) >
+                                           CIPSO_V4_MAX_REM_LVLS)
+                                               goto add_std_failure;
                                        if (nla_get_u32(nla_b) >=
                                            doi_def->map.std->lvl.cipso_size)
                                             doi_def->map.std->lvl.cipso_size =
@@ -206,9 +220,6 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
                                        break;
                                }
                }
-       if (doi_def->map.std->lvl.local_size > CIPSO_V4_MAX_LOC_LVLS ||
-           doi_def->map.std->lvl.cipso_size > CIPSO_V4_MAX_REM_LVLS)
-               goto add_std_failure;
        doi_def->map.std->lvl.local = kcalloc(doi_def->map.std->lvl.local_size,
                                              sizeof(u32),
                                              GFP_KERNEL);
@@ -223,18 +234,17 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
                ret_val = -ENOMEM;
                goto add_std_failure;
        }
+       for (iter = 0; iter < doi_def->map.std->lvl.local_size; iter++)
+               doi_def->map.std->lvl.local[iter] = CIPSO_V4_INV_LVL;
+       for (iter = 0; iter < doi_def->map.std->lvl.cipso_size; iter++)
+               doi_def->map.std->lvl.cipso[iter] = CIPSO_V4_INV_LVL;
        nla_for_each_nested(nla_a,
                            info->attrs[NLBL_CIPSOV4_A_MLSLVLLST],
                            nla_a_rem)
-               if (nla_a->nla_type == NLBL_CIPSOV4_A_MLSLVL) {
+               if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSLVL) {
                        struct nlattr *lvl_loc;
                        struct nlattr *lvl_rem;
 
-                       if (nla_validate_nested(nla_a,
-                                             NLBL_CIPSOV4_A_MAX,
-                                             netlbl_cipsov4_genl_policy) != 0)
-                               goto add_std_failure;
-
                        lvl_loc = nla_find_nested(nla_a,
                                                  NLBL_CIPSOV4_A_MLSLVLLOC);
                        lvl_rem = nla_find_nested(nla_a,
@@ -256,20 +266,26 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
                nla_for_each_nested(nla_a,
                                    info->attrs[NLBL_CIPSOV4_A_MLSCATLST],
                                    nla_a_rem)
-                       if (nla_a->nla_type == NLBL_CIPSOV4_A_MLSCAT) {
+                       if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSCAT) {
                                if (nla_validate_nested(nla_a,
                                              NLBL_CIPSOV4_A_MAX,
                                              netlbl_cipsov4_genl_policy) != 0)
                                        goto add_std_failure;
                                nla_for_each_nested(nla_b, nla_a, nla_b_rem)
-                                       switch (nla_b->nla_type) {
+                                       switch (nla_type(nla_b)) {
                                        case NLBL_CIPSOV4_A_MLSCATLOC:
+                                               if (nla_get_u32(nla_b) >
+                                                   CIPSO_V4_MAX_LOC_CATS)
+                                                       goto add_std_failure;
                                                if (nla_get_u32(nla_b) >=
                                              doi_def->map.std->cat.local_size)
                                             doi_def->map.std->cat.local_size =
                                                     nla_get_u32(nla_b) + 1;
                                                break;
                                        case NLBL_CIPSOV4_A_MLSCATREM:
+                                               if (nla_get_u32(nla_b) >
+                                                   CIPSO_V4_MAX_REM_CATS)
+                                                       goto add_std_failure;
                                                if (nla_get_u32(nla_b) >=
                                              doi_def->map.std->cat.cipso_size)
                                             doi_def->map.std->cat.cipso_size =
@@ -277,11 +293,8 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
                                                break;
                                        }
                        }
-               if (doi_def->map.std->cat.local_size > CIPSO_V4_MAX_LOC_CATS ||
-                   doi_def->map.std->cat.cipso_size > CIPSO_V4_MAX_REM_CATS)
-                       goto add_std_failure;
                doi_def->map.std->cat.local = kcalloc(
-                                             doi_def->map.std->cat.local_size,
+                                             doi_def->map.std->cat.local_size,
                                              sizeof(u32),
                                              GFP_KERNEL);
                if (doi_def->map.std->cat.local == NULL) {
@@ -289,17 +302,21 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
                        goto add_std_failure;
                }
                doi_def->map.std->cat.cipso = kcalloc(
-                                             doi_def->map.std->cat.cipso_size,
+                                             doi_def->map.std->cat.cipso_size,
                                              sizeof(u32),
                                              GFP_KERNEL);
                if (doi_def->map.std->cat.cipso == NULL) {
                        ret_val = -ENOMEM;
                        goto add_std_failure;
                }
+               for (iter = 0; iter < doi_def->map.std->cat.local_size; iter++)
+                       doi_def->map.std->cat.local[iter] = CIPSO_V4_INV_CAT;
+               for (iter = 0; iter < doi_def->map.std->cat.cipso_size; iter++)
+                       doi_def->map.std->cat.cipso[iter] = CIPSO_V4_INV_CAT;
                nla_for_each_nested(nla_a,
                                    info->attrs[NLBL_CIPSOV4_A_MLSCATLST],
                                    nla_a_rem)
-                       if (nla_a->nla_type == NLBL_CIPSOV4_A_MLSCAT) {
+                       if (nla_type(nla_a) == NLBL_CIPSOV4_A_MLSCAT) {
                                struct nlattr *cat_loc;
                                struct nlattr *cat_rem;
 
@@ -310,10 +327,10 @@ static int netlbl_cipsov4_add_std(struct genl_info *info)
                                if (cat_loc == NULL || cat_rem == NULL)
                                        goto add_std_failure;
                                doi_def->map.std->cat.local[
-                                                       nla_get_u32(cat_loc)] =
+                                                       nla_get_u32(cat_loc)] =
                                        nla_get_u32(cat_rem);
                                doi_def->map.std->cat.cipso[
-                                                       nla_get_u32(cat_rem)] =
+                                                       nla_get_u32(cat_rem)] =
                                        nla_get_u32(cat_loc);
                        }
        }
@@ -404,15 +421,19 @@ static int netlbl_cipsov4_add(struct sk_buff *skb, struct genl_info *info)
                ret_val = netlbl_cipsov4_add_pass(info);
                break;
        }
+       if (ret_val == 0)
+               atomic_inc(&netlabel_mgmt_protocount);
 
        audit_buf = netlbl_audit_start_common(AUDIT_MAC_CIPSOV4_ADD,
                                              &audit_info);
-       audit_log_format(audit_buf,
-                        " cipso_doi=%u cipso_type=%s res=%u",
-                        doi,
-                        type_str,
-                        ret_val == 0 ? 1 : 0);
-       audit_log_end(audit_buf);
+       if (audit_buf != NULL) {
+               audit_log_format(audit_buf,
+                                " cipso_doi=%u cipso_type=%s res=%u",
+                                doi,
+                                type_str,
+                                ret_val == 0 ? 1 : 0);
+               audit_log_end(audit_buf);
+       }
 
        return ret_val;
 }
@@ -452,17 +473,13 @@ static int netlbl_cipsov4_list(struct sk_buff *skb, struct genl_info *info)
        }
 
 list_start:
-       ans_skb = nlmsg_new(NLMSG_GOODSIZE * nlsze_mult, GFP_KERNEL);
+       ans_skb = nlmsg_new(NLMSG_DEFAULT_SIZE * nlsze_mult, GFP_KERNEL);
        if (ans_skb == NULL) {
                ret_val = -ENOMEM;
                goto list_failure;
        }
-       data = netlbl_netlink_hdr_put(ans_skb,
-                                     info->snd_pid,
-                                     info->snd_seq,
-                                     netlbl_cipsov4_gnl_family.id,
-                                     0,
-                                     NLBL_CIPSOV4_C_LIST);
+       data = genlmsg_put_reply(ans_skb, info, &netlbl_cipsov4_gnl_family,
+                                0, NLBL_CIPSOV4_C_LIST);
        if (data == NULL) {
                ret_val = -ENOMEM;
                goto list_failure;
@@ -568,7 +585,7 @@ list_start:
 
        genlmsg_end(ans_skb, data);
 
-       ret_val = genlmsg_unicast(ans_skb, info->snd_pid);
+       ret_val = genlmsg_reply(ans_skb, info);
        if (ret_val != 0)
                goto list_failure;
 
@@ -607,12 +624,9 @@ static int netlbl_cipsov4_listall_cb(struct cipso_v4_doi *doi_def, void *arg)
        struct netlbl_cipsov4_doiwalk_arg *cb_arg = arg;
        void *data;
 
-       data = netlbl_netlink_hdr_put(cb_arg->skb,
-                                     NETLINK_CB(cb_arg->nl_cb->skb).pid,
-                                     cb_arg->seq,
-                                     netlbl_cipsov4_gnl_family.id,
-                                     NLM_F_MULTI,
-                                     NLBL_CIPSOV4_C_LISTALL);
+       data = genlmsg_put(cb_arg->skb, NETLINK_CB(cb_arg->nl_cb->skb).pid,
+                          cb_arg->seq, &netlbl_cipsov4_gnl_family,
+                          NLM_F_MULTI, NLBL_CIPSOV4_C_LISTALL);
        if (data == NULL)
                goto listall_cb_failure;
 
@@ -684,14 +698,18 @@ static int netlbl_cipsov4_remove(struct sk_buff *skb, struct genl_info *info)
        ret_val = cipso_v4_doi_remove(doi,
                                      &audit_info,
                                      netlbl_cipsov4_doi_free);
+       if (ret_val == 0)
+               atomic_dec(&netlabel_mgmt_protocount);
 
        audit_buf = netlbl_audit_start_common(AUDIT_MAC_CIPSOV4_DEL,
                                              &audit_info);
-       audit_log_format(audit_buf,
-                        " cipso_doi=%u res=%u",
-                        doi,
-                        ret_val == 0 ? 1 : 0);
-       audit_log_end(audit_buf);
+       if (audit_buf != NULL) {
+               audit_log_format(audit_buf,
+                                " cipso_doi=%u res=%u",
+                                doi,
+                                ret_val == 0 ? 1 : 0);
+               audit_log_end(audit_buf);
+       }
 
        return ret_val;
 }