Merge branch 'master' of git://git.kernel.org/pub/scm/linux/kernel/git/linville/wirel...
[safe/jmp/linux-2.6] / net / core / sock.c
index 7633422..37fe9b6 100644 (file)
 #include <linux/net_tstamp.h>
 #include <net/xfrm.h>
 #include <linux/ipsec.h>
+#include <net/cls_cgroup.h>
 
 #include <linux/filter.h>
 
@@ -142,7 +143,7 @@ static struct lock_class_key af_family_slock_keys[AF_MAX];
  * strings build-time, so that runtime initialization of socket
  * locks is fast):
  */
-static const char *af_family_key_strings[AF_MAX+1] = {
+static const char *const af_family_key_strings[AF_MAX+1] = {
   "sk_lock-AF_UNSPEC", "sk_lock-AF_UNIX"     , "sk_lock-AF_INET"     ,
   "sk_lock-AF_AX25"  , "sk_lock-AF_IPX"      , "sk_lock-AF_APPLETALK",
   "sk_lock-AF_NETROM", "sk_lock-AF_BRIDGE"   , "sk_lock-AF_ATMPVC"   ,
@@ -158,7 +159,7 @@ static const char *af_family_key_strings[AF_MAX+1] = {
   "sk_lock-AF_IEEE802154",
   "sk_lock-AF_MAX"
 };
-static const char *af_family_slock_key_strings[AF_MAX+1] = {
+static const char *const af_family_slock_key_strings[AF_MAX+1] = {
   "slock-AF_UNSPEC", "slock-AF_UNIX"     , "slock-AF_INET"     ,
   "slock-AF_AX25"  , "slock-AF_IPX"      , "slock-AF_APPLETALK",
   "slock-AF_NETROM", "slock-AF_BRIDGE"   , "slock-AF_ATMPVC"   ,
@@ -174,7 +175,7 @@ static const char *af_family_slock_key_strings[AF_MAX+1] = {
   "slock-AF_IEEE802154",
   "slock-AF_MAX"
 };
-static const char *af_family_clock_key_strings[AF_MAX+1] = {
+static const char *const af_family_clock_key_strings[AF_MAX+1] = {
   "clock-AF_UNSPEC", "clock-AF_UNIX"     , "clock-AF_INET"     ,
   "clock-AF_AX25"  , "clock-AF_IPX"      , "clock-AF_APPLETALK",
   "clock-AF_NETROM", "clock-AF_BRIDGE"   , "clock-AF_ATMPVC"   ,
@@ -217,6 +218,11 @@ __u32 sysctl_rmem_default __read_mostly = SK_RMEM_MAX;
 int sysctl_optmem_max __read_mostly = sizeof(unsigned long)*(2*UIO_MAXIOV+512);
 EXPORT_SYMBOL(sysctl_optmem_max);
 
+#if defined(CONFIG_CGROUPS) && !defined(CONFIG_NET_CLS_CGROUP)
+int net_cls_subsys_id = -1;
+EXPORT_SYMBOL_GPL(net_cls_subsys_id);
+#endif
+
 static int sock_set_timeout(long *timeo_p, char __user *optval, int optlen)
 {
        struct timeval tv;
@@ -274,25 +280,27 @@ static void sock_disable_timestamp(struct sock *sk, int flag)
 
 int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
 {
-       int err = 0;
+       int err;
        int skb_len;
+       unsigned long flags;
+       struct sk_buff_head *list = &sk->sk_receive_queue;
 
        /* Cast sk->rcvbuf to unsigned... It's pointless, but reduces
           number of warnings when compiling with -W --ANK
         */
        if (atomic_read(&sk->sk_rmem_alloc) + skb->truesize >=
            (unsigned)sk->sk_rcvbuf) {
-               err = -ENOMEM;
-               goto out;
+               atomic_inc(&sk->sk_drops);
+               return -ENOMEM;
        }
 
        err = sk_filter(sk, skb);
        if (err)
-               goto out;
+               return err;
 
        if (!sk_rmem_schedule(sk, skb->truesize)) {
-               err = -ENOBUFS;
-               goto out;
+               atomic_inc(&sk->sk_drops);
+               return -ENOBUFS;
        }
 
        skb->dev = NULL;
@@ -305,12 +313,19 @@ int sock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
         */
        skb_len = skb->len;
 
-       skb_queue_tail(&sk->sk_receive_queue, skb);
+       /* we escape from rcu protected region, make sure we dont leak
+        * a norefcounted dst
+        */
+       skb_dst_force(skb);
+
+       spin_lock_irqsave(&list->lock, flags);
+       skb->dropcount = atomic_read(&sk->sk_drops);
+       __skb_queue_tail(list, skb);
+       spin_unlock_irqrestore(&list->lock, flags);
 
        if (!sock_flag(sk, SOCK_DEAD))
                sk->sk_data_ready(sk, skb_len);
-out:
-       return err;
+       return 0;
 }
 EXPORT_SYMBOL(sock_queue_rcv_skb);
 
@@ -323,6 +338,10 @@ int sk_receive_skb(struct sock *sk, struct sk_buff *skb, const int nested)
 
        skb->dev = NULL;
 
+       if (sk_rcvqueues_full(sk, skb)) {
+               atomic_inc(&sk->sk_drops);
+               goto discard_and_relse;
+       }
        if (nested)
                bh_lock_sock_nested(sk);
        else
@@ -336,8 +355,12 @@ int sk_receive_skb(struct sock *sk, struct sk_buff *skb, const int nested)
                rc = sk_backlog_rcv(sk, skb);
 
                mutex_release(&sk->sk_lock.dep_map, 1, _RET_IP_);
-       } else
-               sk_add_backlog(sk, skb);
+       } else if (sk_add_backlog(sk, skb)) {
+               bh_unlock_sock(sk);
+               atomic_inc(&sk->sk_drops);
+               goto discard_and_relse;
+       }
+
        bh_unlock_sock(sk);
 out:
        sock_put(sk);
@@ -348,12 +371,19 @@ discard_and_relse:
 }
 EXPORT_SYMBOL(sk_receive_skb);
 
+void sk_reset_txq(struct sock *sk)
+{
+       sk_tx_queue_clear(sk);
+}
+EXPORT_SYMBOL(sk_reset_txq);
+
 struct dst_entry *__sk_dst_check(struct sock *sk, u32 cookie)
 {
-       struct dst_entry *dst = sk->sk_dst_cache;
+       struct dst_entry *dst = __sk_dst_get(sk);
 
        if (dst && dst->obsolete && dst->ops->check(dst, cookie) == NULL) {
-               sk->sk_dst_cache = NULL;
+               sk_tx_queue_clear(sk);
+               rcu_assign_pointer(sk->sk_dst_cache, NULL);
                dst_release(dst);
                return NULL;
        }
@@ -406,17 +436,18 @@ static int sock_bindtodevice(struct sock *sk, char __user *optval, int optlen)
        if (copy_from_user(devname, optval, optlen))
                goto out;
 
-       if (devname[0] == '\0') {
-               index = 0;
-       } else {
-               struct net_device *dev = dev_get_by_name(net, devname);
+       index = 0;
+       if (devname[0] != '\0') {
+               struct net_device *dev;
 
+               rcu_read_lock();
+               dev = dev_get_by_name_rcu(net, devname);
+               if (dev)
+                       index = dev->ifindex;
+               rcu_read_unlock();
                ret = -ENODEV;
                if (!dev)
                        goto out;
-
-               index = dev->ifindex;
-               dev_put(dev);
        }
 
        lock_sock(sk);
@@ -446,7 +477,7 @@ static inline void sock_valbool_flag(struct sock *sk, int bit, int valbool)
  */
 
 int sock_setsockopt(struct socket *sock, int level, int optname,
-                   char __user *optval, int optlen)
+                   char __user *optval, unsigned int optlen)
 {
        struct sock *sk = sock->sk;
        int val;
@@ -482,6 +513,8 @@ int sock_setsockopt(struct socket *sock, int level, int optname,
                sk->sk_reuse = valbool;
                break;
        case SO_TYPE:
+       case SO_PROTOCOL:
+       case SO_DOMAIN:
        case SO_ERROR:
                ret = -ENOPROTOOPT;
                break;
@@ -700,6 +733,12 @@ set_rcvbuf:
 
                /* We implement the SO_SNDLOWAT etc to
                   not be settable (1003.1g 5.3) */
+       case SO_RXQ_OVFL:
+               if (valbool)
+                       sock_set_flag(sk, SOCK_RXQ_OVFL);
+               else
+                       sock_reset_flag(sk, SOCK_RXQ_OVFL);
+               break;
        default:
                ret = -ENOPROTOOPT;
                break;
@@ -721,7 +760,7 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                struct timeval tm;
        } v;
 
-       unsigned int lv = sizeof(int);
+       int lv = sizeof(int);
        int len;
 
        if (get_user(len, optlen))
@@ -764,6 +803,14 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                v.val = sk->sk_type;
                break;
 
+       case SO_PROTOCOL:
+               v.val = sk->sk_protocol;
+               break;
+
+       case SO_DOMAIN:
+               v.val = sk->sk_family;
+               break;
+
        case SO_ERROR:
                v.val = -sock_error(sk);
                if (v.val == 0)
@@ -891,6 +938,10 @@ int sock_getsockopt(struct socket *sock, int level, int optname,
                v.val = sk->sk_mark;
                break;
 
+       case SO_RXQ_OVFL:
+               v.val = !!sock_flag(sk, SOCK_RXQ_OVFL);
+               break;
+
        default:
                return -ENOPROTOOPT;
        }
@@ -929,7 +980,8 @@ static void sock_copy(struct sock *nsk, const struct sock *osk)
        void *sptr = nsk->sk_security;
 #endif
        BUILD_BUG_ON(offsetof(struct sock, sk_copy_start) !=
-                    sizeof(osk->sk_node) + sizeof(osk->sk_refcnt));
+                    sizeof(osk->sk_node) + sizeof(osk->sk_refcnt) +
+                    sizeof(osk->sk_tx_queue_mapping));
        memcpy(&nsk->sk_copy_start, &osk->sk_copy_start,
               osk->sk_prot->obj_size - offsetof(struct sock, sk_copy_start));
 #ifdef CONFIG_SECURITY_NETWORK
@@ -973,6 +1025,7 @@ static struct sock *sk_prot_alloc(struct proto *prot, gfp_t priority,
 
                if (!try_module_get(prot->owner))
                        goto out_free_sec;
+               sk_tx_queue_clear(sk);
        }
 
        return sk;
@@ -1003,6 +1056,17 @@ static void sk_prot_free(struct proto *prot, struct sock *sk)
        module_put(owner);
 }
 
+#ifdef CONFIG_CGROUPS
+void sock_update_classid(struct sock *sk)
+{
+       u32 classid = task_cls_classid(current);
+
+       if (classid && classid != sk->sk_classid)
+               sk->sk_classid = classid;
+}
+EXPORT_SYMBOL(sock_update_classid);
+#endif
+
 /**
  *     sk_alloc - All socket objects are allocated here
  *     @net: the applicable net namespace
@@ -1026,6 +1090,8 @@ struct sock *sk_alloc(struct net *net, int family, gfp_t priority,
                sock_lock_init(sk);
                sock_net_set(sk, get_net(net));
                atomic_set(&sk->sk_wmem_alloc, 1);
+
+               sock_update_classid(sk);
        }
 
        return sk;
@@ -1039,7 +1105,8 @@ static void __sk_free(struct sock *sk)
        if (sk->sk_destruct)
                sk->sk_destruct(sk);
 
-       filter = rcu_dereference(sk->sk_filter);
+       filter = rcu_dereference_check(sk->sk_filter,
+                                      atomic_read(&sk->sk_wmem_alloc) == 0);
        if (filter) {
                sk_filter_uncharge(sk, filter);
                rcu_assign_pointer(sk->sk_filter, NULL);
@@ -1104,6 +1171,7 @@ struct sock *sk_clone(const struct sock *sk, const gfp_t priority)
                sock_lock_init(newsk);
                bh_lock_sock(newsk);
                newsk->sk_backlog.head  = newsk->sk_backlog.tail = NULL;
+               newsk->sk_backlog.len = 0;
 
                atomic_set(&newsk->sk_rmem_alloc, 0);
                /*
@@ -1117,7 +1185,7 @@ struct sock *sk_clone(const struct sock *sk, const gfp_t priority)
                skb_queue_head_init(&newsk->sk_async_wait_queue);
 #endif
 
-               rwlock_init(&newsk->sk_dst_lock);
+               spin_lock_init(&newsk->sk_dst_lock);
                rwlock_init(&newsk->sk_callback_lock);
                lockdep_set_class_and_name(&newsk->sk_callback_lock,
                                af_callback_keys + newsk->sk_family,
@@ -1167,10 +1235,14 @@ struct sock *sk_clone(const struct sock *sk, const gfp_t priority)
                 */
                sk_refcnt_debug_inc(newsk);
                sk_set_socket(newsk, NULL);
-               newsk->sk_sleep  = NULL;
+               newsk->sk_wq = NULL;
 
                if (newsk->sk_prot->sockets_allocated)
                        percpu_counter_inc(newsk->sk_prot->sockets_allocated);
+
+               if (sock_flag(newsk, SOCK_TIMESTAMP) ||
+                   sock_flag(newsk, SOCK_TIMESTAMPING_RX_SOFTWARE))
+                       net_enable_timestamp();
        }
 out:
        return newsk;
@@ -1183,6 +1255,7 @@ void sk_setup_caps(struct sock *sk, struct dst_entry *dst)
        sk->sk_route_caps = dst->dev->features;
        if (sk->sk_route_caps & NETIF_F_GSO)
                sk->sk_route_caps |= NETIF_F_GSO_SOFTWARE;
+       sk->sk_route_caps &= ~sk->sk_route_nocaps;
        if (sk_can_gso(sk)) {
                if (dst->header_len) {
                        sk->sk_route_caps &= ~NETIF_F_GSO_MASK;
@@ -1196,12 +1269,12 @@ EXPORT_SYMBOL_GPL(sk_setup_caps);
 
 void __init sk_init(void)
 {
-       if (num_physpages <= 4096) {
+       if (totalram_pages <= 4096) {
                sysctl_wmem_max = 32767;
                sysctl_rmem_max = 32767;
                sysctl_wmem_default = 32767;
                sysctl_rmem_default = 32767;
-       } else if (num_physpages >= 131072) {
+       } else if (totalram_pages >= 131072) {
                sysctl_wmem_max = 131071;
                sysctl_rmem_max = 131071;
        }
@@ -1218,17 +1291,22 @@ void __init sk_init(void)
 void sock_wfree(struct sk_buff *skb)
 {
        struct sock *sk = skb->sk;
-       int res;
+       unsigned int len = skb->truesize;
 
-       /* In case it might be waiting for more memory. */
-       res = atomic_sub_return(skb->truesize, &sk->sk_wmem_alloc);
-       if (!sock_flag(sk, SOCK_USE_WRITE_QUEUE))
+       if (!sock_flag(sk, SOCK_USE_WRITE_QUEUE)) {
+               /*
+                * Keep a reference on sk_wmem_alloc, this will be released
+                * after sk_write_space() call
+                */
+               atomic_sub(len - 1, &sk->sk_wmem_alloc);
                sk->sk_write_space(sk);
+               len = 1;
+       }
        /*
-        * if sk_wmem_alloc reached 0, we are last user and should
-        * free this sock, as sk_free() call could not do it.
+        * if sk_wmem_alloc reaches 0, we must finish what sk_free()
+        * could not do because of in-flight packets
         */
-       if (res == 0)
+       if (atomic_sub_and_test(len, &sk->sk_wmem_alloc))
                __sk_free(sk);
 }
 EXPORT_SYMBOL(sock_wfree);
@@ -1346,7 +1424,7 @@ static long sock_wait_for_wmem(struct sock *sk, long timeo)
                if (signal_pending(current))
                        break;
                set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-               prepare_to_wait(sk->sk_sleep, &wait, TASK_INTERRUPTIBLE);
+               prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
                if (atomic_read(&sk->sk_wmem_alloc) < sk->sk_sndbuf)
                        break;
                if (sk->sk_shutdown & SEND_SHUTDOWN)
@@ -1355,7 +1433,7 @@ static long sock_wait_for_wmem(struct sock *sk, long timeo)
                        break;
                timeo = schedule_timeout(timeo);
        }
-       finish_wait(sk->sk_sleep, &wait);
+       finish_wait(sk_sleep(sk), &wait);
        return timeo;
 }
 
@@ -1482,6 +1560,7 @@ static void __release_sock(struct sock *sk)
                do {
                        struct sk_buff *next = skb->next;
 
+                       WARN_ON_ONCE(skb_dst_is_noref(skb));
                        skb->next = NULL;
                        sk_backlog_rcv(sk, skb);
 
@@ -1498,6 +1577,12 @@ static void __release_sock(struct sock *sk)
 
                bh_lock_sock(sk);
        } while ((skb = sk->sk_backlog.head) != NULL);
+
+       /*
+        * Doing the zeroing here guarantee we can not loop forever
+        * while a wild producer attempts to flood us.
+        */
+       sk->sk_backlog.len = 0;
 }
 
 /**
@@ -1515,11 +1600,11 @@ int sk_wait_data(struct sock *sk, long *timeo)
        int rc;
        DEFINE_WAIT(wait);
 
-       prepare_to_wait(sk->sk_sleep, &wait, TASK_INTERRUPTIBLE);
+       prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
        set_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
        rc = sk_wait_event(sk, timeo, !skb_queue_empty(&sk->sk_receive_queue));
        clear_bit(SOCK_ASYNC_WAITDATA, &sk->sk_socket->flags);
-       finish_wait(sk->sk_sleep, &wait);
+       finish_wait(sk_sleep(sk), &wait);
        return rc;
 }
 EXPORT_SYMBOL(sk_wait_data);
@@ -1687,7 +1772,7 @@ int sock_no_shutdown(struct socket *sock, int how)
 EXPORT_SYMBOL(sock_no_shutdown);
 
 int sock_no_setsockopt(struct socket *sock, int level, int optname,
-                   char __user *optval, int optlen)
+                   char __user *optval, unsigned int optlen)
 {
        return -EOPNOTSUPP;
 }
@@ -1741,41 +1826,53 @@ EXPORT_SYMBOL(sock_no_sendpage);
 
 static void sock_def_wakeup(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
-       if (sk_has_sleeper(sk))
-               wake_up_interruptible_all(sk->sk_sleep);
-       read_unlock(&sk->sk_callback_lock);
+       struct socket_wq *wq;
+
+       rcu_read_lock();
+       wq = rcu_dereference(sk->sk_wq);
+       if (wq_has_sleeper(wq))
+               wake_up_interruptible_all(&wq->wait);
+       rcu_read_unlock();
 }
 
 static void sock_def_error_report(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
-       if (sk_has_sleeper(sk))
-               wake_up_interruptible_poll(sk->sk_sleep, POLLERR);
+       struct socket_wq *wq;
+
+       rcu_read_lock();
+       wq = rcu_dereference(sk->sk_wq);
+       if (wq_has_sleeper(wq))
+               wake_up_interruptible_poll(&wq->wait, POLLERR);
        sk_wake_async(sk, SOCK_WAKE_IO, POLL_ERR);
-       read_unlock(&sk->sk_callback_lock);
+       rcu_read_unlock();
 }
 
 static void sock_def_readable(struct sock *sk, int len)
 {
-       read_lock(&sk->sk_callback_lock);
-       if (sk_has_sleeper(sk))
-               wake_up_interruptible_sync_poll(sk->sk_sleep, POLLIN |
+       struct socket_wq *wq;
+
+       rcu_read_lock();
+       wq = rcu_dereference(sk->sk_wq);
+       if (wq_has_sleeper(wq))
+               wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
                                                POLLRDNORM | POLLRDBAND);
        sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
-       read_unlock(&sk->sk_callback_lock);
+       rcu_read_unlock();
 }
 
 static void sock_def_write_space(struct sock *sk)
 {
-       read_lock(&sk->sk_callback_lock);
+       struct socket_wq *wq;
+
+       rcu_read_lock();
 
        /* Do not wake up a writer until he can make "significant"
         * progress.  --DaveM
         */
        if ((atomic_read(&sk->sk_wmem_alloc) << 1) <= sk->sk_sndbuf) {
-               if (sk_has_sleeper(sk))
-                       wake_up_interruptible_sync_poll(sk->sk_sleep, POLLOUT |
+               wq = rcu_dereference(sk->sk_wq);
+               if (wq_has_sleeper(wq))
+                       wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
                                                POLLWRNORM | POLLWRBAND);
 
                /* Should agree with poll, otherwise some programs break */
@@ -1783,7 +1880,7 @@ static void sock_def_write_space(struct sock *sk)
                        sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
        }
 
-       read_unlock(&sk->sk_callback_lock);
+       rcu_read_unlock();
 }
 
 static void sock_def_destruct(struct sock *sk)
@@ -1837,12 +1934,12 @@ void sock_init_data(struct socket *sock, struct sock *sk)
 
        if (sock) {
                sk->sk_type     =       sock->type;
-               sk->sk_sleep    =       &sock->wait;
+               sk->sk_wq       =       sock->wq;
                sock->sk        =       sk;
        } else
-               sk->sk_sleep    =       NULL;
+               sk->sk_wq       =       NULL;
 
-       rwlock_init(&sk->sk_dst_lock);
+       spin_lock_init(&sk->sk_dst_lock);
        rwlock_init(&sk->sk_callback_lock);
        lockdep_set_class_and_name(&sk->sk_callback_lock,
                        af_callback_keys + sk->sk_family,
@@ -2008,7 +2105,7 @@ EXPORT_SYMBOL(sock_common_recvmsg);
  *     Set socket options on an inet socket.
  */
 int sock_common_setsockopt(struct socket *sock, int level, int optname,
-                          char __user *optval, int optlen)
+                          char __user *optval, unsigned int optlen)
 {
        struct sock *sk = sock->sk;
 
@@ -2018,7 +2115,7 @@ EXPORT_SYMBOL(sock_common_setsockopt);
 
 #ifdef CONFIG_COMPAT
 int compat_sock_common_setsockopt(struct socket *sock, int level, int optname,
-                                 char __user *optval, int optlen)
+                                 char __user *optval, unsigned int optlen)
 {
        struct sock *sk = sock->sk;
 
@@ -2097,13 +2194,13 @@ int sock_prot_inuse_get(struct net *net, struct proto *prot)
 }
 EXPORT_SYMBOL_GPL(sock_prot_inuse_get);
 
-static int sock_inuse_init_net(struct net *net)
+static int __net_init sock_inuse_init_net(struct net *net)
 {
        net->core.inuse = alloc_percpu(struct prot_inuse);
        return net->core.inuse ? 0 : -ENOMEM;
 }
 
-static void sock_inuse_exit_net(struct net *net)
+static void __net_exit sock_inuse_exit_net(struct net *net)
 {
        free_percpu(net->core.inuse);
 }
@@ -2185,13 +2282,10 @@ int proto_register(struct proto *prot, int alloc_slab)
                }
 
                if (prot->rsk_prot != NULL) {
-                       static const char mask[] = "request_sock_%s";
-
-                       prot->rsk_prot->slab_name = kmalloc(strlen(prot->name) + sizeof(mask) - 1, GFP_KERNEL);
+                       prot->rsk_prot->slab_name = kasprintf(GFP_KERNEL, "request_sock_%s", prot->name);
                        if (prot->rsk_prot->slab_name == NULL)
                                goto out_free_sock_slab;
 
-                       sprintf(prot->rsk_prot->slab_name, mask, prot->name);
                        prot->rsk_prot->slab = kmem_cache_create(prot->rsk_prot->slab_name,
                                                                 prot->rsk_prot->obj_size, 0,
                                                                 SLAB_HWCACHE_ALIGN, NULL);
@@ -2204,14 +2298,11 @@ int proto_register(struct proto *prot, int alloc_slab)
                }
 
                if (prot->twsk_prot != NULL) {
-                       static const char mask[] = "tw_sock_%s";
-
-                       prot->twsk_prot->twsk_slab_name = kmalloc(strlen(prot->name) + sizeof(mask) - 1, GFP_KERNEL);
+                       prot->twsk_prot->twsk_slab_name = kasprintf(GFP_KERNEL, "tw_sock_%s", prot->name);
 
                        if (prot->twsk_prot->twsk_slab_name == NULL)
                                goto out_free_request_sock_slab;
 
-                       sprintf(prot->twsk_prot->twsk_slab_name, mask, prot->name);
                        prot->twsk_prot->twsk_slab =
                                kmem_cache_create(prot->twsk_prot->twsk_slab_name,
                                                  prot->twsk_prot->twsk_obj_size,
@@ -2238,7 +2329,8 @@ out_free_request_sock_slab:
                prot->rsk_prot->slab = NULL;
        }
 out_free_request_sock_slab_name:
-       kfree(prot->rsk_prot->slab_name);
+       if (prot->rsk_prot)
+               kfree(prot->rsk_prot->slab_name);
 out_free_sock_slab:
        kmem_cache_destroy(prot->slab);
        prot->slab = NULL;