IPoIB/cm: Fix racy use of receive WR/SGL in ipoib_cm_post_receive_nonsrq()
[safe/jmp/linux-2.6] / drivers / infiniband / ulp / ipoib / ipoib_cm.c
index 6223fc3..37bf67b 100644 (file)
@@ -111,18 +111,20 @@ static int ipoib_cm_post_receive_srq(struct net_device *dev, int id)
 }
 
 static int ipoib_cm_post_receive_nonsrq(struct net_device *dev,
-                                       struct ipoib_cm_rx *rx, int id)
+                                       struct ipoib_cm_rx *rx,
+                                       struct ib_recv_wr *wr,
+                                       struct ib_sge *sge, int id)
 {
        struct ipoib_dev_priv *priv = netdev_priv(dev);
        struct ib_recv_wr *bad_wr;
        int i, ret;
 
-       priv->cm.rx_wr.wr_id = id | IPOIB_OP_CM | IPOIB_OP_RECV;
+       wr->wr_id = id | IPOIB_OP_CM | IPOIB_OP_RECV;
 
        for (i = 0; i < IPOIB_CM_RX_SG; ++i)
-               priv->cm.rx_sge[i].addr = rx->rx_ring[id].mapping[i];
+               sge[i].addr = rx->rx_ring[id].mapping[i];
 
-       ret = ib_post_recv(rx->qp, &priv->cm.rx_wr, &bad_wr);
+       ret = ib_post_recv(rx->qp, wr, &bad_wr);
        if (unlikely(ret)) {
                ipoib_warn(priv, "post recv failed for buf %d (%d)\n", id, ret);
                ipoib_cm_dma_unmap_rx(priv, IPOIB_CM_RX_SG - 1,
@@ -320,10 +322,33 @@ static int ipoib_cm_modify_rx_qp(struct net_device *dev,
        return 0;
 }
 
+static void ipoib_cm_init_rx_wr(struct net_device *dev,
+                               struct ib_recv_wr *wr,
+                               struct ib_sge *sge)
+{
+       struct ipoib_dev_priv *priv = netdev_priv(dev);
+       int i;
+
+       for (i = 0; i < priv->cm.num_frags; ++i)
+               sge[i].lkey = priv->mr->lkey;
+
+       sge[0].length = IPOIB_CM_HEAD_SIZE;
+       for (i = 1; i < priv->cm.num_frags; ++i)
+               sge[i].length = PAGE_SIZE;
+
+       wr->next    = NULL;
+       wr->sg_list = priv->cm.rx_sge;
+       wr->num_sge = priv->cm.num_frags;
+}
+
 static int ipoib_cm_nonsrq_init_rx(struct net_device *dev, struct ib_cm_id *cm_id,
                                   struct ipoib_cm_rx *rx)
 {
        struct ipoib_dev_priv *priv = netdev_priv(dev);
+       struct {
+               struct ib_recv_wr wr;
+               struct ib_sge sge[IPOIB_CM_RX_SG];
+       } *t;
        int ret;
        int i;
 
@@ -331,6 +356,14 @@ static int ipoib_cm_nonsrq_init_rx(struct net_device *dev, struct ib_cm_id *cm_i
        if (!rx->rx_ring)
                return -ENOMEM;
 
+       t = kmalloc(sizeof *t, GFP_KERNEL);
+       if (!t) {
+               ret = -ENOMEM;
+               goto err_free;
+       }
+
+       ipoib_cm_init_rx_wr(dev, &t->wr, t->sge);
+
        spin_lock_irq(&priv->lock);
 
        if (priv->cm.nonsrq_conn_qp >= ipoib_max_conn_qp) {
@@ -349,8 +382,8 @@ static int ipoib_cm_nonsrq_init_rx(struct net_device *dev, struct ib_cm_id *cm_i
                        ipoib_warn(priv, "failed to allocate receive buffer %d\n", i);
                                ret = -ENOMEM;
                                goto err_count;
-                       }
-               ret = ipoib_cm_post_receive_nonsrq(dev, rx, i);
+               }
+               ret = ipoib_cm_post_receive_nonsrq(dev, rx, &t->wr, t->sge, i);
                if (ret) {
                        ipoib_warn(priv, "ipoib_cm_post_receive_nonsrq "
                                   "failed for buf %d\n", i);
@@ -361,6 +394,8 @@ static int ipoib_cm_nonsrq_init_rx(struct net_device *dev, struct ib_cm_id *cm_i
 
        rx->recv_count = ipoib_recvq_size;
 
+       kfree(t);
+
        return 0;
 
 err_count:
@@ -369,6 +404,7 @@ err_count:
        spin_unlock_irq(&priv->lock);
 
 err_free:
+       kfree(t);
        ipoib_cm_free_rx_ring(dev, rx->rx_ring);
 
        return ret;
@@ -637,7 +673,10 @@ repost:
                        ipoib_warn(priv, "ipoib_cm_post_receive_srq failed "
                                   "for buf %d\n", wr_id);
        } else {
-               if (unlikely(ipoib_cm_post_receive_nonsrq(dev, p, wr_id))) {
+               if (unlikely(ipoib_cm_post_receive_nonsrq(dev, p,
+                                                         &priv->cm.rx_wr,
+                                                         priv->cm.rx_sge,
+                                                         wr_id))) {
                        --p->recv_count;
                        ipoib_warn(priv, "ipoib_cm_post_receive_nonsrq failed "
                                   "for buf %d\n", wr_id);
@@ -1502,15 +1541,7 @@ int ipoib_cm_dev_init(struct net_device *dev)
                priv->cm.num_frags  = IPOIB_CM_RX_SG;
        }
 
-       for (i = 0; i < priv->cm.num_frags; ++i)
-               priv->cm.rx_sge[i].lkey = priv->mr->lkey;
-
-       priv->cm.rx_sge[0].length = IPOIB_CM_HEAD_SIZE;
-       for (i = 1; i < priv->cm.num_frags; ++i)
-               priv->cm.rx_sge[i].length = PAGE_SIZE;
-       priv->cm.rx_wr.next = NULL;
-       priv->cm.rx_wr.sg_list = priv->cm.rx_sge;
-       priv->cm.rx_wr.num_sge = priv->cm.num_frags;
+       ipoib_cm_init_rx_wr(dev, &priv->cm.rx_wr, priv->cm.rx_sge);
 
        if (ipoib_cm_has_srq(dev)) {
                for (i = 0; i < ipoib_recvq_size; ++i) {