Merge git://git.kernel.org/pub/scm/linux/kernel/git/gregkh/tty-2.6
[safe/jmp/linux-2.6] / drivers / vhost / vhost.c
index 5be11c9..3b83382 100644 (file)
@@ -236,6 +236,10 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem,
                               int log_all)
 {
        int i;
+
+        if (!mem)
+                return 0;
+
        for (i = 0; i < mem->nregions; ++i) {
                struct vhost_memory_region *m = mem->regions + i;
                unsigned long a = m->userspace_addr;
@@ -316,10 +320,8 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
 {
        struct vhost_memory mem, *newmem, *oldmem;
        unsigned long size = offsetof(struct vhost_memory, regions);
-       long r;
-       r = copy_from_user(&mem, m, size);
-       if (r)
-               return r;
+       if (copy_from_user(&mem, m, size))
+               return -EFAULT;
        if (mem.padding)
                return -EOPNOTSUPP;
        if (mem.nregions > VHOST_MEMORY_MAX_NREGIONS)
@@ -329,15 +331,16 @@ static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
                return -ENOMEM;
 
        memcpy(newmem, &mem, size);
-       r = copy_from_user(newmem->regions, m->regions,
-                          mem.nregions * sizeof *m->regions);
-       if (r) {
+       if (copy_from_user(newmem->regions, m->regions,
+                          mem.nregions * sizeof *m->regions)) {
                kfree(newmem);
-               return r;
+               return -EFAULT;
        }
 
-       if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL)))
+       if (!memory_access_ok(d, newmem, vhost_has_feature(d, VHOST_F_LOG_ALL))) {
+               kfree(newmem);
                return -EFAULT;
+       }
        oldmem = d->memory;
        rcu_assign_pointer(d->memory, newmem);
        synchronize_rcu();
@@ -370,7 +373,7 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
        r = get_user(idx, idxp);
        if (r < 0)
                return r;
-       if (idx > d->nvqs)
+       if (idx >= d->nvqs)
                return -ENOBUFS;
 
        vq = d->vqs + idx;
@@ -385,9 +388,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        r = -EBUSY;
                        break;
                }
-               r = copy_from_user(&s, argp, sizeof s);
-               if (r < 0)
+               if (copy_from_user(&s, argp, sizeof s)) {
+                       r = -EFAULT;
                        break;
+               }
                if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
                        r = -EINVAL;
                        break;
@@ -401,9 +405,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        r = -EBUSY;
                        break;
                }
-               r = copy_from_user(&s, argp, sizeof s);
-               if (r < 0)
+               if (copy_from_user(&s, argp, sizeof s)) {
+                       r = -EFAULT;
                        break;
+               }
                if (s.num > 0xffff) {
                        r = -EINVAL;
                        break;
@@ -415,12 +420,14 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
        case VHOST_GET_VRING_BASE:
                s.index = idx;
                s.num = vq->last_avail_idx;
-               r = copy_to_user(argp, &s, sizeof s);
+               if (copy_to_user(argp, &s, sizeof s))
+                       r = -EFAULT;
                break;
        case VHOST_SET_VRING_ADDR:
-               r = copy_from_user(&a, argp, sizeof a);
-               if (r < 0)
+               if (copy_from_user(&a, argp, sizeof a)) {
+                       r = -EFAULT;
                        break;
+               }
                if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
                        r = -EOPNOTSUPP;
                        break;
@@ -473,9 +480,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                vq->used = (void __user *)(unsigned long)a.used_user_addr;
                break;
        case VHOST_SET_VRING_KICK:
-               r = copy_from_user(&f, argp, sizeof f);
-               if (r < 0)
+               if (copy_from_user(&f, argp, sizeof f)) {
+                       r = -EFAULT;
                        break;
+               }
                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
                if (IS_ERR(eventfp)) {
                        r = PTR_ERR(eventfp);
@@ -488,9 +496,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        filep = eventfp;
                break;
        case VHOST_SET_VRING_CALL:
-               r = copy_from_user(&f, argp, sizeof f);
-               if (r < 0)
+               if (copy_from_user(&f, argp, sizeof f)) {
+                       r = -EFAULT;
                        break;
+               }
                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
                if (IS_ERR(eventfp)) {
                        r = PTR_ERR(eventfp);
@@ -506,9 +515,10 @@ static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp)
                        filep = eventfp;
                break;
        case VHOST_SET_VRING_ERR:
-               r = copy_from_user(&f, argp, sizeof f);
-               if (r < 0)
+               if (copy_from_user(&f, argp, sizeof f)) {
+                       r = -EFAULT;
                        break;
+               }
                eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd);
                if (IS_ERR(eventfp)) {
                        r = PTR_ERR(eventfp);
@@ -571,9 +581,10 @@ long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg)
                r = vhost_set_memory(d, argp);
                break;
        case VHOST_SET_LOG_BASE:
-               r = copy_from_user(&p, argp, sizeof p);
-               if (r < 0)
+               if (copy_from_user(&p, argp, sizeof p)) {
+                       r = -EFAULT;
                        break;
+               }
                if ((u64)(unsigned long)p != p) {
                        r = -EFAULT;
                        break;
@@ -711,8 +722,8 @@ int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
        return 0;
 }
 
-int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
-                  struct iovec iov[], int iov_size)
+static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
+                         struct iovec iov[], int iov_size)
 {
        const struct vhost_memory_region *reg;
        struct vhost_memory *mem;
@@ -737,7 +748,7 @@ int translate_desc(struct vhost_dev *dev, u64 addr, u32 len,
                _iov = iov + ret;
                size = reg->memory_size - addr + reg->guest_phys_addr;
                _iov->iov_len = min((u64)len, size);
-               _iov->iov_base = (void *)(unsigned long)
+               _iov->iov_base = (void __user *)(unsigned long)
                        (reg->userspace_addr + addr - reg->guest_phys_addr);
                s += size;
                addr += size;
@@ -802,7 +813,7 @@ static unsigned get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq,
        count = indirect->len / sizeof desc;
        /* Buffers are chained via a 16 bit next field, so
         * we can have at most 2^16 of these. */
-       if (count > USHORT_MAX + 1) {
+       if (count > USHRT_MAX + 1) {
                vq_err(vq, "Indirect buffer length too big: %d\n",
                       indirect->len);
                return -E2BIG;
@@ -991,7 +1002,7 @@ void vhost_discard_vq_desc(struct vhost_virtqueue *vq)
  * want to notify the guest, using eventfd. */
 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
 {
-       struct vring_used_elem *used;
+       struct vring_used_elem __user *used;
 
        /* The virtqueue contains a ring of used buffers.  Get a pointer to the
         * next entry in that used ring. */
@@ -1015,7 +1026,8 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
                smp_wmb();
                /* Log used ring entry write. */
                log_write(vq->log_base,
-                         vq->log_addr + ((void *)used - (void *)vq->used),
+                         vq->log_addr +
+                          ((void __user *)used - (void __user *)vq->used),
                          sizeof *used);
                /* Log used index update. */
                log_write(vq->log_base,
@@ -1031,7 +1043,12 @@ int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
 /* This actually signals the guest, using eventfd. */
 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
 {
-       __u16 flags = 0;
+       __u16 flags;
+       /* Flush out used index updates. This is paired
+        * with the barrier that the Guest executes when enabling
+        * interrupts. */
+       smp_mb();
+
        if (get_user(flags, &vq->avail->flags)) {
                vq_err(vq, "Failed to get flags");
                return;