1 /*
2  * Copyright (c) 2013-2014, Google, Inc. All rights reserved
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining
5  * a copy of this software and associated documentation files
6  * (the "Software"), to deal in the Software without restriction,
7  * including without limitation the rights to use, copy, modify, merge,
8  * publish, distribute, sublicense, and/or sell copies of the Software,
9  * and to permit persons to whom the Software is furnished to do so,
10  * subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice shall be
13  * included in all copies or substantial portions of the Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
16  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
17  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
18  * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
19  * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
20  * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
21  * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
22  */
23 
24 #include "vqueue.h"
25 
26 #include <assert.h>
27 #include <err.h>
28 #include <lib/sm.h>
29 #include <lk/pow2.h>
30 #include <stddef.h>
31 #include <stdlib.h>
32 #include <sys/types.h>
33 #include <trace.h>
34 
35 #include <arch/arch_ops.h>
36 #include <kernel/vm.h>
37 
38 #include <lib/trusty/uio.h>
39 
40 #include <virtio/virtio_ring.h>
41 
42 #define LOCAL_TRACE 0
43 
44 #define VQ_LOCK_FLAGS SPIN_LOCK_FLAG_INTERRUPTS
45 
46 /* Arbitrary limit to ensure vring size doesn't overflow */
47 #define VQ_MAX_RING_NUM 256
48 
vqueue_init(struct vqueue * vq,uint32_t id,ext_mem_client_id_t client_id,ext_mem_obj_id_t shared_mem_id,uint num,ulong align,void * priv,vqueue_cb_t notify_cb,vqueue_cb_t kick_cb)49 int vqueue_init(struct vqueue* vq,
50                 uint32_t id,
51                 ext_mem_client_id_t client_id,
52                 ext_mem_obj_id_t shared_mem_id,
53                 uint num,
54                 ulong align,
55                 void* priv,
56                 vqueue_cb_t notify_cb,
57                 vqueue_cb_t kick_cb) {
58     status_t ret;
59     void* vptr = NULL;
60 
61     DEBUG_ASSERT(vq);
62 
63     if (num > VQ_MAX_RING_NUM) {
64         LTRACEF("vring too large: %u\n", num);
65         return ERR_INVALID_ARGS;
66     }
67 
68     if (align == 0 || !ispow2(align)) {
69         LTRACEF("bad vring alignment: %lu\n", align);
70         return ERR_INVALID_ARGS;
71     }
72 
73     vq->vring_sz = vring_size(num, align);
74     ret = ext_mem_map_obj_id(vmm_get_kernel_aspace(), "vqueue", client_id,
75                              shared_mem_id, 0, 0,
76                              round_up(vq->vring_sz, PAGE_SIZE), &vptr,
77                              PAGE_SIZE_SHIFT, 0, ARCH_MMU_FLAG_PERM_NO_EXECUTE);
78     if (ret != NO_ERROR) {
79         LTRACEF("cannot map vring (%d)\n", ret);
80         return (int)ret;
81     }
82 
83     vring_init(&vq->vring, num, vptr, align);
84 
85     vq->id = id;
86     vq->priv = priv;
87     vq->notify_cb = notify_cb;
88     vq->kick_cb = kick_cb;
89     vq->vring_addr = (vaddr_t)vptr;
90 
91     event_init(&vq->avail_event, false, 0);
92 
93     return NO_ERROR;
94 }
95 
vqueue_destroy(struct vqueue * vq)96 void vqueue_destroy(struct vqueue* vq) {
97     vaddr_t vring_addr;
98     spin_lock_saved_state_t state;
99 
100     DEBUG_ASSERT(vq);
101 
102     spin_lock_save(&vq->slock, &state, VQ_LOCK_FLAGS);
103     vring_addr = vq->vring_addr;
104     vq->vring_addr = (vaddr_t)NULL;
105     vq->vring_sz = 0;
106     spin_unlock_restore(&vq->slock, state, VQ_LOCK_FLAGS);
107 
108     vmm_free_region(vmm_get_kernel_aspace(), vring_addr);
109 }
110 
vqueue_signal_avail(struct vqueue * vq)111 void vqueue_signal_avail(struct vqueue* vq) {
112     spin_lock_saved_state_t state;
113 
114     spin_lock_save(&vq->slock, &state, VQ_LOCK_FLAGS);
115     if (vq->vring_addr)
116         vq->vring.used->flags |= VRING_USED_F_NO_NOTIFY;
117     event_signal(&vq->avail_event, false);
118     spin_unlock_restore(&vq->slock, state, VQ_LOCK_FLAGS);
119 }
120 
121 /* The other side of virtio pushes buffers into our avail ring, and pulls them
122  * off our used ring. We do the reverse. We take buffers off the avail ring,
123  * and put them onto the used ring.
124  */
125 
_vqueue_get_avail_buf_locked(struct vqueue * vq,struct vqueue_buf * iovbuf)126 static int _vqueue_get_avail_buf_locked(struct vqueue* vq,
127                                         struct vqueue_buf* iovbuf) {
128     uint16_t next_idx;
129     struct vring_desc* desc;
130 
131     DEBUG_ASSERT(vq);
132     DEBUG_ASSERT(iovbuf);
133 
134     if (!vq->vring_addr) {
135         /* there is no vring - return an error */
136         return ERR_CHANNEL_CLOSED;
137     }
138 
139     /* the idx counter is free running, so check that it's no more
140      * than the ring size away from last time we checked... this
141      * should *never* happen, but we should be careful. */
142     uint16_t avail_cnt;
143     __builtin_sub_overflow(vq->vring.avail->idx, vq->last_avail_idx,
144                            &avail_cnt);
145     if (unlikely(avail_cnt > (uint16_t)vq->vring.num)) {
146         /* such state is not recoverable */
147         panic("vq %u: new avail idx out of range (old %u new %u)\n", vq->id,
148               vq->last_avail_idx, vq->vring.avail->idx);
149     }
150 
151     if (vq->last_avail_idx == vq->vring.avail->idx) {
152         event_unsignal(&vq->avail_event);
153         vq->vring.used->flags &= ~VRING_USED_F_NO_NOTIFY;
154         smp_mb();
155         if (vq->last_avail_idx == vq->vring.avail->idx) {
156             /* no buffers left */
157             return ERR_NOT_ENOUGH_BUFFER;
158         }
159         vq->vring.used->flags |= VRING_USED_F_NO_NOTIFY;
160         event_signal(&vq->avail_event, false);
161     }
162     smp_rmb();
163 
164     next_idx = vq->vring.avail->ring[vq->last_avail_idx % vq->vring.num];
165     __builtin_add_overflow(vq->last_avail_idx, 1, &vq->last_avail_idx);
166 
167     if (unlikely(next_idx >= vq->vring.num)) {
168         /* index of the first descriptor in chain is out of range.
169            vring is in non recoverable state: we cannot even return
170            an error to the other side */
171         panic("vq %u: head out of range %u (max %u)\n", vq->id, next_idx,
172               vq->vring.num);
173     }
174 
175     iovbuf->head = next_idx;
176     iovbuf->in_iovs.used = 0;
177     iovbuf->in_iovs.len = 0;
178     iovbuf->out_iovs.used = 0;
179     iovbuf->out_iovs.len = 0;
180 
181     do {
182         struct vqueue_iovs* iovlist;
183 
184         if (unlikely(next_idx >= vq->vring.num)) {
185             /* Descriptor chain is in invalid state.
186              * Abort message handling, return an error to the
187              * other side and let it deal with it.
188              */
189             LTRACEF("vq %p: head out of range %u (max %u)\n", vq, next_idx,
190                     vq->vring.num);
191             return ERR_NOT_VALID;
192         }
193 
194         desc = &vq->vring.desc[next_idx];
195         if (desc->flags & VRING_DESC_F_WRITE)
196             iovlist = &iovbuf->out_iovs;
197         else
198             iovlist = &iovbuf->in_iovs;
199 
200         if (iovlist->used < iovlist->cnt) {
201             /* .iov_base will be set when we map this iov */
202             iovlist->iovs[iovlist->used].iov_len = desc->len;
203             iovlist->shared_mem_id[iovlist->used] =
204                     (ext_mem_obj_id_t)desc->addr;
205             assert(iovlist->shared_mem_id[iovlist->used] == desc->addr);
206             iovlist->used++;
207             iovlist->len += desc->len;
208         } else {
209             return ERR_TOO_BIG;
210         }
211 
212         /* go to the next entry in the descriptor chain */
213         next_idx = desc->next;
214     } while (desc->flags & VRING_DESC_F_NEXT);
215 
216     return NO_ERROR;
217 }
218 
vqueue_get_avail_buf(struct vqueue * vq,struct vqueue_buf * iovbuf)219 int vqueue_get_avail_buf(struct vqueue* vq, struct vqueue_buf* iovbuf) {
220     spin_lock_saved_state_t state;
221 
222     spin_lock_save(&vq->slock, &state, VQ_LOCK_FLAGS);
223     int ret = _vqueue_get_avail_buf_locked(vq, iovbuf);
224     spin_unlock_restore(&vq->slock, state, VQ_LOCK_FLAGS);
225     return ret;
226 }
227 
228 struct vqueue_mem_obj {
229     ext_mem_client_id_t client_id;
230     ext_mem_obj_id_t id;
231     void* iov_base;
232     size_t size;
233     struct bst_node node;
234 };
235 
vqueue_mem_obj_from_bst_node(struct bst_node * node)236 static struct vqueue_mem_obj* vqueue_mem_obj_from_bst_node(
237         struct bst_node* node) {
238     return containerof(node, struct vqueue_mem_obj, node);
239 }
240 
vqueue_mem_obj_cmp(struct bst_node * a_bst,struct bst_node * b_bst)241 static int vqueue_mem_obj_cmp(struct bst_node* a_bst, struct bst_node* b_bst) {
242     struct vqueue_mem_obj* a = vqueue_mem_obj_from_bst_node(a_bst);
243     struct vqueue_mem_obj* b = vqueue_mem_obj_from_bst_node(b_bst);
244 
245     return a->id < b->id ? 1 : a->id > b->id ? -1 : 0;
246 }
247 
vqueue_mem_obj_initialize(struct vqueue_mem_obj * obj,ext_mem_client_id_t client_id,ext_mem_obj_id_t id,void * iov_base,size_t size)248 static void vqueue_mem_obj_initialize(struct vqueue_mem_obj* obj,
249                                       ext_mem_client_id_t client_id,
250                                       ext_mem_obj_id_t id,
251                                       void* iov_base,
252                                       size_t size) {
253     obj->client_id = client_id;
254     obj->id = id;
255     obj->iov_base = iov_base;
256     obj->size = size;
257     bst_node_initialize(&obj->node);
258 }
259 
vqueue_mem_insert(struct bst_root * objs,struct vqueue_mem_obj * obj)260 static bool vqueue_mem_insert(struct bst_root* objs,
261                               struct vqueue_mem_obj* obj) {
262     return bst_insert(objs, &obj->node, vqueue_mem_obj_cmp);
263 }
264 
vqueue_mem_lookup(struct bst_root * objs,ext_mem_obj_id_t id)265 static struct vqueue_mem_obj* vqueue_mem_lookup(struct bst_root* objs,
266                                                 ext_mem_obj_id_t id) {
267     struct vqueue_mem_obj ref_obj;
268     ref_obj.id = id;
269     return bst_search_type(objs, &ref_obj, vqueue_mem_obj_cmp,
270                            struct vqueue_mem_obj, node);
271 }
272 
vqueue_mem_delete(struct bst_root * objs,struct vqueue_mem_obj * obj)273 static inline void vqueue_mem_delete(struct bst_root* objs,
274                                      struct vqueue_mem_obj* obj) {
275     bst_delete(objs, &obj->node);
276 }
277 
vqueue_map_iovs(ext_mem_client_id_t client_id,struct vqueue_iovs * vqiovs,u_int flags,struct vqueue_mapped_list * mapped_list)278 int vqueue_map_iovs(ext_mem_client_id_t client_id,
279                     struct vqueue_iovs* vqiovs,
280                     u_int flags,
281                     struct vqueue_mapped_list* mapped_list) {
282     uint i;
283     int ret;
284     size_t size;
285     struct vqueue_mem_obj* obj;
286 
287     DEBUG_ASSERT(vqiovs);
288     DEBUG_ASSERT(vqiovs->shared_mem_id);
289     DEBUG_ASSERT(vqiovs->iovs);
290     DEBUG_ASSERT(vqiovs->used <= vqiovs->cnt);
291 
292     for (i = 0; i < vqiovs->used; i++) {
293         /* see if it's already been mapped */
294         mutex_acquire(&mapped_list->lock);
295         obj = vqueue_mem_lookup(&mapped_list->list, vqiovs->shared_mem_id[i]);
296         mutex_release(&mapped_list->lock);
297 
298         if (obj && obj->client_id == client_id &&
299             vqiovs->iovs[i].iov_len <= obj->size) {
300             LTRACEF("iov restored %s id= %lu (base= %p, size= %lu)\n",
301                     mapped_list->in_direction ? "IN" : "OUT",
302                     (unsigned long)vqiovs->shared_mem_id[i], obj->iov_base,
303                     (unsigned long)obj->size);
304             vqiovs->iovs[i].iov_base = obj->iov_base;
305             continue; /* use the previously mapped */
306         } else if (obj) {
307             /* otherwise, we need to drop old mapping and remap  */
308             TRACEF("iov needs remapped for id= %lu\n",
309                    (unsigned long)vqiovs->shared_mem_id[i]);
310             mutex_acquire(&mapped_list->lock);
311             vqueue_mem_delete(&mapped_list->list, obj);
312             mutex_release(&mapped_list->lock);
313             free(obj);
314         }
315 
316         /* allocate since it may be reused instead of unmapped after use */
317         obj = calloc(1, sizeof(struct vqueue_mem_obj));
318         if (unlikely(!obj)) {
319             TRACEF("calloc failure for vqueue_mem_obj for iov\n");
320             ret = ERR_NO_MEMORY;
321             goto err;
322         }
323 
324         /* map it */
325         vqiovs->iovs[i].iov_base = NULL;
326         size = round_up(vqiovs->iovs[i].iov_len, PAGE_SIZE);
327         ret = ext_mem_map_obj_id(vmm_get_kernel_aspace(), "vqueue-buf",
328                                  client_id, vqiovs->shared_mem_id[i], 0, 0,
329                                  size, &vqiovs->iovs[i].iov_base,
330                                  PAGE_SIZE_SHIFT, 0, flags);
331         if (ret) {
332             free(obj);
333             goto err;
334         }
335 
336         vqueue_mem_obj_initialize(obj, client_id, vqiovs->shared_mem_id[i],
337                                   vqiovs->iovs[i].iov_base, size);
338 
339         mutex_acquire(&mapped_list->lock);
340         if (unlikely(!vqueue_mem_insert(&mapped_list->list, obj)))
341             panic("Unhandled duplicate entry in ext_mem for iov\n");
342         mutex_release(&mapped_list->lock);
343 
344         LTRACEF("iov saved %s id= %lu (base= %p, size= %lu)\n",
345                 mapped_list->in_direction ? "IN" : "OUT",
346                 (unsigned long)vqiovs->shared_mem_id[i],
347                 vqiovs->iovs[i].iov_base, (unsigned long)size);
348     }
349 
350     return NO_ERROR;
351 
352 err:
353     while (i) {
354         i--;
355         vmm_free_region(vmm_get_kernel_aspace(),
356                         (vaddr_t)vqiovs->iovs[i].iov_base);
357         vqiovs->iovs[i].iov_base = NULL;
358     }
359     return ret;
360 }
361 
vqueue_unmap_iovs(struct vqueue_iovs * vqiovs,struct vqueue_mapped_list * mapped_list)362 void vqueue_unmap_iovs(struct vqueue_iovs* vqiovs,
363                        struct vqueue_mapped_list* mapped_list) {
364     struct vqueue_mem_obj* obj;
365 
366     DEBUG_ASSERT(vqiovs);
367     DEBUG_ASSERT(vqiovs->shared_mem_id);
368     DEBUG_ASSERT(vqiovs->iovs);
369     DEBUG_ASSERT(vqiovs->used <= vqiovs->cnt);
370 
371     for (uint i = 0; i < vqiovs->used; i++) {
372         /* base is expected to be set */
373         DEBUG_ASSERT(vqiovs->iovs[i].iov_base);
374         vmm_free_region(vmm_get_kernel_aspace(),
375                         (vaddr_t)vqiovs->iovs[i].iov_base);
376         vqiovs->iovs[i].iov_base = NULL;
377 
378         /* remove from list since it has been unmapped */
379         mutex_acquire(&mapped_list->lock);
380         obj = vqueue_mem_lookup(&mapped_list->list, vqiovs->shared_mem_id[i]);
381         if (obj) {
382             LTRACEF("iov removed %s id= %lu (base= %p, size= %lu)\n",
383                     mapped_list->in_direction ? "IN" : "OUT",
384                     (unsigned long)vqiovs->shared_mem_id[i],
385                     vqiovs->iovs[i].iov_base,
386                     (unsigned long)vqiovs->iovs[i].iov_len);
387             vqueue_mem_delete(&mapped_list->list, obj);
388             free(obj);
389         } else {
390             TRACEF("iov mapping not found for id= %lu (base= %p, size= %lu)\n",
391                    (unsigned long)vqiovs->shared_mem_id[i],
392                    vqiovs->iovs[i].iov_base,
393                    (unsigned long)vqiovs->iovs[i].iov_len);
394         }
395         mutex_release(&mapped_list->lock);
396     }
397 }
398 
vqueue_unmap_memid(ext_mem_obj_id_t id,struct vqueue_mapped_list * mapped_list[],int list_cnt)399 int vqueue_unmap_memid(ext_mem_obj_id_t id,
400                        struct vqueue_mapped_list* mapped_list[],
401                        int list_cnt) {
402     struct vqueue_mapped_list* mapped;
403     struct vqueue_mem_obj* obj;
404     struct vqueue_iovs fake_vqiovs;
405     ext_mem_obj_id_t fake_shared_mem_id[1];
406     struct iovec_kern fake_iovs[1];
407 
408     /* determine which list this entry is in */
409     for (int i = 0; i < list_cnt; i++) {
410         mapped = mapped_list[i];
411         obj = vqueue_mem_lookup(&mapped->list, id);
412         if (obj)
413             break;
414         mapped = NULL;
415     }
416 
417     if (mapped) {
418         /* fake a vqueue_iovs struct to use common interface */
419         memset(&fake_vqiovs, 0, sizeof(fake_vqiovs));
420         fake_vqiovs.iovs = fake_iovs;
421         fake_vqiovs.shared_mem_id = fake_shared_mem_id;
422         fake_vqiovs.used = 1;
423         fake_vqiovs.cnt = 1;
424         fake_vqiovs.iovs[0].iov_base = obj->iov_base;
425         fake_vqiovs.iovs[0].iov_len = obj->size;
426         fake_vqiovs.shared_mem_id[0] = id;
427 
428         /* unmap */
429         vqueue_unmap_iovs(&fake_vqiovs, mapped);
430 
431         return NO_ERROR;
432     }
433 
434     return ERR_NOT_FOUND;
435 }
436 
_vqueue_add_buf_locked(struct vqueue * vq,struct vqueue_buf * buf,uint32_t len)437 static int _vqueue_add_buf_locked(struct vqueue* vq,
438                                   struct vqueue_buf* buf,
439                                   uint32_t len) {
440     struct vring_used_elem* used;
441 
442     DEBUG_ASSERT(vq);
443     DEBUG_ASSERT(buf);
444 
445     if (!vq->vring_addr) {
446         /* there is no vring - return an error */
447         return ERR_CHANNEL_CLOSED;
448     }
449 
450     if (buf->head >= vq->vring.num) {
451         /* this would probable mean corrupted vring */
452         LTRACEF("vq %p: head (%u) out of range (%u)\n", vq, buf->head,
453                 vq->vring.num);
454         return ERR_NOT_VALID;
455     }
456 
457     used = &vq->vring.used->ring[vq->vring.used->idx % vq->vring.num];
458     used->id = buf->head;
459     used->len = len;
460     smp_wmb();
461     __builtin_add_overflow(vq->vring.used->idx, 1, &vq->vring.used->idx);
462     return NO_ERROR;
463 }
464 
vqueue_add_buf(struct vqueue * vq,struct vqueue_buf * buf,uint32_t len)465 int vqueue_add_buf(struct vqueue* vq, struct vqueue_buf* buf, uint32_t len) {
466     spin_lock_saved_state_t state;
467 
468     spin_lock_save(&vq->slock, &state, VQ_LOCK_FLAGS);
469     int ret = _vqueue_add_buf_locked(vq, buf, len);
470     spin_unlock_restore(&vq->slock, state, VQ_LOCK_FLAGS);
471     return ret;
472 }
473