]> git.karo-electronics.de Git - karo-tx-linux.git/blob - net/vmw_vsock/virtio_transport_common.c
IB/ipoib: Fix memory leak in create child syscall
[karo-tx-linux.git] / net / vmw_vsock / virtio_transport_common.c
1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/sched/signal.h>
13 #include <linux/ctype.h>
14 #include <linux/list.h>
15 #include <linux/virtio.h>
16 #include <linux/virtio_ids.h>
17 #include <linux/virtio_config.h>
18 #include <linux/virtio_vsock.h>
19 #include <uapi/linux/vsockmon.h>
20
21 #include <net/sock.h>
22 #include <net/af_vsock.h>
23
24 #define CREATE_TRACE_POINTS
25 #include <trace/events/vsock_virtio_transport_common.h>
26
27 /* How long to wait for graceful shutdown of a connection */
28 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
29
30 static const struct virtio_transport *virtio_transport_get_ops(void)
31 {
32         const struct vsock_transport *t = vsock_core_get_transport();
33
34         return container_of(t, struct virtio_transport, transport);
35 }
36
37 static struct virtio_vsock_pkt *
38 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
39                            size_t len,
40                            u32 src_cid,
41                            u32 src_port,
42                            u32 dst_cid,
43                            u32 dst_port)
44 {
45         struct virtio_vsock_pkt *pkt;
46         int err;
47
48         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
49         if (!pkt)
50                 return NULL;
51
52         pkt->hdr.type           = cpu_to_le16(info->type);
53         pkt->hdr.op             = cpu_to_le16(info->op);
54         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
55         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
56         pkt->hdr.src_port       = cpu_to_le32(src_port);
57         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
58         pkt->hdr.flags          = cpu_to_le32(info->flags);
59         pkt->len                = len;
60         pkt->hdr.len            = cpu_to_le32(len);
61         pkt->reply              = info->reply;
62         pkt->vsk                = info->vsk;
63
64         if (info->msg && len > 0) {
65                 pkt->buf = kmalloc(len, GFP_KERNEL);
66                 if (!pkt->buf)
67                         goto out_pkt;
68                 err = memcpy_from_msg(pkt->buf, info->msg, len);
69                 if (err)
70                         goto out;
71         }
72
73         trace_virtio_transport_alloc_pkt(src_cid, src_port,
74                                          dst_cid, dst_port,
75                                          len,
76                                          info->type,
77                                          info->op,
78                                          info->flags);
79
80         return pkt;
81
82 out:
83         kfree(pkt->buf);
84 out_pkt:
85         kfree(pkt);
86         return NULL;
87 }
88
89 /* Packet capture */
90 static struct sk_buff *virtio_transport_build_skb(void *opaque)
91 {
92         struct virtio_vsock_pkt *pkt = opaque;
93         unsigned char *t_hdr, *payload;
94         struct af_vsockmon_hdr *hdr;
95         struct sk_buff *skb;
96
97         skb = alloc_skb(sizeof(*hdr) + sizeof(pkt->hdr) + pkt->len,
98                         GFP_ATOMIC);
99         if (!skb)
100                 return NULL;
101
102         hdr = (struct af_vsockmon_hdr *)skb_put(skb, sizeof(*hdr));
103
104         /* pkt->hdr is little-endian so no need to byteswap here */
105         hdr->src_cid = pkt->hdr.src_cid;
106         hdr->src_port = pkt->hdr.src_port;
107         hdr->dst_cid = pkt->hdr.dst_cid;
108         hdr->dst_port = pkt->hdr.dst_port;
109
110         hdr->transport = cpu_to_le16(AF_VSOCK_TRANSPORT_VIRTIO);
111         hdr->len = cpu_to_le16(sizeof(pkt->hdr));
112         memset(hdr->reserved, 0, sizeof(hdr->reserved));
113
114         switch (le16_to_cpu(pkt->hdr.op)) {
115         case VIRTIO_VSOCK_OP_REQUEST:
116         case VIRTIO_VSOCK_OP_RESPONSE:
117                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONNECT);
118                 break;
119         case VIRTIO_VSOCK_OP_RST:
120         case VIRTIO_VSOCK_OP_SHUTDOWN:
121                 hdr->op = cpu_to_le16(AF_VSOCK_OP_DISCONNECT);
122                 break;
123         case VIRTIO_VSOCK_OP_RW:
124                 hdr->op = cpu_to_le16(AF_VSOCK_OP_PAYLOAD);
125                 break;
126         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
127         case VIRTIO_VSOCK_OP_CREDIT_REQUEST:
128                 hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL);
129                 break;
130         default:
131                 hdr->op = cpu_to_le16(AF_VSOCK_OP_UNKNOWN);
132                 break;
133         }
134
135         t_hdr = skb_put(skb, sizeof(pkt->hdr));
136         memcpy(t_hdr, &pkt->hdr, sizeof(pkt->hdr));
137
138         if (pkt->len) {
139                 payload = skb_put(skb, pkt->len);
140                 memcpy(payload, pkt->buf, pkt->len);
141         }
142
143         return skb;
144 }
145
146 void virtio_transport_deliver_tap_pkt(struct virtio_vsock_pkt *pkt)
147 {
148         vsock_deliver_tap(virtio_transport_build_skb, pkt);
149 }
150 EXPORT_SYMBOL_GPL(virtio_transport_deliver_tap_pkt);
151
152 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
153                                           struct virtio_vsock_pkt_info *info)
154 {
155         u32 src_cid, src_port, dst_cid, dst_port;
156         struct virtio_vsock_sock *vvs;
157         struct virtio_vsock_pkt *pkt;
158         u32 pkt_len = info->pkt_len;
159
160         src_cid = vm_sockets_get_local_cid();
161         src_port = vsk->local_addr.svm_port;
162         if (!info->remote_cid) {
163                 dst_cid = vsk->remote_addr.svm_cid;
164                 dst_port = vsk->remote_addr.svm_port;
165         } else {
166                 dst_cid = info->remote_cid;
167                 dst_port = info->remote_port;
168         }
169
170         vvs = vsk->trans;
171
172         /* we can send less than pkt_len bytes */
173         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
174                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
175
176         /* virtio_transport_get_credit might return less than pkt_len credit */
177         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
178
179         /* Do not send zero length OP_RW pkt */
180         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
181                 return pkt_len;
182
183         pkt = virtio_transport_alloc_pkt(info, pkt_len,
184                                          src_cid, src_port,
185                                          dst_cid, dst_port);
186         if (!pkt) {
187                 virtio_transport_put_credit(vvs, pkt_len);
188                 return -ENOMEM;
189         }
190
191         virtio_transport_inc_tx_pkt(vvs, pkt);
192
193         return virtio_transport_get_ops()->send_pkt(pkt);
194 }
195
196 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
197                                         struct virtio_vsock_pkt *pkt)
198 {
199         vvs->rx_bytes += pkt->len;
200 }
201
202 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
203                                         struct virtio_vsock_pkt *pkt)
204 {
205         vvs->rx_bytes -= pkt->len;
206         vvs->fwd_cnt += pkt->len;
207 }
208
209 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
210 {
211         spin_lock_bh(&vvs->tx_lock);
212         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
213         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
214         spin_unlock_bh(&vvs->tx_lock);
215 }
216 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
217
218 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
219 {
220         u32 ret;
221
222         spin_lock_bh(&vvs->tx_lock);
223         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
224         if (ret > credit)
225                 ret = credit;
226         vvs->tx_cnt += ret;
227         spin_unlock_bh(&vvs->tx_lock);
228
229         return ret;
230 }
231 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
232
233 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
234 {
235         spin_lock_bh(&vvs->tx_lock);
236         vvs->tx_cnt -= credit;
237         spin_unlock_bh(&vvs->tx_lock);
238 }
239 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
240
241 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
242                                                int type,
243                                                struct virtio_vsock_hdr *hdr)
244 {
245         struct virtio_vsock_pkt_info info = {
246                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
247                 .type = type,
248                 .vsk = vsk,
249         };
250
251         return virtio_transport_send_pkt_info(vsk, &info);
252 }
253
254 static ssize_t
255 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
256                                    struct msghdr *msg,
257                                    size_t len)
258 {
259         struct virtio_vsock_sock *vvs = vsk->trans;
260         struct virtio_vsock_pkt *pkt;
261         size_t bytes, total = 0;
262         int err = -EFAULT;
263
264         spin_lock_bh(&vvs->rx_lock);
265         while (total < len && !list_empty(&vvs->rx_queue)) {
266                 pkt = list_first_entry(&vvs->rx_queue,
267                                        struct virtio_vsock_pkt, list);
268
269                 bytes = len - total;
270                 if (bytes > pkt->len - pkt->off)
271                         bytes = pkt->len - pkt->off;
272
273                 /* sk_lock is held by caller so no one else can dequeue.
274                  * Unlock rx_lock since memcpy_to_msg() may sleep.
275                  */
276                 spin_unlock_bh(&vvs->rx_lock);
277
278                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
279                 if (err)
280                         goto out;
281
282                 spin_lock_bh(&vvs->rx_lock);
283
284                 total += bytes;
285                 pkt->off += bytes;
286                 if (pkt->off == pkt->len) {
287                         virtio_transport_dec_rx_pkt(vvs, pkt);
288                         list_del(&pkt->list);
289                         virtio_transport_free_pkt(pkt);
290                 }
291         }
292         spin_unlock_bh(&vvs->rx_lock);
293
294         /* Send a credit pkt to peer */
295         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
296                                             NULL);
297
298         return total;
299
300 out:
301         if (total)
302                 err = total;
303         return err;
304 }
305
306 ssize_t
307 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
308                                 struct msghdr *msg,
309                                 size_t len, int flags)
310 {
311         if (flags & MSG_PEEK)
312                 return -EOPNOTSUPP;
313
314         return virtio_transport_stream_do_dequeue(vsk, msg, len);
315 }
316 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
317
318 int
319 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
320                                struct msghdr *msg,
321                                size_t len, int flags)
322 {
323         return -EOPNOTSUPP;
324 }
325 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
326
327 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
328 {
329         struct virtio_vsock_sock *vvs = vsk->trans;
330         s64 bytes;
331
332         spin_lock_bh(&vvs->rx_lock);
333         bytes = vvs->rx_bytes;
334         spin_unlock_bh(&vvs->rx_lock);
335
336         return bytes;
337 }
338 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
339
340 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
341 {
342         struct virtio_vsock_sock *vvs = vsk->trans;
343         s64 bytes;
344
345         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
346         if (bytes < 0)
347                 bytes = 0;
348
349         return bytes;
350 }
351
352 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
353 {
354         struct virtio_vsock_sock *vvs = vsk->trans;
355         s64 bytes;
356
357         spin_lock_bh(&vvs->tx_lock);
358         bytes = virtio_transport_has_space(vsk);
359         spin_unlock_bh(&vvs->tx_lock);
360
361         return bytes;
362 }
363 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
364
365 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
366                                     struct vsock_sock *psk)
367 {
368         struct virtio_vsock_sock *vvs;
369
370         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
371         if (!vvs)
372                 return -ENOMEM;
373
374         vsk->trans = vvs;
375         vvs->vsk = vsk;
376         if (psk) {
377                 struct virtio_vsock_sock *ptrans = psk->trans;
378
379                 vvs->buf_size   = ptrans->buf_size;
380                 vvs->buf_size_min = ptrans->buf_size_min;
381                 vvs->buf_size_max = ptrans->buf_size_max;
382                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
383         } else {
384                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
385                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
386                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
387         }
388
389         vvs->buf_alloc = vvs->buf_size;
390
391         spin_lock_init(&vvs->rx_lock);
392         spin_lock_init(&vvs->tx_lock);
393         INIT_LIST_HEAD(&vvs->rx_queue);
394
395         return 0;
396 }
397 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
398
399 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
400 {
401         struct virtio_vsock_sock *vvs = vsk->trans;
402
403         return vvs->buf_size;
404 }
405 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
406
407 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
408 {
409         struct virtio_vsock_sock *vvs = vsk->trans;
410
411         return vvs->buf_size_min;
412 }
413 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
414
415 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
416 {
417         struct virtio_vsock_sock *vvs = vsk->trans;
418
419         return vvs->buf_size_max;
420 }
421 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
422
423 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
424 {
425         struct virtio_vsock_sock *vvs = vsk->trans;
426
427         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
428                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
429         if (val < vvs->buf_size_min)
430                 vvs->buf_size_min = val;
431         if (val > vvs->buf_size_max)
432                 vvs->buf_size_max = val;
433         vvs->buf_size = val;
434         vvs->buf_alloc = val;
435 }
436 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
437
438 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
439 {
440         struct virtio_vsock_sock *vvs = vsk->trans;
441
442         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
443                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
444         if (val > vvs->buf_size)
445                 vvs->buf_size = val;
446         vvs->buf_size_min = val;
447 }
448 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
449
450 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
451 {
452         struct virtio_vsock_sock *vvs = vsk->trans;
453
454         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
455                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
456         if (val < vvs->buf_size)
457                 vvs->buf_size = val;
458         vvs->buf_size_max = val;
459 }
460 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
461
462 int
463 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
464                                 size_t target,
465                                 bool *data_ready_now)
466 {
467         if (vsock_stream_has_data(vsk))
468                 *data_ready_now = true;
469         else
470                 *data_ready_now = false;
471
472         return 0;
473 }
474 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
475
476 int
477 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
478                                  size_t target,
479                                  bool *space_avail_now)
480 {
481         s64 free_space;
482
483         free_space = vsock_stream_has_space(vsk);
484         if (free_space > 0)
485                 *space_avail_now = true;
486         else if (free_space == 0)
487                 *space_avail_now = false;
488
489         return 0;
490 }
491 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
492
493 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
494         size_t target, struct vsock_transport_recv_notify_data *data)
495 {
496         return 0;
497 }
498 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
499
500 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
501         size_t target, struct vsock_transport_recv_notify_data *data)
502 {
503         return 0;
504 }
505 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
506
507 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
508         size_t target, struct vsock_transport_recv_notify_data *data)
509 {
510         return 0;
511 }
512 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
513
514 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
515         size_t target, ssize_t copied, bool data_read,
516         struct vsock_transport_recv_notify_data *data)
517 {
518         return 0;
519 }
520 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
521
522 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
523         struct vsock_transport_send_notify_data *data)
524 {
525         return 0;
526 }
527 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
528
529 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
530         struct vsock_transport_send_notify_data *data)
531 {
532         return 0;
533 }
534 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
535
536 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
537         struct vsock_transport_send_notify_data *data)
538 {
539         return 0;
540 }
541 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
542
543 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
544         ssize_t written, struct vsock_transport_send_notify_data *data)
545 {
546         return 0;
547 }
548 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
549
550 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
551 {
552         struct virtio_vsock_sock *vvs = vsk->trans;
553
554         return vvs->buf_size;
555 }
556 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
557
558 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
559 {
560         return true;
561 }
562 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
563
564 bool virtio_transport_stream_allow(u32 cid, u32 port)
565 {
566         return true;
567 }
568 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
569
570 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
571                                 struct sockaddr_vm *addr)
572 {
573         return -EOPNOTSUPP;
574 }
575 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
576
577 bool virtio_transport_dgram_allow(u32 cid, u32 port)
578 {
579         return false;
580 }
581 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
582
583 int virtio_transport_connect(struct vsock_sock *vsk)
584 {
585         struct virtio_vsock_pkt_info info = {
586                 .op = VIRTIO_VSOCK_OP_REQUEST,
587                 .type = VIRTIO_VSOCK_TYPE_STREAM,
588                 .vsk = vsk,
589         };
590
591         return virtio_transport_send_pkt_info(vsk, &info);
592 }
593 EXPORT_SYMBOL_GPL(virtio_transport_connect);
594
595 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
596 {
597         struct virtio_vsock_pkt_info info = {
598                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
599                 .type = VIRTIO_VSOCK_TYPE_STREAM,
600                 .flags = (mode & RCV_SHUTDOWN ?
601                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
602                          (mode & SEND_SHUTDOWN ?
603                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
604                 .vsk = vsk,
605         };
606
607         return virtio_transport_send_pkt_info(vsk, &info);
608 }
609 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
610
611 int
612 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
613                                struct sockaddr_vm *remote_addr,
614                                struct msghdr *msg,
615                                size_t dgram_len)
616 {
617         return -EOPNOTSUPP;
618 }
619 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
620
621 ssize_t
622 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
623                                 struct msghdr *msg,
624                                 size_t len)
625 {
626         struct virtio_vsock_pkt_info info = {
627                 .op = VIRTIO_VSOCK_OP_RW,
628                 .type = VIRTIO_VSOCK_TYPE_STREAM,
629                 .msg = msg,
630                 .pkt_len = len,
631                 .vsk = vsk,
632         };
633
634         return virtio_transport_send_pkt_info(vsk, &info);
635 }
636 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
637
638 void virtio_transport_destruct(struct vsock_sock *vsk)
639 {
640         struct virtio_vsock_sock *vvs = vsk->trans;
641
642         kfree(vvs);
643 }
644 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
645
646 static int virtio_transport_reset(struct vsock_sock *vsk,
647                                   struct virtio_vsock_pkt *pkt)
648 {
649         struct virtio_vsock_pkt_info info = {
650                 .op = VIRTIO_VSOCK_OP_RST,
651                 .type = VIRTIO_VSOCK_TYPE_STREAM,
652                 .reply = !!pkt,
653                 .vsk = vsk,
654         };
655
656         /* Send RST only if the original pkt is not a RST pkt */
657         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
658                 return 0;
659
660         return virtio_transport_send_pkt_info(vsk, &info);
661 }
662
663 /* Normally packets are associated with a socket.  There may be no socket if an
664  * attempt was made to connect to a socket that does not exist.
665  */
666 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
667 {
668         struct virtio_vsock_pkt_info info = {
669                 .op = VIRTIO_VSOCK_OP_RST,
670                 .type = le16_to_cpu(pkt->hdr.type),
671                 .reply = true,
672         };
673
674         /* Send RST only if the original pkt is not a RST pkt */
675         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
676                 return 0;
677
678         pkt = virtio_transport_alloc_pkt(&info, 0,
679                                          le64_to_cpu(pkt->hdr.dst_cid),
680                                          le32_to_cpu(pkt->hdr.dst_port),
681                                          le64_to_cpu(pkt->hdr.src_cid),
682                                          le32_to_cpu(pkt->hdr.src_port));
683         if (!pkt)
684                 return -ENOMEM;
685
686         return virtio_transport_get_ops()->send_pkt(pkt);
687 }
688
689 static void virtio_transport_wait_close(struct sock *sk, long timeout)
690 {
691         if (timeout) {
692                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
693
694                 add_wait_queue(sk_sleep(sk), &wait);
695
696                 do {
697                         if (sk_wait_event(sk, &timeout,
698                                           sock_flag(sk, SOCK_DONE), &wait))
699                                 break;
700                 } while (!signal_pending(current) && timeout);
701
702                 remove_wait_queue(sk_sleep(sk), &wait);
703         }
704 }
705
706 static void virtio_transport_do_close(struct vsock_sock *vsk,
707                                       bool cancel_timeout)
708 {
709         struct sock *sk = sk_vsock(vsk);
710
711         sock_set_flag(sk, SOCK_DONE);
712         vsk->peer_shutdown = SHUTDOWN_MASK;
713         if (vsock_stream_has_data(vsk) <= 0)
714                 sk->sk_state = SS_DISCONNECTING;
715         sk->sk_state_change(sk);
716
717         if (vsk->close_work_scheduled &&
718             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
719                 vsk->close_work_scheduled = false;
720
721                 vsock_remove_sock(vsk);
722
723                 /* Release refcnt obtained when we scheduled the timeout */
724                 sock_put(sk);
725         }
726 }
727
728 static void virtio_transport_close_timeout(struct work_struct *work)
729 {
730         struct vsock_sock *vsk =
731                 container_of(work, struct vsock_sock, close_work.work);
732         struct sock *sk = sk_vsock(vsk);
733
734         sock_hold(sk);
735         lock_sock(sk);
736
737         if (!sock_flag(sk, SOCK_DONE)) {
738                 (void)virtio_transport_reset(vsk, NULL);
739
740                 virtio_transport_do_close(vsk, false);
741         }
742
743         vsk->close_work_scheduled = false;
744
745         release_sock(sk);
746         sock_put(sk);
747 }
748
749 /* User context, vsk->sk is locked */
750 static bool virtio_transport_close(struct vsock_sock *vsk)
751 {
752         struct sock *sk = &vsk->sk;
753
754         if (!(sk->sk_state == SS_CONNECTED ||
755               sk->sk_state == SS_DISCONNECTING))
756                 return true;
757
758         /* Already received SHUTDOWN from peer, reply with RST */
759         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
760                 (void)virtio_transport_reset(vsk, NULL);
761                 return true;
762         }
763
764         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
765                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
766
767         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
768                 virtio_transport_wait_close(sk, sk->sk_lingertime);
769
770         if (sock_flag(sk, SOCK_DONE)) {
771                 return true;
772         }
773
774         sock_hold(sk);
775         INIT_DELAYED_WORK(&vsk->close_work,
776                           virtio_transport_close_timeout);
777         vsk->close_work_scheduled = true;
778         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
779         return false;
780 }
781
782 void virtio_transport_release(struct vsock_sock *vsk)
783 {
784         struct sock *sk = &vsk->sk;
785         bool remove_sock = true;
786
787         lock_sock(sk);
788         if (sk->sk_type == SOCK_STREAM)
789                 remove_sock = virtio_transport_close(vsk);
790         release_sock(sk);
791
792         if (remove_sock)
793                 vsock_remove_sock(vsk);
794 }
795 EXPORT_SYMBOL_GPL(virtio_transport_release);
796
797 static int
798 virtio_transport_recv_connecting(struct sock *sk,
799                                  struct virtio_vsock_pkt *pkt)
800 {
801         struct vsock_sock *vsk = vsock_sk(sk);
802         int err;
803         int skerr;
804
805         switch (le16_to_cpu(pkt->hdr.op)) {
806         case VIRTIO_VSOCK_OP_RESPONSE:
807                 sk->sk_state = SS_CONNECTED;
808                 sk->sk_socket->state = SS_CONNECTED;
809                 vsock_insert_connected(vsk);
810                 sk->sk_state_change(sk);
811                 break;
812         case VIRTIO_VSOCK_OP_INVALID:
813                 break;
814         case VIRTIO_VSOCK_OP_RST:
815                 skerr = ECONNRESET;
816                 err = 0;
817                 goto destroy;
818         default:
819                 skerr = EPROTO;
820                 err = -EINVAL;
821                 goto destroy;
822         }
823         return 0;
824
825 destroy:
826         virtio_transport_reset(vsk, pkt);
827         sk->sk_state = SS_UNCONNECTED;
828         sk->sk_err = skerr;
829         sk->sk_error_report(sk);
830         return err;
831 }
832
833 static int
834 virtio_transport_recv_connected(struct sock *sk,
835                                 struct virtio_vsock_pkt *pkt)
836 {
837         struct vsock_sock *vsk = vsock_sk(sk);
838         struct virtio_vsock_sock *vvs = vsk->trans;
839         int err = 0;
840
841         switch (le16_to_cpu(pkt->hdr.op)) {
842         case VIRTIO_VSOCK_OP_RW:
843                 pkt->len = le32_to_cpu(pkt->hdr.len);
844                 pkt->off = 0;
845
846                 spin_lock_bh(&vvs->rx_lock);
847                 virtio_transport_inc_rx_pkt(vvs, pkt);
848                 list_add_tail(&pkt->list, &vvs->rx_queue);
849                 spin_unlock_bh(&vvs->rx_lock);
850
851                 sk->sk_data_ready(sk);
852                 return err;
853         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
854                 sk->sk_write_space(sk);
855                 break;
856         case VIRTIO_VSOCK_OP_SHUTDOWN:
857                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
858                         vsk->peer_shutdown |= RCV_SHUTDOWN;
859                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
860                         vsk->peer_shutdown |= SEND_SHUTDOWN;
861                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
862                     vsock_stream_has_data(vsk) <= 0)
863                         sk->sk_state = SS_DISCONNECTING;
864                 if (le32_to_cpu(pkt->hdr.flags))
865                         sk->sk_state_change(sk);
866                 break;
867         case VIRTIO_VSOCK_OP_RST:
868                 virtio_transport_do_close(vsk, true);
869                 break;
870         default:
871                 err = -EINVAL;
872                 break;
873         }
874
875         virtio_transport_free_pkt(pkt);
876         return err;
877 }
878
879 static void
880 virtio_transport_recv_disconnecting(struct sock *sk,
881                                     struct virtio_vsock_pkt *pkt)
882 {
883         struct vsock_sock *vsk = vsock_sk(sk);
884
885         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
886                 virtio_transport_do_close(vsk, true);
887 }
888
889 static int
890 virtio_transport_send_response(struct vsock_sock *vsk,
891                                struct virtio_vsock_pkt *pkt)
892 {
893         struct virtio_vsock_pkt_info info = {
894                 .op = VIRTIO_VSOCK_OP_RESPONSE,
895                 .type = VIRTIO_VSOCK_TYPE_STREAM,
896                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
897                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
898                 .reply = true,
899                 .vsk = vsk,
900         };
901
902         return virtio_transport_send_pkt_info(vsk, &info);
903 }
904
905 /* Handle server socket */
906 static int
907 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
908 {
909         struct vsock_sock *vsk = vsock_sk(sk);
910         struct vsock_sock *vchild;
911         struct sock *child;
912
913         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
914                 virtio_transport_reset(vsk, pkt);
915                 return -EINVAL;
916         }
917
918         if (sk_acceptq_is_full(sk)) {
919                 virtio_transport_reset(vsk, pkt);
920                 return -ENOMEM;
921         }
922
923         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
924                                sk->sk_type, 0);
925         if (!child) {
926                 virtio_transport_reset(vsk, pkt);
927                 return -ENOMEM;
928         }
929
930         sk->sk_ack_backlog++;
931
932         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
933
934         child->sk_state = SS_CONNECTED;
935
936         vchild = vsock_sk(child);
937         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
938                         le32_to_cpu(pkt->hdr.dst_port));
939         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
940                         le32_to_cpu(pkt->hdr.src_port));
941
942         vsock_insert_connected(vchild);
943         vsock_enqueue_accept(sk, child);
944         virtio_transport_send_response(vchild, pkt);
945
946         release_sock(child);
947
948         sk->sk_data_ready(sk);
949         return 0;
950 }
951
952 static bool virtio_transport_space_update(struct sock *sk,
953                                           struct virtio_vsock_pkt *pkt)
954 {
955         struct vsock_sock *vsk = vsock_sk(sk);
956         struct virtio_vsock_sock *vvs = vsk->trans;
957         bool space_available;
958
959         /* buf_alloc and fwd_cnt is always included in the hdr */
960         spin_lock_bh(&vvs->tx_lock);
961         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
962         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
963         space_available = virtio_transport_has_space(vsk);
964         spin_unlock_bh(&vvs->tx_lock);
965         return space_available;
966 }
967
968 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
969  * lock.
970  */
971 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
972 {
973         struct sockaddr_vm src, dst;
974         struct vsock_sock *vsk;
975         struct sock *sk;
976         bool space_available;
977
978         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
979                         le32_to_cpu(pkt->hdr.src_port));
980         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
981                         le32_to_cpu(pkt->hdr.dst_port));
982
983         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
984                                         dst.svm_cid, dst.svm_port,
985                                         le32_to_cpu(pkt->hdr.len),
986                                         le16_to_cpu(pkt->hdr.type),
987                                         le16_to_cpu(pkt->hdr.op),
988                                         le32_to_cpu(pkt->hdr.flags),
989                                         le32_to_cpu(pkt->hdr.buf_alloc),
990                                         le32_to_cpu(pkt->hdr.fwd_cnt));
991
992         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
993                 (void)virtio_transport_reset_no_sock(pkt);
994                 goto free_pkt;
995         }
996
997         /* The socket must be in connected or bound table
998          * otherwise send reset back
999          */
1000         sk = vsock_find_connected_socket(&src, &dst);
1001         if (!sk) {
1002                 sk = vsock_find_bound_socket(&dst);
1003                 if (!sk) {
1004                         (void)virtio_transport_reset_no_sock(pkt);
1005                         goto free_pkt;
1006                 }
1007         }
1008
1009         vsk = vsock_sk(sk);
1010
1011         space_available = virtio_transport_space_update(sk, pkt);
1012
1013         lock_sock(sk);
1014
1015         /* Update CID in case it has changed after a transport reset event */
1016         vsk->local_addr.svm_cid = dst.svm_cid;
1017
1018         if (space_available)
1019                 sk->sk_write_space(sk);
1020
1021         switch (sk->sk_state) {
1022         case VSOCK_SS_LISTEN:
1023                 virtio_transport_recv_listen(sk, pkt);
1024                 virtio_transport_free_pkt(pkt);
1025                 break;
1026         case SS_CONNECTING:
1027                 virtio_transport_recv_connecting(sk, pkt);
1028                 virtio_transport_free_pkt(pkt);
1029                 break;
1030         case SS_CONNECTED:
1031                 virtio_transport_recv_connected(sk, pkt);
1032                 break;
1033         case SS_DISCONNECTING:
1034                 virtio_transport_recv_disconnecting(sk, pkt);
1035                 virtio_transport_free_pkt(pkt);
1036                 break;
1037         default:
1038                 virtio_transport_free_pkt(pkt);
1039                 break;
1040         }
1041         release_sock(sk);
1042
1043         /* Release refcnt obtained when we fetched this socket out of the
1044          * bound or connected list.
1045          */
1046         sock_put(sk);
1047         return;
1048
1049 free_pkt:
1050         virtio_transport_free_pkt(pkt);
1051 }
1052 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
1053
1054 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
1055 {
1056         kfree(pkt->buf);
1057         kfree(pkt);
1058 }
1059 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
1060
1061 MODULE_LICENSE("GPL v2");
1062 MODULE_AUTHOR("Asias He");
1063 MODULE_DESCRIPTION("common code for virtio vsock");