]> git.karo-electronics.de Git - karo-tx-linux.git/blob - net/vmw_vsock/virtio_transport_common.c
Merge branch 'for-linus-4.10-rc1' of git://git.kernel.org/pub/scm/linux/kernel/git...
[karo-tx-linux.git] / net / vmw_vsock / virtio_transport_common.c
1 /*
2  * common code for virtio vsock
3  *
4  * Copyright (C) 2013-2015 Red Hat, Inc.
5  * Author: Asias He <asias@redhat.com>
6  *         Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This work is licensed under the terms of the GNU GPL, version 2.
9  */
10 #include <linux/spinlock.h>
11 #include <linux/module.h>
12 #include <linux/ctype.h>
13 #include <linux/list.h>
14 #include <linux/virtio.h>
15 #include <linux/virtio_ids.h>
16 #include <linux/virtio_config.h>
17 #include <linux/virtio_vsock.h>
18
19 #include <net/sock.h>
20 #include <net/af_vsock.h>
21
22 #define CREATE_TRACE_POINTS
23 #include <trace/events/vsock_virtio_transport_common.h>
24
25 /* How long to wait for graceful shutdown of a connection */
26 #define VSOCK_CLOSE_TIMEOUT (8 * HZ)
27
28 static const struct virtio_transport *virtio_transport_get_ops(void)
29 {
30         const struct vsock_transport *t = vsock_core_get_transport();
31
32         return container_of(t, struct virtio_transport, transport);
33 }
34
35 static struct virtio_vsock_pkt *
36 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info,
37                            size_t len,
38                            u32 src_cid,
39                            u32 src_port,
40                            u32 dst_cid,
41                            u32 dst_port)
42 {
43         struct virtio_vsock_pkt *pkt;
44         int err;
45
46         pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
47         if (!pkt)
48                 return NULL;
49
50         pkt->hdr.type           = cpu_to_le16(info->type);
51         pkt->hdr.op             = cpu_to_le16(info->op);
52         pkt->hdr.src_cid        = cpu_to_le64(src_cid);
53         pkt->hdr.dst_cid        = cpu_to_le64(dst_cid);
54         pkt->hdr.src_port       = cpu_to_le32(src_port);
55         pkt->hdr.dst_port       = cpu_to_le32(dst_port);
56         pkt->hdr.flags          = cpu_to_le32(info->flags);
57         pkt->len                = len;
58         pkt->hdr.len            = cpu_to_le32(len);
59         pkt->reply              = info->reply;
60
61         if (info->msg && len > 0) {
62                 pkt->buf = kmalloc(len, GFP_KERNEL);
63                 if (!pkt->buf)
64                         goto out_pkt;
65                 err = memcpy_from_msg(pkt->buf, info->msg, len);
66                 if (err)
67                         goto out;
68         }
69
70         trace_virtio_transport_alloc_pkt(src_cid, src_port,
71                                          dst_cid, dst_port,
72                                          len,
73                                          info->type,
74                                          info->op,
75                                          info->flags);
76
77         return pkt;
78
79 out:
80         kfree(pkt->buf);
81 out_pkt:
82         kfree(pkt);
83         return NULL;
84 }
85
86 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk,
87                                           struct virtio_vsock_pkt_info *info)
88 {
89         u32 src_cid, src_port, dst_cid, dst_port;
90         struct virtio_vsock_sock *vvs;
91         struct virtio_vsock_pkt *pkt;
92         u32 pkt_len = info->pkt_len;
93
94         src_cid = vm_sockets_get_local_cid();
95         src_port = vsk->local_addr.svm_port;
96         if (!info->remote_cid) {
97                 dst_cid = vsk->remote_addr.svm_cid;
98                 dst_port = vsk->remote_addr.svm_port;
99         } else {
100                 dst_cid = info->remote_cid;
101                 dst_port = info->remote_port;
102         }
103
104         vvs = vsk->trans;
105
106         /* we can send less than pkt_len bytes */
107         if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE)
108                 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
109
110         /* virtio_transport_get_credit might return less than pkt_len credit */
111         pkt_len = virtio_transport_get_credit(vvs, pkt_len);
112
113         /* Do not send zero length OP_RW pkt */
114         if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW)
115                 return pkt_len;
116
117         pkt = virtio_transport_alloc_pkt(info, pkt_len,
118                                          src_cid, src_port,
119                                          dst_cid, dst_port);
120         if (!pkt) {
121                 virtio_transport_put_credit(vvs, pkt_len);
122                 return -ENOMEM;
123         }
124
125         virtio_transport_inc_tx_pkt(vvs, pkt);
126
127         return virtio_transport_get_ops()->send_pkt(pkt);
128 }
129
130 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs,
131                                         struct virtio_vsock_pkt *pkt)
132 {
133         vvs->rx_bytes += pkt->len;
134 }
135
136 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs,
137                                         struct virtio_vsock_pkt *pkt)
138 {
139         vvs->rx_bytes -= pkt->len;
140         vvs->fwd_cnt += pkt->len;
141 }
142
143 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt)
144 {
145         spin_lock_bh(&vvs->tx_lock);
146         pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt);
147         pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc);
148         spin_unlock_bh(&vvs->tx_lock);
149 }
150 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt);
151
152 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit)
153 {
154         u32 ret;
155
156         spin_lock_bh(&vvs->tx_lock);
157         ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
158         if (ret > credit)
159                 ret = credit;
160         vvs->tx_cnt += ret;
161         spin_unlock_bh(&vvs->tx_lock);
162
163         return ret;
164 }
165 EXPORT_SYMBOL_GPL(virtio_transport_get_credit);
166
167 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit)
168 {
169         spin_lock_bh(&vvs->tx_lock);
170         vvs->tx_cnt -= credit;
171         spin_unlock_bh(&vvs->tx_lock);
172 }
173 EXPORT_SYMBOL_GPL(virtio_transport_put_credit);
174
175 static int virtio_transport_send_credit_update(struct vsock_sock *vsk,
176                                                int type,
177                                                struct virtio_vsock_hdr *hdr)
178 {
179         struct virtio_vsock_pkt_info info = {
180                 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE,
181                 .type = type,
182         };
183
184         return virtio_transport_send_pkt_info(vsk, &info);
185 }
186
187 static ssize_t
188 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk,
189                                    struct msghdr *msg,
190                                    size_t len)
191 {
192         struct virtio_vsock_sock *vvs = vsk->trans;
193         struct virtio_vsock_pkt *pkt;
194         size_t bytes, total = 0;
195         int err = -EFAULT;
196
197         spin_lock_bh(&vvs->rx_lock);
198         while (total < len && !list_empty(&vvs->rx_queue)) {
199                 pkt = list_first_entry(&vvs->rx_queue,
200                                        struct virtio_vsock_pkt, list);
201
202                 bytes = len - total;
203                 if (bytes > pkt->len - pkt->off)
204                         bytes = pkt->len - pkt->off;
205
206                 /* sk_lock is held by caller so no one else can dequeue.
207                  * Unlock rx_lock since memcpy_to_msg() may sleep.
208                  */
209                 spin_unlock_bh(&vvs->rx_lock);
210
211                 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes);
212                 if (err)
213                         goto out;
214
215                 spin_lock_bh(&vvs->rx_lock);
216
217                 total += bytes;
218                 pkt->off += bytes;
219                 if (pkt->off == pkt->len) {
220                         virtio_transport_dec_rx_pkt(vvs, pkt);
221                         list_del(&pkt->list);
222                         virtio_transport_free_pkt(pkt);
223                 }
224         }
225         spin_unlock_bh(&vvs->rx_lock);
226
227         /* Send a credit pkt to peer */
228         virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM,
229                                             NULL);
230
231         return total;
232
233 out:
234         if (total)
235                 err = total;
236         return err;
237 }
238
239 ssize_t
240 virtio_transport_stream_dequeue(struct vsock_sock *vsk,
241                                 struct msghdr *msg,
242                                 size_t len, int flags)
243 {
244         if (flags & MSG_PEEK)
245                 return -EOPNOTSUPP;
246
247         return virtio_transport_stream_do_dequeue(vsk, msg, len);
248 }
249 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue);
250
251 int
252 virtio_transport_dgram_dequeue(struct vsock_sock *vsk,
253                                struct msghdr *msg,
254                                size_t len, int flags)
255 {
256         return -EOPNOTSUPP;
257 }
258 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue);
259
260 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk)
261 {
262         struct virtio_vsock_sock *vvs = vsk->trans;
263         s64 bytes;
264
265         spin_lock_bh(&vvs->rx_lock);
266         bytes = vvs->rx_bytes;
267         spin_unlock_bh(&vvs->rx_lock);
268
269         return bytes;
270 }
271 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data);
272
273 static s64 virtio_transport_has_space(struct vsock_sock *vsk)
274 {
275         struct virtio_vsock_sock *vvs = vsk->trans;
276         s64 bytes;
277
278         bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt);
279         if (bytes < 0)
280                 bytes = 0;
281
282         return bytes;
283 }
284
285 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk)
286 {
287         struct virtio_vsock_sock *vvs = vsk->trans;
288         s64 bytes;
289
290         spin_lock_bh(&vvs->tx_lock);
291         bytes = virtio_transport_has_space(vsk);
292         spin_unlock_bh(&vvs->tx_lock);
293
294         return bytes;
295 }
296 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space);
297
298 int virtio_transport_do_socket_init(struct vsock_sock *vsk,
299                                     struct vsock_sock *psk)
300 {
301         struct virtio_vsock_sock *vvs;
302
303         vvs = kzalloc(sizeof(*vvs), GFP_KERNEL);
304         if (!vvs)
305                 return -ENOMEM;
306
307         vsk->trans = vvs;
308         vvs->vsk = vsk;
309         if (psk) {
310                 struct virtio_vsock_sock *ptrans = psk->trans;
311
312                 vvs->buf_size   = ptrans->buf_size;
313                 vvs->buf_size_min = ptrans->buf_size_min;
314                 vvs->buf_size_max = ptrans->buf_size_max;
315                 vvs->peer_buf_alloc = ptrans->peer_buf_alloc;
316         } else {
317                 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE;
318                 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE;
319                 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE;
320         }
321
322         vvs->buf_alloc = vvs->buf_size;
323
324         spin_lock_init(&vvs->rx_lock);
325         spin_lock_init(&vvs->tx_lock);
326         INIT_LIST_HEAD(&vvs->rx_queue);
327
328         return 0;
329 }
330 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init);
331
332 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk)
333 {
334         struct virtio_vsock_sock *vvs = vsk->trans;
335
336         return vvs->buf_size;
337 }
338 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size);
339
340 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk)
341 {
342         struct virtio_vsock_sock *vvs = vsk->trans;
343
344         return vvs->buf_size_min;
345 }
346 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size);
347
348 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk)
349 {
350         struct virtio_vsock_sock *vvs = vsk->trans;
351
352         return vvs->buf_size_max;
353 }
354 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size);
355
356 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val)
357 {
358         struct virtio_vsock_sock *vvs = vsk->trans;
359
360         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
361                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
362         if (val < vvs->buf_size_min)
363                 vvs->buf_size_min = val;
364         if (val > vvs->buf_size_max)
365                 vvs->buf_size_max = val;
366         vvs->buf_size = val;
367         vvs->buf_alloc = val;
368 }
369 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size);
370
371 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val)
372 {
373         struct virtio_vsock_sock *vvs = vsk->trans;
374
375         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
376                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
377         if (val > vvs->buf_size)
378                 vvs->buf_size = val;
379         vvs->buf_size_min = val;
380 }
381 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size);
382
383 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val)
384 {
385         struct virtio_vsock_sock *vvs = vsk->trans;
386
387         if (val > VIRTIO_VSOCK_MAX_BUF_SIZE)
388                 val = VIRTIO_VSOCK_MAX_BUF_SIZE;
389         if (val < vvs->buf_size)
390                 vvs->buf_size = val;
391         vvs->buf_size_max = val;
392 }
393 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size);
394
395 int
396 virtio_transport_notify_poll_in(struct vsock_sock *vsk,
397                                 size_t target,
398                                 bool *data_ready_now)
399 {
400         if (vsock_stream_has_data(vsk))
401                 *data_ready_now = true;
402         else
403                 *data_ready_now = false;
404
405         return 0;
406 }
407 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in);
408
409 int
410 virtio_transport_notify_poll_out(struct vsock_sock *vsk,
411                                  size_t target,
412                                  bool *space_avail_now)
413 {
414         s64 free_space;
415
416         free_space = vsock_stream_has_space(vsk);
417         if (free_space > 0)
418                 *space_avail_now = true;
419         else if (free_space == 0)
420                 *space_avail_now = false;
421
422         return 0;
423 }
424 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out);
425
426 int virtio_transport_notify_recv_init(struct vsock_sock *vsk,
427         size_t target, struct vsock_transport_recv_notify_data *data)
428 {
429         return 0;
430 }
431 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init);
432
433 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk,
434         size_t target, struct vsock_transport_recv_notify_data *data)
435 {
436         return 0;
437 }
438 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block);
439
440 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk,
441         size_t target, struct vsock_transport_recv_notify_data *data)
442 {
443         return 0;
444 }
445 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue);
446
447 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk,
448         size_t target, ssize_t copied, bool data_read,
449         struct vsock_transport_recv_notify_data *data)
450 {
451         return 0;
452 }
453 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue);
454
455 int virtio_transport_notify_send_init(struct vsock_sock *vsk,
456         struct vsock_transport_send_notify_data *data)
457 {
458         return 0;
459 }
460 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init);
461
462 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk,
463         struct vsock_transport_send_notify_data *data)
464 {
465         return 0;
466 }
467 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block);
468
469 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk,
470         struct vsock_transport_send_notify_data *data)
471 {
472         return 0;
473 }
474 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue);
475
476 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk,
477         ssize_t written, struct vsock_transport_send_notify_data *data)
478 {
479         return 0;
480 }
481 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue);
482
483 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk)
484 {
485         struct virtio_vsock_sock *vvs = vsk->trans;
486
487         return vvs->buf_size;
488 }
489 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat);
490
491 bool virtio_transport_stream_is_active(struct vsock_sock *vsk)
492 {
493         return true;
494 }
495 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active);
496
497 bool virtio_transport_stream_allow(u32 cid, u32 port)
498 {
499         return true;
500 }
501 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow);
502
503 int virtio_transport_dgram_bind(struct vsock_sock *vsk,
504                                 struct sockaddr_vm *addr)
505 {
506         return -EOPNOTSUPP;
507 }
508 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind);
509
510 bool virtio_transport_dgram_allow(u32 cid, u32 port)
511 {
512         return false;
513 }
514 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow);
515
516 int virtio_transport_connect(struct vsock_sock *vsk)
517 {
518         struct virtio_vsock_pkt_info info = {
519                 .op = VIRTIO_VSOCK_OP_REQUEST,
520                 .type = VIRTIO_VSOCK_TYPE_STREAM,
521         };
522
523         return virtio_transport_send_pkt_info(vsk, &info);
524 }
525 EXPORT_SYMBOL_GPL(virtio_transport_connect);
526
527 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode)
528 {
529         struct virtio_vsock_pkt_info info = {
530                 .op = VIRTIO_VSOCK_OP_SHUTDOWN,
531                 .type = VIRTIO_VSOCK_TYPE_STREAM,
532                 .flags = (mode & RCV_SHUTDOWN ?
533                           VIRTIO_VSOCK_SHUTDOWN_RCV : 0) |
534                          (mode & SEND_SHUTDOWN ?
535                           VIRTIO_VSOCK_SHUTDOWN_SEND : 0),
536         };
537
538         return virtio_transport_send_pkt_info(vsk, &info);
539 }
540 EXPORT_SYMBOL_GPL(virtio_transport_shutdown);
541
542 int
543 virtio_transport_dgram_enqueue(struct vsock_sock *vsk,
544                                struct sockaddr_vm *remote_addr,
545                                struct msghdr *msg,
546                                size_t dgram_len)
547 {
548         return -EOPNOTSUPP;
549 }
550 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue);
551
552 ssize_t
553 virtio_transport_stream_enqueue(struct vsock_sock *vsk,
554                                 struct msghdr *msg,
555                                 size_t len)
556 {
557         struct virtio_vsock_pkt_info info = {
558                 .op = VIRTIO_VSOCK_OP_RW,
559                 .type = VIRTIO_VSOCK_TYPE_STREAM,
560                 .msg = msg,
561                 .pkt_len = len,
562         };
563
564         return virtio_transport_send_pkt_info(vsk, &info);
565 }
566 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue);
567
568 void virtio_transport_destruct(struct vsock_sock *vsk)
569 {
570         struct virtio_vsock_sock *vvs = vsk->trans;
571
572         kfree(vvs);
573 }
574 EXPORT_SYMBOL_GPL(virtio_transport_destruct);
575
576 static int virtio_transport_reset(struct vsock_sock *vsk,
577                                   struct virtio_vsock_pkt *pkt)
578 {
579         struct virtio_vsock_pkt_info info = {
580                 .op = VIRTIO_VSOCK_OP_RST,
581                 .type = VIRTIO_VSOCK_TYPE_STREAM,
582                 .reply = !!pkt,
583         };
584
585         /* Send RST only if the original pkt is not a RST pkt */
586         if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
587                 return 0;
588
589         return virtio_transport_send_pkt_info(vsk, &info);
590 }
591
592 /* Normally packets are associated with a socket.  There may be no socket if an
593  * attempt was made to connect to a socket that does not exist.
594  */
595 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
596 {
597         struct virtio_vsock_pkt_info info = {
598                 .op = VIRTIO_VSOCK_OP_RST,
599                 .type = le16_to_cpu(pkt->hdr.type),
600                 .reply = true,
601         };
602
603         /* Send RST only if the original pkt is not a RST pkt */
604         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
605                 return 0;
606
607         pkt = virtio_transport_alloc_pkt(&info, 0,
608                                          le64_to_cpu(pkt->hdr.dst_cid),
609                                          le32_to_cpu(pkt->hdr.dst_port),
610                                          le64_to_cpu(pkt->hdr.src_cid),
611                                          le32_to_cpu(pkt->hdr.src_port));
612         if (!pkt)
613                 return -ENOMEM;
614
615         return virtio_transport_get_ops()->send_pkt(pkt);
616 }
617
618 static void virtio_transport_wait_close(struct sock *sk, long timeout)
619 {
620         if (timeout) {
621                 DEFINE_WAIT_FUNC(wait, woken_wake_function);
622
623                 add_wait_queue(sk_sleep(sk), &wait);
624
625                 do {
626                         if (sk_wait_event(sk, &timeout,
627                                           sock_flag(sk, SOCK_DONE), &wait))
628                                 break;
629                 } while (!signal_pending(current) && timeout);
630
631                 remove_wait_queue(sk_sleep(sk), &wait);
632         }
633 }
634
635 static void virtio_transport_do_close(struct vsock_sock *vsk,
636                                       bool cancel_timeout)
637 {
638         struct sock *sk = sk_vsock(vsk);
639
640         sock_set_flag(sk, SOCK_DONE);
641         vsk->peer_shutdown = SHUTDOWN_MASK;
642         if (vsock_stream_has_data(vsk) <= 0)
643                 sk->sk_state = SS_DISCONNECTING;
644         sk->sk_state_change(sk);
645
646         if (vsk->close_work_scheduled &&
647             (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) {
648                 vsk->close_work_scheduled = false;
649
650                 vsock_remove_sock(vsk);
651
652                 /* Release refcnt obtained when we scheduled the timeout */
653                 sock_put(sk);
654         }
655 }
656
657 static void virtio_transport_close_timeout(struct work_struct *work)
658 {
659         struct vsock_sock *vsk =
660                 container_of(work, struct vsock_sock, close_work.work);
661         struct sock *sk = sk_vsock(vsk);
662
663         sock_hold(sk);
664         lock_sock(sk);
665
666         if (!sock_flag(sk, SOCK_DONE)) {
667                 (void)virtio_transport_reset(vsk, NULL);
668
669                 virtio_transport_do_close(vsk, false);
670         }
671
672         vsk->close_work_scheduled = false;
673
674         release_sock(sk);
675         sock_put(sk);
676 }
677
678 /* User context, vsk->sk is locked */
679 static bool virtio_transport_close(struct vsock_sock *vsk)
680 {
681         struct sock *sk = &vsk->sk;
682
683         if (!(sk->sk_state == SS_CONNECTED ||
684               sk->sk_state == SS_DISCONNECTING))
685                 return true;
686
687         /* Already received SHUTDOWN from peer, reply with RST */
688         if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) {
689                 (void)virtio_transport_reset(vsk, NULL);
690                 return true;
691         }
692
693         if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK)
694                 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK);
695
696         if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING))
697                 virtio_transport_wait_close(sk, sk->sk_lingertime);
698
699         if (sock_flag(sk, SOCK_DONE)) {
700                 return true;
701         }
702
703         sock_hold(sk);
704         INIT_DELAYED_WORK(&vsk->close_work,
705                           virtio_transport_close_timeout);
706         vsk->close_work_scheduled = true;
707         schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT);
708         return false;
709 }
710
711 void virtio_transport_release(struct vsock_sock *vsk)
712 {
713         struct sock *sk = &vsk->sk;
714         bool remove_sock = true;
715
716         lock_sock(sk);
717         if (sk->sk_type == SOCK_STREAM)
718                 remove_sock = virtio_transport_close(vsk);
719         release_sock(sk);
720
721         if (remove_sock)
722                 vsock_remove_sock(vsk);
723 }
724 EXPORT_SYMBOL_GPL(virtio_transport_release);
725
726 static int
727 virtio_transport_recv_connecting(struct sock *sk,
728                                  struct virtio_vsock_pkt *pkt)
729 {
730         struct vsock_sock *vsk = vsock_sk(sk);
731         int err;
732         int skerr;
733
734         switch (le16_to_cpu(pkt->hdr.op)) {
735         case VIRTIO_VSOCK_OP_RESPONSE:
736                 sk->sk_state = SS_CONNECTED;
737                 sk->sk_socket->state = SS_CONNECTED;
738                 vsock_insert_connected(vsk);
739                 sk->sk_state_change(sk);
740                 break;
741         case VIRTIO_VSOCK_OP_INVALID:
742                 break;
743         case VIRTIO_VSOCK_OP_RST:
744                 skerr = ECONNRESET;
745                 err = 0;
746                 goto destroy;
747         default:
748                 skerr = EPROTO;
749                 err = -EINVAL;
750                 goto destroy;
751         }
752         return 0;
753
754 destroy:
755         virtio_transport_reset(vsk, pkt);
756         sk->sk_state = SS_UNCONNECTED;
757         sk->sk_err = skerr;
758         sk->sk_error_report(sk);
759         return err;
760 }
761
762 static int
763 virtio_transport_recv_connected(struct sock *sk,
764                                 struct virtio_vsock_pkt *pkt)
765 {
766         struct vsock_sock *vsk = vsock_sk(sk);
767         struct virtio_vsock_sock *vvs = vsk->trans;
768         int err = 0;
769
770         switch (le16_to_cpu(pkt->hdr.op)) {
771         case VIRTIO_VSOCK_OP_RW:
772                 pkt->len = le32_to_cpu(pkt->hdr.len);
773                 pkt->off = 0;
774
775                 spin_lock_bh(&vvs->rx_lock);
776                 virtio_transport_inc_rx_pkt(vvs, pkt);
777                 list_add_tail(&pkt->list, &vvs->rx_queue);
778                 spin_unlock_bh(&vvs->rx_lock);
779
780                 sk->sk_data_ready(sk);
781                 return err;
782         case VIRTIO_VSOCK_OP_CREDIT_UPDATE:
783                 sk->sk_write_space(sk);
784                 break;
785         case VIRTIO_VSOCK_OP_SHUTDOWN:
786                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV)
787                         vsk->peer_shutdown |= RCV_SHUTDOWN;
788                 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND)
789                         vsk->peer_shutdown |= SEND_SHUTDOWN;
790                 if (vsk->peer_shutdown == SHUTDOWN_MASK &&
791                     vsock_stream_has_data(vsk) <= 0)
792                         sk->sk_state = SS_DISCONNECTING;
793                 if (le32_to_cpu(pkt->hdr.flags))
794                         sk->sk_state_change(sk);
795                 break;
796         case VIRTIO_VSOCK_OP_RST:
797                 virtio_transport_do_close(vsk, true);
798                 break;
799         default:
800                 err = -EINVAL;
801                 break;
802         }
803
804         virtio_transport_free_pkt(pkt);
805         return err;
806 }
807
808 static void
809 virtio_transport_recv_disconnecting(struct sock *sk,
810                                     struct virtio_vsock_pkt *pkt)
811 {
812         struct vsock_sock *vsk = vsock_sk(sk);
813
814         if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST)
815                 virtio_transport_do_close(vsk, true);
816 }
817
818 static int
819 virtio_transport_send_response(struct vsock_sock *vsk,
820                                struct virtio_vsock_pkt *pkt)
821 {
822         struct virtio_vsock_pkt_info info = {
823                 .op = VIRTIO_VSOCK_OP_RESPONSE,
824                 .type = VIRTIO_VSOCK_TYPE_STREAM,
825                 .remote_cid = le64_to_cpu(pkt->hdr.src_cid),
826                 .remote_port = le32_to_cpu(pkt->hdr.src_port),
827                 .reply = true,
828         };
829
830         return virtio_transport_send_pkt_info(vsk, &info);
831 }
832
833 /* Handle server socket */
834 static int
835 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt)
836 {
837         struct vsock_sock *vsk = vsock_sk(sk);
838         struct vsock_sock *vchild;
839         struct sock *child;
840
841         if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) {
842                 virtio_transport_reset(vsk, pkt);
843                 return -EINVAL;
844         }
845
846         if (sk_acceptq_is_full(sk)) {
847                 virtio_transport_reset(vsk, pkt);
848                 return -ENOMEM;
849         }
850
851         child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL,
852                                sk->sk_type, 0);
853         if (!child) {
854                 virtio_transport_reset(vsk, pkt);
855                 return -ENOMEM;
856         }
857
858         sk->sk_ack_backlog++;
859
860         lock_sock_nested(child, SINGLE_DEPTH_NESTING);
861
862         child->sk_state = SS_CONNECTED;
863
864         vchild = vsock_sk(child);
865         vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid),
866                         le32_to_cpu(pkt->hdr.dst_port));
867         vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid),
868                         le32_to_cpu(pkt->hdr.src_port));
869
870         vsock_insert_connected(vchild);
871         vsock_enqueue_accept(sk, child);
872         virtio_transport_send_response(vchild, pkt);
873
874         release_sock(child);
875
876         sk->sk_data_ready(sk);
877         return 0;
878 }
879
880 static bool virtio_transport_space_update(struct sock *sk,
881                                           struct virtio_vsock_pkt *pkt)
882 {
883         struct vsock_sock *vsk = vsock_sk(sk);
884         struct virtio_vsock_sock *vvs = vsk->trans;
885         bool space_available;
886
887         /* buf_alloc and fwd_cnt is always included in the hdr */
888         spin_lock_bh(&vvs->tx_lock);
889         vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc);
890         vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt);
891         space_available = virtio_transport_has_space(vsk);
892         spin_unlock_bh(&vvs->tx_lock);
893         return space_available;
894 }
895
896 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
897  * lock.
898  */
899 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
900 {
901         struct sockaddr_vm src, dst;
902         struct vsock_sock *vsk;
903         struct sock *sk;
904         bool space_available;
905
906         vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid),
907                         le32_to_cpu(pkt->hdr.src_port));
908         vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid),
909                         le32_to_cpu(pkt->hdr.dst_port));
910
911         trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port,
912                                         dst.svm_cid, dst.svm_port,
913                                         le32_to_cpu(pkt->hdr.len),
914                                         le16_to_cpu(pkt->hdr.type),
915                                         le16_to_cpu(pkt->hdr.op),
916                                         le32_to_cpu(pkt->hdr.flags),
917                                         le32_to_cpu(pkt->hdr.buf_alloc),
918                                         le32_to_cpu(pkt->hdr.fwd_cnt));
919
920         if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
921                 (void)virtio_transport_reset_no_sock(pkt);
922                 goto free_pkt;
923         }
924
925         /* The socket must be in connected or bound table
926          * otherwise send reset back
927          */
928         sk = vsock_find_connected_socket(&src, &dst);
929         if (!sk) {
930                 sk = vsock_find_bound_socket(&dst);
931                 if (!sk) {
932                         (void)virtio_transport_reset_no_sock(pkt);
933                         goto free_pkt;
934                 }
935         }
936
937         vsk = vsock_sk(sk);
938
939         space_available = virtio_transport_space_update(sk, pkt);
940
941         lock_sock(sk);
942
943         /* Update CID in case it has changed after a transport reset event */
944         vsk->local_addr.svm_cid = dst.svm_cid;
945
946         if (space_available)
947                 sk->sk_write_space(sk);
948
949         switch (sk->sk_state) {
950         case VSOCK_SS_LISTEN:
951                 virtio_transport_recv_listen(sk, pkt);
952                 virtio_transport_free_pkt(pkt);
953                 break;
954         case SS_CONNECTING:
955                 virtio_transport_recv_connecting(sk, pkt);
956                 virtio_transport_free_pkt(pkt);
957                 break;
958         case SS_CONNECTED:
959                 virtio_transport_recv_connected(sk, pkt);
960                 break;
961         case SS_DISCONNECTING:
962                 virtio_transport_recv_disconnecting(sk, pkt);
963                 virtio_transport_free_pkt(pkt);
964                 break;
965         default:
966                 virtio_transport_free_pkt(pkt);
967                 break;
968         }
969         release_sock(sk);
970
971         /* Release refcnt obtained when we fetched this socket out of the
972          * bound or connected list.
973          */
974         sock_put(sk);
975         return;
976
977 free_pkt:
978         virtio_transport_free_pkt(pkt);
979 }
980 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt);
981
982 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt)
983 {
984         kfree(pkt->buf);
985         kfree(pkt);
986 }
987 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt);
988
989 MODULE_LICENSE("GPL v2");
990 MODULE_AUTHOR("Asias He");
991 MODULE_DESCRIPTION("common code for virtio vsock");