Merge git://git.kernel.org/pub/scm/linux/kernel/git/gregkh/tty-2.6
[safe/jmp/linux-2.6] / drivers / vhost / vhost.c
index c8c25db..3b83382 100644 (file)
@@ -22,6 +22,7 @@
 #include <linux/poll.h>
 #include <linux/file.h>
 #include <linux/highmem.h>
+#include <linux/slab.h>
 
 #include <linux/net.h>
 #include <linux/if_packet.h>
@@ -121,6 +122,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
        vq->kick = NULL;
        vq->call_ctx = NULL;
        vq->call = NULL;
+       vq->log_ctx = NULL;
 }
 
 long vhost_dev_init(struct vhost_dev *dev,
@@ -234,6 +236,10 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
                               int log_all)
 {
        int i;
+
+        if (!mem)
+                return 0;
+
        for (i = 0; i < mem->nregions; ++i) {
                struct vhost_memory_region *m = mem->regions + i;
                unsigned long a = m->userspace_addr;
@@ -314,10 +320,8 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
 {
        struct vhost_memory mem, *newmem, *oldmem;
        unsigned long size = offsetof(struct vhost_memory, regions);
-       long r;
-       r = copy_from_user(&mem, m, size);
-       if (r)
-               return r;
+       if (copy_from_user(&mem, m, size))
+               return -EFAULT;
        if (mem.padding)
                return -EOPNOTSUPP;
        if (mem.nregions > VHOST_MEMORY_MAX_NREGIONS)
@@ -327,15 +331,16 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
                return -ENOMEM;
 
        memcpy(newmem, &mem, size);
-       r = copy_from_user(newmem->regions, m->regions,
-                          mem.nregions * sizeof *m->regions);
-       if (r) {
+       if (copy_from_user(newmem->regions, m->regions,
+                          mem.nregions * sizeof *m->regions)) {
                kfree(newmem);
-               return r;
+               return -EFAULT;
        }
 
-       if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL)))
+       if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL))) {
+               kfree(newmem);
                return -EFAULT;
+       }
        oldmem = d->memory;
        rcu_assign_pointer(d->memory, newmem);
        synchronize_rcu();
@@ -368,7 +373,7 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
        r = get_user(idx, idxp);
        if (r < 0)
                return r;
-       if (idx > d->nvqs)
+       if (idx >= d->nvqs)
                return -ENOBUFS;
 
        vq = d->vqs + idx;
@@ -383,9 +388,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        r = -EBUSY;
                        break;
                }
-               r = copy_from_user(&s, argp, sizeof s);
-               if (r < 0)
+               if (copy_from_user(&s, argp, sizeof s)) {
+                       r = -EFAULT;
                        break;
+               }
                if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
                        r = -EINVAL;
                        break;
@@ -399,9 +405,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        r = -EBUSY;
                        break;
                }
-               r = copy_from_user(&s, argp, sizeof s);
-               if (r < 0)
+               if (copy_from_user(&s, argp, sizeof s)) {
+                       r = -EFAULT;
                        break;
+               }
                if (s.num > 0xffff) {
                        r = -EINVAL;
                        break;
@@ -413,12 +420,14 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
        case VHOST_GET_VRING_BASE:
                s.index = idx;
                s.num = vq->last_avail_idx;
-               r = copy_to_user(argp, &s, sizeof s);
+               if (copy_to_user(argp, &s, sizeof s))
+                       r = -EFAULT;
                break;
        case VHOST_SET_VRING_ADDR:
-               r = copy_from_user(&a, argp, sizeof a);
-               if (r < 0)
+               if (copy_from_user(&a, argp, sizeof a)) {
+                       r = -EFAULT;
                        break;
+               }
                if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
                        r = -EOPNOTSUPP;
                        break;
@@ -471,12 +480,15 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                vq->used = (void __user *)(unsigned long)a.used_user_addr;
                break;
        case VHOST_SET_VRING_KICK:
-               r = copy_from_user(&f, argp, sizeof f);
-               if (r < 0)
+               if (copy_from_user(&f, argp, sizeof f)) {
+                       r = -EFAULT;
                        break;
+               }
                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
-               if (IS_ERR(eventfp))
-                       return PTR_ERR(eventfp);
+               if (IS_ERR(eventfp)) {
+                       r = PTR_ERR(eventfp);
+                       break;
+               }
                if (eventfp != vq->kick) {
                        pollstop = filep = vq->kick;
                        pollstart = vq->kick = eventfp;
@@ -484,12 +496,15 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        filep = eventfp;
                break;
        case VHOST_SET_VRING_CALL:
-               r = copy_from_user(&f, argp, sizeof f);
-               if (r < 0)
+               if (copy_from_user(&f, argp, sizeof f)) {
+                       r = -EFAULT;
                        break;
+               }
                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
-               if (IS_ERR(eventfp))
-                       return PTR_ERR(eventfp);
+               if (IS_ERR(eventfp)) {
+                       r = PTR_ERR(eventfp);
+                       break;
+               }
                if (eventfp != vq->call) {
                        filep = vq->call;
                        ctx = vq->call_ctx;
@@ -500,12 +515,15 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        filep = eventfp;
                break;
        case VHOST_SET_VRING_ERR:
-               r = copy_from_user(&f, argp, sizeof f);
-               if (r < 0)
+               if (copy_from_user(&f, argp, sizeof f)) {
+                       r = -EFAULT;
                        break;
+               }
                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
-               if (IS_ERR(eventfp))
-                       return PTR_ERR(eventfp);
+               if (IS_ERR(eventfp)) {
+                       r = PTR_ERR(eventfp);
+                       break;
+               }
                if (eventfp != vq->error) {
                        filep = vq->error;
                        vq->error = eventfp;
@@ -563,9 +581,10 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
                r = vhost_set_memory(d, argp);
                break;
        case VHOST_SET_LOG_BASE:
-               r = copy_from_user(&p, argp, sizeof p);
-               if (r < 0)
+               if (copy_from_user(&p, argp, sizeof p)) {
+                       r = -EFAULT;
                        break;
+               }
                if ((u64)(unsigned long)p != p) {
                        r = -EFAULT;
                        break;
@@ -645,8 +664,9 @@ static int set_bit_to_user(int nr, void __user *addr)
        int bit = nr + (log % PAGE_SIZE) * 8;
        int r;
        r = get_user_pages_fast(log, 1, 1, &page);
-       if (r)
+       if (r < 0)
                return r;
+       BUG_ON(r != 1);
        base = kmap_atomic(page, KM_USER0);
        set_bit(bit, base);
        kunmap_atomic(base, KM_USER0);
@@ -685,7 +705,7 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
        int i, r;
 
        /* Make sure data written is seen before log. */
-       wmb();
+       smp_wmb();
        for (i = 0; i < log_num; ++i) {
                u64 l = min(log[i].len, len);
                r = log_write(vq->log_base, log[i].addr, l);
@@ -702,8 +722,8 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
        return 0;
 }
 
-int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
-                  struct iovec iov[], int iov_size)
+static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
+                         struct iovec iov[], int iov_size)
 {
        const struct vhost_memory_region *reg;
        struct vhost_memory *mem;
@@ -728,7 +748,7 @@ int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
                _iov = iov + ret;
                size = reg->memory_size - addr + reg->guest_phys_addr;
                _iov->iov_len = min((u64)len, size);
-               _iov->iov_base = (void *)(unsigned long)
+               _iov->iov_base = (void __user *)(unsigned long)
                        (reg->userspace_addr + addr - reg->guest_phys_addr);
                s += size;
                addr += size;
@@ -793,7 +813,7 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
        count = indirect->len / sizeof desc;
        /* Buffers are chained via a 16 bit next field, so
         * we can have at most 2^16 of these. */
-       if (count > USHORT_MAX + 1) {
+       if (count > USHRT_MAX + 1) {
                vq_err(vq, "Indirect buffer length too big: %d\n",
                       indirect->len);
                return -E2BIG;
@@ -884,7 +904,7 @@ unsigned vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq,
                return vq->num;
 
        /* Only get avail ring entries after they have been exposed by guest. */
-       rmb();
+       smp_rmb();
 
        /* Grab the next descriptor number they're advertising, and increment
         * the index we've seen. */
@@ -982,7 +1002,7 @@ void vhost_discard_vq_desc(struct vhost_virtqueue *vq)
  * want to notify the guest, using eventfd. */
 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
 {
-       struct vring_used_elem *used;
+       struct vring_used_elem __user *used;
 
        /* The virtqueue contains a ring of used buffers.  Get a pointer to the
         * next entry in that used ring. */
@@ -996,18 +1016,23 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
                return -EFAULT;
        }
        /* Make sure buffer is written before we update index. */
-       wmb();
+       smp_wmb();
        if (put_user(vq->last_used_idx + 1, &vq->used->idx)) {
                vq_err(vq, "Failed to increment used idx");
                return -EFAULT;
        }
        if (unlikely(vq->log_used)) {
                /* Make sure data is seen before log. */
-               wmb();
-               log_write(vq->log_base, vq->log_addr + sizeof *vq->used->ring *
-                         (vq->last_used_idx % vq->num),
-                         sizeof *vq->used->ring);
-               log_write(vq->log_base, vq->log_addr, sizeof *vq->used->ring);
+               smp_wmb();
+               /* Log used ring entry write. */
+               log_write(vq->log_base,
+                         vq->log_addr +
+                          ((void __user *)used - (void __user *)vq->used),
+                         sizeof *used);
+               /* Log used index update. */
+               log_write(vq->log_base,
+                         vq->log_addr + offsetof(struct vring_used, idx),
+                         sizeof vq->used->idx);
                if (vq->log_ctx)
                        eventfd_signal(vq->log_ctx, 1);
        }
@@ -1018,7 +1043,12 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
 /* This actually signals the guest, using eventfd. */
 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
 {
-       __u16 flags = 0;
+       __u16 flags;
+       /* Flush out used index updates. This is paired
+        * with the barrier that the Guest executes when enabling
+        * interrupts. */
+       smp_mb();
+
        if (get_user(flags, &vq->avail->flags)) {
                vq_err(vq, "Failed to get flags");
                return;
@@ -1060,7 +1090,7 @@ bool vhost_enable_notify(struct vhost_virtqueue *vq)
        }
        /* They could have slipped one in as we were doing that: make
         * sure it's written, then check again. */
-       mb();
+       smp_mb();
        r = get_user(avail_idx, &vq->avail->idx);
        if (r) {
                vq_err(vq, "Failed to check avail idx at %p: %d\n",