}
}
+/* Helper to allocate iovec buffers for all vqs. */
+static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
+{
+ int i;
+ for (i = 0; i < dev->nvqs; ++i) {
+ dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect *
+ UIO_MAXIOV, GFP_KERNEL);
+ dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV,
+ GFP_KERNEL);
+ dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads *
+ UIO_MAXIOV, GFP_KERNEL);
+
+ if (!dev->vqs[i].indirect || !dev->vqs[i].log ||
+ !dev->vqs[i].heads)
+ goto err_nomem;
+ }
+ return 0;
+err_nomem:
+ for (; i >= 0; --i) {
+ kfree(dev->vqs[i].indirect);
+ kfree(dev->vqs[i].log);
+ kfree(dev->vqs[i].heads);
+ }
+ return -ENOMEM;
+}
+
+static void vhost_dev_free_iovecs(struct vhost_dev *dev)
+{
+ int i;
+ for (i = 0; i < dev->nvqs; ++i) {
+ kfree(dev->vqs[i].indirect);
+ dev->vqs[i].indirect = NULL;
+ kfree(dev->vqs[i].log);
+ dev->vqs[i].log = NULL;
+ kfree(dev->vqs[i].heads);
+ dev->vqs[i].heads = NULL;
+ }
+}
+
long vhost_dev_init(struct vhost_dev *dev,
struct vhost_virtqueue *vqs, int nvqs)
{
dev->worker = NULL;
for (i = 0; i < dev->nvqs; ++i) {
+ dev->vqs[i].log = NULL;
+ dev->vqs[i].indirect = NULL;
+ dev->vqs[i].heads = NULL;
dev->vqs[i].dev = dev;
mutex_init(&dev->vqs[i].mutex);
vhost_vq_reset(dev, dev->vqs + i);
if (err)
goto err_cgroup;
+ err = vhost_dev_alloc_iovecs(dev);
+ if (err)
+ goto err_cgroup;
+
return 0;
err_cgroup:
kthread_stop(worker);
vhost_dev_cleanup(dev);
memory->nregions = 0;
- dev->memory = memory;
+ RCU_INIT_POINTER(dev->memory, memory);
return 0;
}
fput(dev->vqs[i].call);
vhost_vq_reset(dev, dev->vqs + i);
}
+ vhost_dev_free_iovecs(dev);
if (dev->log_ctx)
eventfd_ctx_put(dev->log_ctx);
dev->log_ctx = NULL;
fput(dev->log_file);
dev->log_file = NULL;
/* No one will access memory at this point */
- kfree(dev->memory);
- dev->memory = NULL;
+ kfree(rcu_dereference_protected(dev->memory,
+ lockdep_is_held(&dev->mutex)));
+ RCU_INIT_POINTER(dev->memory, NULL);
if (dev->mm)
mmput(dev->mm);
dev->mm = NULL;
/* Make sure 64 bit math will not overflow. */
if (a > ULONG_MAX - (unsigned long)log_base ||
a + (unsigned long)log_base > ULONG_MAX)
- return -EFAULT;
+ return 0;
return access_ok(VERIFY_WRITE, log_base + a,
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
/* Caller should have device mutex but not vq mutex */
int vhost_log_access_ok(struct vhost_dev *dev)
{
- return memory_access_ok(dev, dev->memory, 1);
+ struct vhost_memory *mp;
+
+ mp = rcu_dereference_protected(dev->memory,
+ lockdep_is_held(&dev->mutex));
+ return memory_access_ok(dev, mp, 1);
}
/* Verify access for write logging. */
/* Caller should have vq mutex and device mutex */
static int vq_log_access_ok(struct vhost_virtqueue *vq, void __user *log_base)
{
- return vq_memory_access_ok(log_base, vq->dev->memory,
+ struct vhost_memory *mp;
+
+ mp = rcu_dereference_protected(vq->dev->memory,
+ lockdep_is_held(&vq->mutex));
+ return vq_memory_access_ok(log_base, mp,
vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
sizeof *vq->used +
kfree(newmem);
return -EFAULT;
}
- oldmem = d->memory;
+ oldmem = rcu_dereference_protected(d->memory,
+ lockdep_is_held(&d->mutex));
rcu_assign_pointer(d->memory, newmem);
synchronize_rcu();
kfree(oldmem);
int r;
if (!write_length)
return 0;
+ write_length += write_address % VHOST_PAGE_SIZE;
write_address /= VHOST_PAGE_SIZE;
for (;;) {
u64 base = (u64)(unsigned long)log_base;
if (write_length <= VHOST_PAGE_SIZE)
break;
write_length -= VHOST_PAGE_SIZE;
- write_address += VHOST_PAGE_SIZE;
+ write_address += 1;
}
return r;
}
}
ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect,
- ARRAY_SIZE(vq->indirect));
+ UIO_MAXIOV);
if (unlikely(ret < 0)) {
vq_err(vq, "Translation failure %d in indirect.\n", ret);
return ret;