]> git.karo-electronics.de Git - karo-tx-linux.git/blob - net/vmw_vsock/virtio_transport_common.c
Merge tag 'mlx5-updates-2017-06-16' of git://git.kernel.org/pub/scm/linux/kernel...
[karo-tx-linux.git] / net / vmw_vsock / virtio_transport_common.c
1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 #include <uapi/linux/vsockmon.h>
20
21 #include <net/sock.h>
22 #include <net/af_vsock.h>
23
24 #define CREATE_TRACE_POINTS
25 #include <trace/events/vsock_virtio_transport_common.h>
26
27 /* How long to wait for graceful shutdown of a connection */
28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
29
30 static const struct virtio_transport *virtio_transport_get_ops(void)
31 {
32         const struct vsock_transport *t = vsock_core_get_transport();
33
34         return container_of(t, struct virtio_transport, transport);
35 }
36
37 static struct virtio_vsock_pkt *
38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
39                            size_t len,
40                            u32 src_cid,
41                            u32 src_port,
42                            u32 dst_cid,
43                            u32 dst_port)
44 {
45         struct virtio_vsock_pkt *pkt;
46         int err;
47
48         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
49         if (!pkt)
50                 return NULL;
51
52         pkt->hdr.type           = cpu_to_le16(info->type);
53         pkt->hdr.op             = cpu_to_le16(info->op);
54         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
55         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
56         pkt->hdr.src_port       = cpu_to_le32(src_port);
57         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
58         pkt->hdr.flags          = cpu_to_le32(info->flags);
59         pkt->len                = len;
60         pkt->hdr.len            = cpu_to_le32(len);
61         pkt->reply              = info->reply;
62         pkt->vsk                = info->vsk;
63
64         if (info->msg && len > 0) {
65                 pkt->buf = kmalloc(len, GFP_KERNEL);
66                 if (!pkt->buf)
67                         goto out_pkt;
68                 err = memcpy_from_msg(pkt->buf, info->msg, len);
69                 if (err)
70                         goto out;
71         }
72
73         trace_virtio_transport_alloc_pkt(src_cid, src_port,
74                                          dst_cid, dst_port,
75                                          len,
76                                          info->type,
77                                          info->op,
78                                          info->flags);
79
80         return pkt;
81
82 out:
83         kfree(pkt->buf);
84 out_pkt:
85         kfree(pkt);
86         return NULL;
87 }
88
89 /* Packet capture */
90 static struct sk_buff *virtio_transport_build_skb(void *opaque)
91 {
92         struct virtio_vsock_pkt *pkt = opaque;
93         unsigned char *t_hdr, *payload;
94         struct af_vsockmon_hdr *hdr;
95         struct sk_buff *skb;
96
97         skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
98                         GFP_ATOMIC);
99         if (!skb)
100                 return NULL;
101
102         hdr = skb_put(skb, sizeof(*hdr));
103
104         /* pkt->hdr is little-endian so no need to byteswap here */
105         hdr->src_cid = pkt->hdr.src_cid;
106         hdr->src_port = pkt->hdr.src_port;
107         hdr->dst_cid = pkt->hdr.dst_cid;
108         hdr->dst_port = pkt->hdr.dst_port;
109
110         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
111         hdr->len = cpu_to_le16(sizeof(pkt->hdr));
112         memset(hdr->reserved, 0, sizeof(hdr->reserved));
113
114         switch (le16_to_cpu(pkt->hdr.op)) {
115         case VIRTIO_VSOCK_OP_REQUEST:
116         case VIRTIO_VSOCK_OP_RESPONSE:
117                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
118                 break;
119         case VIRTIO_VSOCK_OP_RST:
120         case VIRTIO_VSOCK_OP_SHUTDOWN:
121                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
122                 break;
123         case VIRTIO_VSOCK_OP_RW:
124                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
125                 break;
126         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
127         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
128                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
129                 break;
130         default:
131                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
132                 break;
133         }
134
135         t_hdr = skb_put_data(skb, &pkt->hdr, sizeof(pkt->hdr));
136
137         if (pkt->len) {
138                 payload = skb_put_data(skb, pkt->buf, pkt->len);
139         }
140
141         return skb;
142 }
143
144 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
145 {
146         vsock_deliver_tap(virtio_transport_build_skb, pkt);
147 }
148 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
149
150 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
151                                           struct virtio_vsock_pkt_info *info)
152 {
153         u32 src_cid, src_port, dst_cid, dst_port;
154         struct virtio_vsock_sock *vvs;
155         struct virtio_vsock_pkt *pkt;
156         u32 pkt_len = info->pkt_len;
157
158         src_cid = vm_sockets_get_local_cid();
159         src_port = vsk->local_addr.svm_port;
160         if (!info->remote_cid) {
161                 dst_cid = vsk->remote_addr.svm_cid;
162                 dst_port = vsk->remote_addr.svm_port;
163         } else {
164                 dst_cid = info->remote_cid;
165                 dst_port = info->remote_port;
166         }
167
168         vvs = vsk->trans;
169
170         /* we can send less than pkt_len bytes */
171         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
172                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
173
174         /* virtio_transport_get_credit might return less than pkt_len credit */
175         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
176
177         /* Do not send zero length OP_RW pkt */
178         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
179                 return pkt_len;
180
181         pkt = virtio_transport_alloc_pkt(info, pkt_len,
182                                          src_cid, src_port,
183                                          dst_cid, dst_port);
184         if (!pkt) {
185                 virtio_transport_put_credit(vvs, pkt_len);
186                 return -ENOMEM;
187         }
188
189         virtio_transport_inc_tx_pkt(vvs, pkt);
190
191         return virtio_transport_get_ops()->send_pkt(pkt);
192 }
193
194 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
195                                         struct virtio_vsock_pkt *pkt)
196 {
197         vvs->rx_bytes += pkt->len;
198 }
199
200 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
201                                         struct virtio_vsock_pkt *pkt)
202 {
203         vvs->rx_bytes -= pkt->len;
204         vvs->fwd_cnt += pkt->len;
205 }
206
207 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
208 {
209         spin_lock_bh(&vvs->tx_lock);
210         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
211         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
212         spin_unlock_bh(&vvs->tx_lock);
213 }
214 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
215
216 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
217 {
218         u32 ret;
219
220         spin_lock_bh(&vvs->tx_lock);
221         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
222         if (ret > credit)
223                 ret = credit;
224         vvs->tx_cnt += ret;
225         spin_unlock_bh(&vvs->tx_lock);
226
227         return ret;
228 }
229 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
230
231 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
232 {
233         spin_lock_bh(&vvs->tx_lock);
234         vvs->tx_cnt -= credit;
235         spin_unlock_bh(&vvs->tx_lock);
236 }
237 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
238
239 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
240                                                int type,
241                                                struct virtio_vsock_hdr *hdr)
242 {
243         struct virtio_vsock_pkt_info info = {
244                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
245                 .type = type,
246                 .vsk = vsk,
247         };
248
249         return virtio_transport_send_pkt_info(vsk, &info);
250 }
251
252 static ssize_t
253 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
254                                    struct msghdr *msg,
255                                    size_t len)
256 {
257         struct virtio_vsock_sock *vvs = vsk->trans;
258         struct virtio_vsock_pkt *pkt;
259         size_t bytes, total = 0;
260         int err = -EFAULT;
261
262         spin_lock_bh(&vvs->rx_lock);
263         while (total < len && !list_empty(&vvs->rx_queue)) {
264                 pkt = list_first_entry(&vvs->rx_queue,
265                                        struct virtio_vsock_pkt, list);
266
267                 bytes = len - total;
268                 if (bytes > pkt->len - pkt->off)
269                         bytes = pkt->len - pkt->off;
270
271                 /* sk_lock is held by caller so no one else can dequeue.
272                  * Unlock rx_lock since memcpy_to_msg() may sleep.
273                  */
274                 spin_unlock_bh(&vvs->rx_lock);
275
276                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
277                 if (err)
278                         goto out;
279
280                 spin_lock_bh(&vvs->rx_lock);
281
282                 total += bytes;
283                 pkt->off += bytes;
284                 if (pkt->off == pkt->len) {
285                         virtio_transport_dec_rx_pkt(vvs, pkt);
286                         list_del(&pkt->list);
287                         virtio_transport_free_pkt(pkt);
288                 }
289         }
290         spin_unlock_bh(&vvs->rx_lock);
291
292         /* Send a credit pkt to peer */
293         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
294                                             NULL);
295
296         return total;
297
298 out:
299         if (total)
300                 err = total;
301         return err;
302 }
303
304 ssize_t
305 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
306                                 struct msghdr *msg,
307                                 size_t len, int flags)
308 {
309         if (flags & MSG_PEEK)
310                 return -EOPNOTSUPP;
311
312         return virtio_transport_stream_do_dequeue(vsk, msg, len);
313 }
314 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
315
316 int
317 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
318                                struct msghdr *msg,
319                                size_t len, int flags)
320 {
321         return -EOPNOTSUPP;
322 }
323 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
324
325 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
326 {
327         struct virtio_vsock_sock *vvs = vsk->trans;
328         s64 bytes;
329
330         spin_lock_bh(&vvs->rx_lock);
331         bytes = vvs->rx_bytes;
332         spin_unlock_bh(&vvs->rx_lock);
333
334         return bytes;
335 }
336 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
337
338 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
339 {
340         struct virtio_vsock_sock *vvs = vsk->trans;
341         s64 bytes;
342
343         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
344         if (bytes < 0)
345                 bytes = 0;
346
347         return bytes;
348 }
349
350 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
351 {
352         struct virtio_vsock_sock *vvs = vsk->trans;
353         s64 bytes;
354
355         spin_lock_bh(&vvs->tx_lock);
356         bytes = virtio_transport_has_space(vsk);
357         spin_unlock_bh(&vvs->tx_lock);
358
359         return bytes;
360 }
361 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
362
363 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
364                                     struct vsock_sock *psk)
365 {
366         struct virtio_vsock_sock *vvs;
367
368         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
369         if (!vvs)
370                 return -ENOMEM;
371
372         vsk->trans = vvs;
373         vvs->vsk = vsk;
374         if (psk) {
375                 struct virtio_vsock_sock *ptrans = psk->trans;
376
377                 vvs->buf_size   = ptrans->buf_size;
378                 vvs->buf_size_min = ptrans->buf_size_min;
379                 vvs->buf_size_max = ptrans->buf_size_max;
380                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
381         } else {
382                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
383                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
384                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
385         }
386
387         vvs->buf_alloc = vvs->buf_size;
388
389         spin_lock_init(&vvs->rx_lock);
390         spin_lock_init(&vvs->tx_lock);
391         INIT_LIST_HEAD(&vvs->rx_queue);
392
393         return 0;
394 }
395 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
396
397 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
398 {
399         struct virtio_vsock_sock *vvs = vsk->trans;
400
401         return vvs->buf_size;
402 }
403 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
404
405 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
406 {
407         struct virtio_vsock_sock *vvs = vsk->trans;
408
409         return vvs->buf_size_min;
410 }
411 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
412
413 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
414 {
415         struct virtio_vsock_sock *vvs = vsk->trans;
416
417         return vvs->buf_size_max;
418 }
419 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
420
421 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
422 {
423         struct virtio_vsock_sock *vvs = vsk->trans;
424
425         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
426                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
427         if (val < vvs->buf_size_min)
428                 vvs->buf_size_min = val;
429         if (val > vvs->buf_size_max)
430                 vvs->buf_size_max = val;
431         vvs->buf_size = val;
432         vvs->buf_alloc = val;
433 }
434 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
435
436 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
437 {
438         struct virtio_vsock_sock *vvs = vsk->trans;
439
440         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
441                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
442         if (val > vvs->buf_size)
443                 vvs->buf_size = val;
444         vvs->buf_size_min = val;
445 }
446 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
447
448 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
449 {
450         struct virtio_vsock_sock *vvs = vsk->trans;
451
452         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
453                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
454         if (val < vvs->buf_size)
455                 vvs->buf_size = val;
456         vvs->buf_size_max = val;
457 }
458 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
459
460 int
461 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
462                                 size_t target,
463                                 bool *data_ready_now)
464 {
465         if (vsock_stream_has_data(vsk))
466                 *data_ready_now = true;
467         else
468                 *data_ready_now = false;
469
470         return 0;
471 }
472 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
473
474 int
475 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
476                                  size_t target,
477                                  bool *space_avail_now)
478 {
479         s64 free_space;
480
481         free_space = vsock_stream_has_space(vsk);
482         if (free_space > 0)
483                 *space_avail_now = true;
484         else if (free_space == 0)
485                 *space_avail_now = false;
486
487         return 0;
488 }
489 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
490
491 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
492         size_t target, struct vsock_transport_recv_notify_data *data)
493 {
494         return 0;
495 }
496 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
497
498 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
499         size_t target, struct vsock_transport_recv_notify_data *data)
500 {
501         return 0;
502 }
503 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
504
505 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
506         size_t target, struct vsock_transport_recv_notify_data *data)
507 {
508         return 0;
509 }
510 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
511
512 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
513         size_t target, ssize_t copied, bool data_read,
514         struct vsock_transport_recv_notify_data *data)
515 {
516         return 0;
517 }
518 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
519
520 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
521         struct vsock_transport_send_notify_data *data)
522 {
523         return 0;
524 }
525 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
526
527 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
528         struct vsock_transport_send_notify_data *data)
529 {
530         return 0;
531 }
532 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
533
534 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
535         struct vsock_transport_send_notify_data *data)
536 {
537         return 0;
538 }
539 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
540
541 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
542         ssize_t written, struct vsock_transport_send_notify_data *data)
543 {
544         return 0;
545 }
546 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
547
548 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
549 {
550         struct virtio_vsock_sock *vvs = vsk->trans;
551
552         return vvs->buf_size;
553 }
554 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
555
556 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
557 {
558         return true;
559 }
560 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
561
562 bool virtio_transport_stream_allow(u32 cid, u32 port)
563 {
564         return true;
565 }
566 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
567
568 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
569                                 struct sockaddr_vm *addr)
570 {
571         return -EOPNOTSUPP;
572 }
573 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
574
575 bool virtio_transport_dgram_allow(u32 cid, u32 port)
576 {
577         return false;
578 }
579 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
580
581 int virtio_transport_connect(struct vsock_sock *vsk)
582 {
583         struct virtio_vsock_pkt_info info = {
584                 .op = VIRTIO_VSOCK_OP_REQUEST,
585                 .type = VIRTIO_VSOCK_TYPE_STREAM,
586                 .vsk = vsk,
587         };
588
589         return virtio_transport_send_pkt_info(vsk, &info);
590 }
591 EXPORT_SYMBOL_GPL(virtio_transport_connect);
592
593 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
594 {
595         struct virtio_vsock_pkt_info info = {
596                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
597                 .type = VIRTIO_VSOCK_TYPE_STREAM,
598                 .flags = (mode & RCV_SHUTDOWN ?
599                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
600                          (mode & SEND_SHUTDOWN ?
601                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
602                 .vsk = vsk,
603         };
604
605         return virtio_transport_send_pkt_info(vsk, &info);
606 }
607 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
608
609 int
610 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
611                                struct sockaddr_vm *remote_addr,
612                                struct msghdr *msg,
613                                size_t dgram_len)
614 {
615         return -EOPNOTSUPP;
616 }
617 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
618
619 ssize_t
620 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
621                                 struct msghdr *msg,
622                                 size_t len)
623 {
624         struct virtio_vsock_pkt_info info = {
625                 .op = VIRTIO_VSOCK_OP_RW,
626                 .type = VIRTIO_VSOCK_TYPE_STREAM,
627                 .msg = msg,
628                 .pkt_len = len,
629                 .vsk = vsk,
630         };
631
632         return virtio_transport_send_pkt_info(vsk, &info);
633 }
634 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
635
636 void virtio_transport_destruct(struct vsock_sock *vsk)
637 {
638         struct virtio_vsock_sock *vvs = vsk->trans;
639
640         kfree(vvs);
641 }
642 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
643
644 static int virtio_transport_reset(struct vsock_sock *vsk,
645                                   struct virtio_vsock_pkt *pkt)
646 {
647         struct virtio_vsock_pkt_info info = {
648                 .op = VIRTIO_VSOCK_OP_RST,
649                 .type = VIRTIO_VSOCK_TYPE_STREAM,
650                 .reply = !!pkt,
651                 .vsk = vsk,
652         };
653
654         /* Send RST only if the original pkt is not a RST pkt */
655         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
656                 return 0;
657
658         return virtio_transport_send_pkt_info(vsk, &info);
659 }
660
661 /* Normally packets are associated with a socket.  There may be no socket if an
662  * attempt was made to connect to a socket that does not exist.
663  */
664 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
665 {
666         struct virtio_vsock_pkt_info info = {
667                 .op = VIRTIO_VSOCK_OP_RST,
668                 .type = le16_to_cpu(pkt->hdr.type),
669                 .reply = true,
670         };
671
672         /* Send RST only if the original pkt is not a RST pkt */
673         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
674                 return 0;
675
676         pkt = virtio_transport_alloc_pkt(&info, 0,
677                                          le64_to_cpu(pkt->hdr.dst_cid),
678                                          le32_to_cpu(pkt->hdr.dst_port),
679                                          le64_to_cpu(pkt->hdr.src_cid),
680                                          le32_to_cpu(pkt->hdr.src_port));
681         if (!pkt)
682                 return -ENOMEM;
683
684         return virtio_transport_get_ops()->send_pkt(pkt);
685 }
686
687 static void virtio_transport_wait_close(struct sock *sk, long timeout)
688 {
689         if (timeout) {
690                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
691
692                 add_wait_queue(sk_sleep(sk), &wait);
693
694                 do {
695                         if (sk_wait_event(sk, &timeout,
696                                           sock_flag(sk, SOCK_DONE), &wait))
697                                 break;
698                 } while (!signal_pending(current) && timeout);
699
700                 remove_wait_queue(sk_sleep(sk), &wait);
701         }
702 }
703
704 static void virtio_transport_do_close(struct vsock_sock *vsk,
705                                       bool cancel_timeout)
706 {
707         struct sock *sk = sk_vsock(vsk);
708
709         sock_set_flag(sk, SOCK_DONE);
710         vsk->peer_shutdown = SHUTDOWN_MASK;
711         if (vsock_stream_has_data(vsk) <= 0)
712                 sk->sk_state = SS_DISCONNECTING;
713         sk->sk_state_change(sk);
714
715         if (vsk->close_work_scheduled &&
716             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
717                 vsk->close_work_scheduled = false;
718
719                 vsock_remove_sock(vsk);
720
721                 /* Release refcnt obtained when we scheduled the timeout */
722                 sock_put(sk);
723         }
724 }
725
726 static void virtio_transport_close_timeout(struct work_struct *work)
727 {
728         struct vsock_sock *vsk =
729                 container_of(work, struct vsock_sock, close_work.work);
730         struct sock *sk = sk_vsock(vsk);
731
732         sock_hold(sk);
733         lock_sock(sk);
734
735         if (!sock_flag(sk, SOCK_DONE)) {
736                 (void)virtio_transport_reset(vsk, NULL);
737
738                 virtio_transport_do_close(vsk, false);
739         }
740
741         vsk->close_work_scheduled = false;
742
743         release_sock(sk);
744         sock_put(sk);
745 }
746
747 /* User context, vsk->sk is locked */
748 static bool virtio_transport_close(struct vsock_sock *vsk)
749 {
750         struct sock *sk = &vsk->sk;
751
752         if (!(sk->sk_state == SS_CONNECTED ||
753               sk->sk_state == SS_DISCONNECTING))
754                 return true;
755
756         /* Already received SHUTDOWN from peer, reply with RST */
757         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
758                 (void)virtio_transport_reset(vsk, NULL);
759                 return true;
760         }
761
762         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
763                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
764
765         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
766                 virtio_transport_wait_close(sk, sk->sk_lingertime);
767
768         if (sock_flag(sk, SOCK_DONE)) {
769                 return true;
770         }
771
772         sock_hold(sk);
773         INIT_DELAYED_WORK(&vsk->close_work,
774                           virtio_transport_close_timeout);
775         vsk->close_work_scheduled = true;
776         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
777         return false;
778 }
779
780 void virtio_transport_release(struct vsock_sock *vsk)
781 {
782         struct sock *sk = &vsk->sk;
783         bool remove_sock = true;
784
785         lock_sock(sk);
786         if (sk->sk_type == SOCK_STREAM)
787                 remove_sock = virtio_transport_close(vsk);
788         release_sock(sk);
789
790         if (remove_sock)
791                 vsock_remove_sock(vsk);
792 }
793 EXPORT_SYMBOL_GPL(virtio_transport_release);
794
795 static int
796 virtio_transport_recv_connecting(struct sock *sk,
797                                  struct virtio_vsock_pkt *pkt)
798 {
799         struct vsock_sock *vsk = vsock_sk(sk);
800         int err;
801         int skerr;
802
803         switch (le16_to_cpu(pkt->hdr.op)) {
804         case VIRTIO_VSOCK_OP_RESPONSE:
805                 sk->sk_state = SS_CONNECTED;
806                 sk->sk_socket->state = SS_CONNECTED;
807                 vsock_insert_connected(vsk);
808                 sk->sk_state_change(sk);
809                 break;
810         case VIRTIO_VSOCK_OP_INVALID:
811                 break;
812         case VIRTIO_VSOCK_OP_RST:
813                 skerr = ECONNRESET;
814                 err = 0;
815                 goto destroy;
816         default:
817                 skerr = EPROTO;
818                 err = -EINVAL;
819                 goto destroy;
820         }
821         return 0;
822
823 destroy:
824         virtio_transport_reset(vsk, pkt);
825         sk->sk_state = SS_UNCONNECTED;
826         sk->sk_err = skerr;
827         sk->sk_error_report(sk);
828         return err;
829 }
830
831 static int
832 virtio_transport_recv_connected(struct sock *sk,
833                                 struct virtio_vsock_pkt *pkt)
834 {
835         struct vsock_sock *vsk = vsock_sk(sk);
836         struct virtio_vsock_sock *vvs = vsk->trans;
837         int err = 0;
838
839         switch (le16_to_cpu(pkt->hdr.op)) {
840         case VIRTIO_VSOCK_OP_RW:
841                 pkt->len = le32_to_cpu(pkt->hdr.len);
842                 pkt->off = 0;
843
844                 spin_lock_bh(&vvs->rx_lock);
845                 virtio_transport_inc_rx_pkt(vvs, pkt);
846                 list_add_tail(&pkt->list, &vvs->rx_queue);
847                 spin_unlock_bh(&vvs->rx_lock);
848
849                 sk->sk_data_ready(sk);
850                 return err;
851         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
852                 sk->sk_write_space(sk);
853                 break;
854         case VIRTIO_VSOCK_OP_SHUTDOWN:
855                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
856                         vsk->peer_shutdown |= RCV_SHUTDOWN;
857                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
858                         vsk->peer_shutdown |= SEND_SHUTDOWN;
859                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
860                     vsock_stream_has_data(vsk) <= 0)
861                         sk->sk_state = SS_DISCONNECTING;
862                 if (le32_to_cpu(pkt->hdr.flags))
863                         sk->sk_state_change(sk);
864                 break;
865         case VIRTIO_VSOCK_OP_RST:
866                 virtio_transport_do_close(vsk, true);
867                 break;
868         default:
869                 err = -EINVAL;
870                 break;
871         }
872
873         virtio_transport_free_pkt(pkt);
874         return err;
875 }
876
877 static void
878 virtio_transport_recv_disconnecting(struct sock *sk,
879                                     struct virtio_vsock_pkt *pkt)
880 {
881         struct vsock_sock *vsk = vsock_sk(sk);
882
883         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
884                 virtio_transport_do_close(vsk, true);
885 }
886
887 static int
888 virtio_transport_send_response(struct vsock_sock *vsk,
889                                struct virtio_vsock_pkt *pkt)
890 {
891         struct virtio_vsock_pkt_info info = {
892                 .op = VIRTIO_VSOCK_OP_RESPONSE,
893                 .type = VIRTIO_VSOCK_TYPE_STREAM,
894                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
895                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
896                 .reply = true,
897                 .vsk = vsk,
898         };
899
900         return virtio_transport_send_pkt_info(vsk, &info);
901 }
902
903 /* Handle server socket */
904 static int
905 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
906 {
907         struct vsock_sock *vsk = vsock_sk(sk);
908         struct vsock_sock *vchild;
909         struct sock *child;
910
911         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
912                 virtio_transport_reset(vsk, pkt);
913                 return -EINVAL;
914         }
915
916         if (sk_acceptq_is_full(sk)) {
917                 virtio_transport_reset(vsk, pkt);
918                 return -ENOMEM;
919         }
920
921         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
922                                sk->sk_type, 0);
923         if (!child) {
924                 virtio_transport_reset(vsk, pkt);
925                 return -ENOMEM;
926         }
927
928         sk->sk_ack_backlog++;
929
930         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
931
932         child->sk_state = SS_CONNECTED;
933
934         vchild = vsock_sk(child);
935         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
936                         le32_to_cpu(pkt->hdr.dst_port));
937         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
938                         le32_to_cpu(pkt->hdr.src_port));
939
940         vsock_insert_connected(vchild);
941         vsock_enqueue_accept(sk, child);
942         virtio_transport_send_response(vchild, pkt);
943
944         release_sock(child);
945
946         sk->sk_data_ready(sk);
947         return 0;
948 }
949
950 static bool virtio_transport_space_update(struct sock *sk,
951                                           struct virtio_vsock_pkt *pkt)
952 {
953         struct vsock_sock *vsk = vsock_sk(sk);
954         struct virtio_vsock_sock *vvs = vsk->trans;
955         bool space_available;
956
957         /* buf_alloc and fwd_cnt is always included in the hdr */
958         spin_lock_bh(&vvs->tx_lock);
959         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
960         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
961         space_available = virtio_transport_has_space(vsk);
962         spin_unlock_bh(&vvs->tx_lock);
963         return space_available;
964 }
965
966 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
967  * lock.
968  */
969 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
970 {
971         struct sockaddr_vm src, dst;
972         struct vsock_sock *vsk;
973         struct sock *sk;
974         bool space_available;
975
976         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
977                         le32_to_cpu(pkt->hdr.src_port));
978         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
979                         le32_to_cpu(pkt->hdr.dst_port));
980
981         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
982                                         dst.svm_cid, dst.svm_port,
983                                         le32_to_cpu(pkt->hdr.len),
984                                         le16_to_cpu(pkt->hdr.type),
985                                         le16_to_cpu(pkt->hdr.op),
986                                         le32_to_cpu(pkt->hdr.flags),
987                                         le32_to_cpu(pkt->hdr.buf_alloc),
988                                         le32_to_cpu(pkt->hdr.fwd_cnt));
989
990         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
991                 (void)virtio_transport_reset_no_sock(pkt);
992                 goto free_pkt;
993         }
994
995         /* The socket must be in connected or bound table
996          * otherwise send reset back
997          */
998         sk = vsock_find_connected_socket(&src, &dst);
999         if (!sk) {
1000                 sk = vsock_find_bound_socket(&dst);
1001                 if (!sk) {
1002                         (void)virtio_transport_reset_no_sock(pkt);
1003                         goto free_pkt;
1004                 }
1005         }
1006
1007         vsk = vsock_sk(sk);
1008
1009         space_available = virtio_transport_space_update(sk, pkt);
1010
1011         lock_sock(sk);
1012
1013         /* Update CID in case it has changed after a transport reset event */
1014         vsk->local_addr.svm_cid = dst.svm_cid;
1015
1016         if (space_available)
1017                 sk->sk_write_space(sk);
1018
1019         switch (sk->sk_state) {
1020         case VSOCK_SS_LISTEN:
1021                 virtio_transport_recv_listen(sk, pkt);
1022                 virtio_transport_free_pkt(pkt);
1023                 break;
1024         case SS_CONNECTING:
1025                 virtio_transport_recv_connecting(sk, pkt);
1026                 virtio_transport_free_pkt(pkt);
1027                 break;
1028         case SS_CONNECTED:
1029                 virtio_transport_recv_connected(sk, pkt);
1030                 break;
1031         case SS_DISCONNECTING:
1032                 virtio_transport_recv_disconnecting(sk, pkt);
1033                 virtio_transport_free_pkt(pkt);
1034                 break;
1035         default:
1036                 virtio_transport_free_pkt(pkt);
1037                 break;
1038         }
1039         release_sock(sk);
1040
1041         /* Release refcnt obtained when we fetched this socket out of the
1042          * bound or connected list.
1043          */
1044         sock_put(sk);
1045         return;
1046
1047 free_pkt:
1048         virtio_transport_free_pkt(pkt);
1049 }
1050 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1051
1052 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1053 {
1054         kfree(pkt->buf);
1055         kfree(pkt);
1056 }
1057 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1058
1059 MODULE_LICENSE("GPL v2");
1060 MODULE_AUTHOR("Asias He");
1061 MODULE_DESCRIPTION("common code for virtio vsock");