virtio_transport.c 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632
  1. /*
  2. * virtio transport for vsock
  3. *
  4. * Copyright (C) 2013-2015 Red Hat, Inc.
  5. * Author: Asias He <[email protected]>
  6. * Stefan Hajnoczi <[email protected]>
  7. *
  8. * Some of the code is take from Gerd Hoffmann <[email protected]>'s
  9. * early virtio-vsock proof-of-concept bits.
  10. *
  11. * This work is licensed under the terms of the GNU GPL, version 2.
  12. */
  13. #include <linux/spinlock.h>
  14. #include <linux/module.h>
  15. #include <linux/list.h>
  16. #include <linux/atomic.h>
  17. #include <linux/virtio.h>
  18. #include <linux/virtio_ids.h>
  19. #include <linux/virtio_config.h>
  20. #include <linux/virtio_vsock.h>
  21. #include <net/sock.h>
  22. #include <linux/mutex.h>
  23. #include <net/af_vsock.h>
  24. static struct workqueue_struct *virtio_vsock_workqueue;
  25. static struct virtio_vsock *the_virtio_vsock;
  26. static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
  27. struct virtio_vsock {
  28. struct virtio_device *vdev;
  29. struct virtqueue *vqs[VSOCK_VQ_MAX];
  30. /* Virtqueue processing is deferred to a workqueue */
  31. struct work_struct tx_work;
  32. struct work_struct rx_work;
  33. struct work_struct event_work;
  34. /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX]
  35. * must be accessed with tx_lock held.
  36. */
  37. struct mutex tx_lock;
  38. struct work_struct send_pkt_work;
  39. spinlock_t send_pkt_list_lock;
  40. struct list_head send_pkt_list;
  41. atomic_t queued_replies;
  42. /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX]
  43. * must be accessed with rx_lock held.
  44. */
  45. struct mutex rx_lock;
  46. int rx_buf_nr;
  47. int rx_buf_max_nr;
  48. /* The following fields are protected by event_lock.
  49. * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
  50. */
  51. struct mutex event_lock;
  52. struct virtio_vsock_event event_list[8];
  53. u32 guest_cid;
  54. };
  55. static struct virtio_vsock *virtio_vsock_get(void)
  56. {
  57. return the_virtio_vsock;
  58. }
  59. static u32 virtio_transport_get_local_cid(void)
  60. {
  61. struct virtio_vsock *vsock = virtio_vsock_get();
  62. if (!vsock)
  63. return VMADDR_CID_ANY;
  64. return vsock->guest_cid;
  65. }
  66. static void
  67. virtio_transport_send_pkt_work(struct work_struct *work)
  68. {
  69. struct virtio_vsock *vsock =
  70. container_of(work, struct virtio_vsock, send_pkt_work);
  71. struct virtqueue *vq;
  72. bool added = false;
  73. bool restart_rx = false;
  74. mutex_lock(&vsock->tx_lock);
  75. vq = vsock->vqs[VSOCK_VQ_TX];
  76. for (;;) {
  77. struct virtio_vsock_pkt *pkt;
  78. struct scatterlist hdr, buf, *sgs[2];
  79. int ret, in_sg = 0, out_sg = 0;
  80. bool reply;
  81. spin_lock_bh(&vsock->send_pkt_list_lock);
  82. if (list_empty(&vsock->send_pkt_list)) {
  83. spin_unlock_bh(&vsock->send_pkt_list_lock);
  84. break;
  85. }
  86. pkt = list_first_entry(&vsock->send_pkt_list,
  87. struct virtio_vsock_pkt, list);
  88. list_del_init(&pkt->list);
  89. spin_unlock_bh(&vsock->send_pkt_list_lock);
  90. reply = pkt->reply;
  91. sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
  92. sgs[out_sg++] = &hdr;
  93. if (pkt->buf) {
  94. sg_init_one(&buf, pkt->buf, pkt->len);
  95. sgs[out_sg++] = &buf;
  96. }
  97. ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL);
  98. /* Usually this means that there is no more space available in
  99. * the vq
  100. */
  101. if (ret < 0) {
  102. spin_lock_bh(&vsock->send_pkt_list_lock);
  103. list_add(&pkt->list, &vsock->send_pkt_list);
  104. spin_unlock_bh(&vsock->send_pkt_list_lock);
  105. break;
  106. }
  107. if (reply) {
  108. struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
  109. int val;
  110. val = atomic_dec_return(&vsock->queued_replies);
  111. /* Do we now have resources to resume rx processing? */
  112. if (val + 1 == virtqueue_get_vring_size(rx_vq))
  113. restart_rx = true;
  114. }
  115. added = true;
  116. }
  117. if (added)
  118. virtqueue_kick(vq);
  119. mutex_unlock(&vsock->tx_lock);
  120. if (restart_rx)
  121. queue_work(virtio_vsock_workqueue, &vsock->rx_work);
  122. }
  123. static int
  124. virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt)
  125. {
  126. struct virtio_vsock *vsock;
  127. int len = pkt->len;
  128. vsock = virtio_vsock_get();
  129. if (!vsock) {
  130. virtio_transport_free_pkt(pkt);
  131. return -ENODEV;
  132. }
  133. if (pkt->reply)
  134. atomic_inc(&vsock->queued_replies);
  135. spin_lock_bh(&vsock->send_pkt_list_lock);
  136. list_add_tail(&pkt->list, &vsock->send_pkt_list);
  137. spin_unlock_bh(&vsock->send_pkt_list_lock);
  138. queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
  139. return len;
  140. }
  141. static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
  142. {
  143. int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE;
  144. struct virtio_vsock_pkt *pkt;
  145. struct scatterlist hdr, buf, *sgs[2];
  146. struct virtqueue *vq;
  147. int ret;
  148. vq = vsock->vqs[VSOCK_VQ_RX];
  149. do {
  150. pkt = kzalloc(sizeof(*pkt), GFP_KERNEL);
  151. if (!pkt)
  152. break;
  153. pkt->buf = kmalloc(buf_len, GFP_KERNEL);
  154. if (!pkt->buf) {
  155. virtio_transport_free_pkt(pkt);
  156. break;
  157. }
  158. pkt->len = buf_len;
  159. sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr));
  160. sgs[0] = &hdr;
  161. sg_init_one(&buf, pkt->buf, buf_len);
  162. sgs[1] = &buf;
  163. ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL);
  164. if (ret) {
  165. virtio_transport_free_pkt(pkt);
  166. break;
  167. }
  168. vsock->rx_buf_nr++;
  169. } while (vq->num_free);
  170. if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
  171. vsock->rx_buf_max_nr = vsock->rx_buf_nr;
  172. virtqueue_kick(vq);
  173. }
  174. static void virtio_transport_tx_work(struct work_struct *work)
  175. {
  176. struct virtio_vsock *vsock =
  177. container_of(work, struct virtio_vsock, tx_work);
  178. struct virtqueue *vq;
  179. bool added = false;
  180. vq = vsock->vqs[VSOCK_VQ_TX];
  181. mutex_lock(&vsock->tx_lock);
  182. do {
  183. struct virtio_vsock_pkt *pkt;
  184. unsigned int len;
  185. virtqueue_disable_cb(vq);
  186. while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) {
  187. virtio_transport_free_pkt(pkt);
  188. added = true;
  189. }
  190. } while (!virtqueue_enable_cb(vq));
  191. mutex_unlock(&vsock->tx_lock);
  192. if (added)
  193. queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
  194. }
  195. /* Is there space left for replies to rx packets? */
  196. static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
  197. {
  198. struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
  199. int val;
  200. smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
  201. val = atomic_read(&vsock->queued_replies);
  202. return val < virtqueue_get_vring_size(vq);
  203. }
  204. static void virtio_transport_rx_work(struct work_struct *work)
  205. {
  206. struct virtio_vsock *vsock =
  207. container_of(work, struct virtio_vsock, rx_work);
  208. struct virtqueue *vq;
  209. vq = vsock->vqs[VSOCK_VQ_RX];
  210. mutex_lock(&vsock->rx_lock);
  211. do {
  212. virtqueue_disable_cb(vq);
  213. for (;;) {
  214. struct virtio_vsock_pkt *pkt;
  215. unsigned int len;
  216. if (!virtio_transport_more_replies(vsock)) {
  217. /* Stop rx until the device processes already
  218. * pending replies. Leave rx virtqueue
  219. * callbacks disabled.
  220. */
  221. goto out;
  222. }
  223. pkt = virtqueue_get_buf(vq, &len);
  224. if (!pkt) {
  225. break;
  226. }
  227. vsock->rx_buf_nr--;
  228. /* Drop short/long packets */
  229. if (unlikely(len < sizeof(pkt->hdr) ||
  230. len > sizeof(pkt->hdr) + pkt->len)) {
  231. virtio_transport_free_pkt(pkt);
  232. continue;
  233. }
  234. pkt->len = len - sizeof(pkt->hdr);
  235. virtio_transport_recv_pkt(pkt);
  236. }
  237. } while (!virtqueue_enable_cb(vq));
  238. out:
  239. if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
  240. virtio_vsock_rx_fill(vsock);
  241. mutex_unlock(&vsock->rx_lock);
  242. }
  243. /* event_lock must be held */
  244. static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
  245. struct virtio_vsock_event *event)
  246. {
  247. struct scatterlist sg;
  248. struct virtqueue *vq;
  249. vq = vsock->vqs[VSOCK_VQ_EVENT];
  250. sg_init_one(&sg, event, sizeof(*event));
  251. return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
  252. }
  253. /* event_lock must be held */
  254. static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
  255. {
  256. size_t i;
  257. for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
  258. struct virtio_vsock_event *event = &vsock->event_list[i];
  259. virtio_vsock_event_fill_one(vsock, event);
  260. }
  261. virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
  262. }
  263. static void virtio_vsock_reset_sock(struct sock *sk)
  264. {
  265. lock_sock(sk);
  266. sk->sk_state = SS_UNCONNECTED;
  267. sk->sk_err = ECONNRESET;
  268. sk->sk_error_report(sk);
  269. release_sock(sk);
  270. }
  271. static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
  272. {
  273. struct virtio_device *vdev = vsock->vdev;
  274. u64 guest_cid;
  275. vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
  276. &guest_cid, sizeof(guest_cid));
  277. vsock->guest_cid = le64_to_cpu(guest_cid);
  278. }
  279. /* event_lock must be held */
  280. static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
  281. struct virtio_vsock_event *event)
  282. {
  283. switch (le32_to_cpu(event->id)) {
  284. case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
  285. virtio_vsock_update_guest_cid(vsock);
  286. vsock_for_each_connected_socket(virtio_vsock_reset_sock);
  287. break;
  288. }
  289. }
  290. static void virtio_transport_event_work(struct work_struct *work)
  291. {
  292. struct virtio_vsock *vsock =
  293. container_of(work, struct virtio_vsock, event_work);
  294. struct virtqueue *vq;
  295. vq = vsock->vqs[VSOCK_VQ_EVENT];
  296. mutex_lock(&vsock->event_lock);
  297. do {
  298. struct virtio_vsock_event *event;
  299. unsigned int len;
  300. virtqueue_disable_cb(vq);
  301. while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
  302. if (len == sizeof(*event))
  303. virtio_vsock_event_handle(vsock, event);
  304. virtio_vsock_event_fill_one(vsock, event);
  305. }
  306. } while (!virtqueue_enable_cb(vq));
  307. virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
  308. mutex_unlock(&vsock->event_lock);
  309. }
  310. static void virtio_vsock_event_done(struct virtqueue *vq)
  311. {
  312. struct virtio_vsock *vsock = vq->vdev->priv;
  313. if (!vsock)
  314. return;
  315. queue_work(virtio_vsock_workqueue, &vsock->event_work);
  316. }
  317. static void virtio_vsock_tx_done(struct virtqueue *vq)
  318. {
  319. struct virtio_vsock *vsock = vq->vdev->priv;
  320. if (!vsock)
  321. return;
  322. queue_work(virtio_vsock_workqueue, &vsock->tx_work);
  323. }
  324. static void virtio_vsock_rx_done(struct virtqueue *vq)
  325. {
  326. struct virtio_vsock *vsock = vq->vdev->priv;
  327. if (!vsock)
  328. return;
  329. queue_work(virtio_vsock_workqueue, &vsock->rx_work);
  330. }
  331. static struct virtio_transport virtio_transport = {
  332. .transport = {
  333. .get_local_cid = virtio_transport_get_local_cid,
  334. .init = virtio_transport_do_socket_init,
  335. .destruct = virtio_transport_destruct,
  336. .release = virtio_transport_release,
  337. .connect = virtio_transport_connect,
  338. .shutdown = virtio_transport_shutdown,
  339. .dgram_bind = virtio_transport_dgram_bind,
  340. .dgram_dequeue = virtio_transport_dgram_dequeue,
  341. .dgram_enqueue = virtio_transport_dgram_enqueue,
  342. .dgram_allow = virtio_transport_dgram_allow,
  343. .stream_dequeue = virtio_transport_stream_dequeue,
  344. .stream_enqueue = virtio_transport_stream_enqueue,
  345. .stream_has_data = virtio_transport_stream_has_data,
  346. .stream_has_space = virtio_transport_stream_has_space,
  347. .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
  348. .stream_is_active = virtio_transport_stream_is_active,
  349. .stream_allow = virtio_transport_stream_allow,
  350. .notify_poll_in = virtio_transport_notify_poll_in,
  351. .notify_poll_out = virtio_transport_notify_poll_out,
  352. .notify_recv_init = virtio_transport_notify_recv_init,
  353. .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
  354. .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
  355. .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
  356. .notify_send_init = virtio_transport_notify_send_init,
  357. .notify_send_pre_block = virtio_transport_notify_send_pre_block,
  358. .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
  359. .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
  360. .set_buffer_size = virtio_transport_set_buffer_size,
  361. .set_min_buffer_size = virtio_transport_set_min_buffer_size,
  362. .set_max_buffer_size = virtio_transport_set_max_buffer_size,
  363. .get_buffer_size = virtio_transport_get_buffer_size,
  364. .get_min_buffer_size = virtio_transport_get_min_buffer_size,
  365. .get_max_buffer_size = virtio_transport_get_max_buffer_size,
  366. },
  367. .send_pkt = virtio_transport_send_pkt,
  368. };
  369. static int virtio_vsock_probe(struct virtio_device *vdev)
  370. {
  371. vq_callback_t *callbacks[] = {
  372. virtio_vsock_rx_done,
  373. virtio_vsock_tx_done,
  374. virtio_vsock_event_done,
  375. };
  376. static const char * const names[] = {
  377. "rx",
  378. "tx",
  379. "event",
  380. };
  381. struct virtio_vsock *vsock = NULL;
  382. int ret;
  383. ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
  384. if (ret)
  385. return ret;
  386. /* Only one virtio-vsock device per guest is supported */
  387. if (the_virtio_vsock) {
  388. ret = -EBUSY;
  389. goto out;
  390. }
  391. vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
  392. if (!vsock) {
  393. ret = -ENOMEM;
  394. goto out;
  395. }
  396. vsock->vdev = vdev;
  397. ret = vsock->vdev->config->find_vqs(vsock->vdev, VSOCK_VQ_MAX,
  398. vsock->vqs, callbacks, names);
  399. if (ret < 0)
  400. goto out;
  401. virtio_vsock_update_guest_cid(vsock);
  402. vsock->rx_buf_nr = 0;
  403. vsock->rx_buf_max_nr = 0;
  404. atomic_set(&vsock->queued_replies, 0);
  405. vdev->priv = vsock;
  406. the_virtio_vsock = vsock;
  407. mutex_init(&vsock->tx_lock);
  408. mutex_init(&vsock->rx_lock);
  409. mutex_init(&vsock->event_lock);
  410. spin_lock_init(&vsock->send_pkt_list_lock);
  411. INIT_LIST_HEAD(&vsock->send_pkt_list);
  412. INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
  413. INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
  414. INIT_WORK(&vsock->event_work, virtio_transport_event_work);
  415. INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
  416. mutex_lock(&vsock->rx_lock);
  417. virtio_vsock_rx_fill(vsock);
  418. mutex_unlock(&vsock->rx_lock);
  419. mutex_lock(&vsock->event_lock);
  420. virtio_vsock_event_fill(vsock);
  421. mutex_unlock(&vsock->event_lock);
  422. mutex_unlock(&the_virtio_vsock_mutex);
  423. return 0;
  424. out:
  425. kfree(vsock);
  426. mutex_unlock(&the_virtio_vsock_mutex);
  427. return ret;
  428. }
  429. static void virtio_vsock_remove(struct virtio_device *vdev)
  430. {
  431. struct virtio_vsock *vsock = vdev->priv;
  432. struct virtio_vsock_pkt *pkt;
  433. flush_work(&vsock->rx_work);
  434. flush_work(&vsock->tx_work);
  435. flush_work(&vsock->event_work);
  436. flush_work(&vsock->send_pkt_work);
  437. /* Reset all connected sockets when the device disappear */
  438. vsock_for_each_connected_socket(virtio_vsock_reset_sock);
  439. vdev->config->reset(vdev);
  440. mutex_lock(&vsock->rx_lock);
  441. while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
  442. virtio_transport_free_pkt(pkt);
  443. mutex_unlock(&vsock->rx_lock);
  444. mutex_lock(&vsock->tx_lock);
  445. while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
  446. virtio_transport_free_pkt(pkt);
  447. mutex_unlock(&vsock->tx_lock);
  448. spin_lock_bh(&vsock->send_pkt_list_lock);
  449. while (!list_empty(&vsock->send_pkt_list)) {
  450. pkt = list_first_entry(&vsock->send_pkt_list,
  451. struct virtio_vsock_pkt, list);
  452. list_del(&pkt->list);
  453. virtio_transport_free_pkt(pkt);
  454. }
  455. spin_unlock_bh(&vsock->send_pkt_list_lock);
  456. mutex_lock(&the_virtio_vsock_mutex);
  457. the_virtio_vsock = NULL;
  458. mutex_unlock(&the_virtio_vsock_mutex);
  459. vdev->config->del_vqs(vdev);
  460. kfree(vsock);
  461. }
  462. static struct virtio_device_id id_table[] = {
  463. { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
  464. { 0 },
  465. };
  466. static unsigned int features[] = {
  467. };
  468. static struct virtio_driver virtio_vsock_driver = {
  469. .feature_table = features,
  470. .feature_table_size = ARRAY_SIZE(features),
  471. .driver.name = KBUILD_MODNAME,
  472. .driver.owner = THIS_MODULE,
  473. .id_table = id_table,
  474. .probe = virtio_vsock_probe,
  475. .remove = virtio_vsock_remove,
  476. };
  477. static int __init virtio_vsock_init(void)
  478. {
  479. int ret;
  480. virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
  481. if (!virtio_vsock_workqueue)
  482. return -ENOMEM;
  483. ret = vsock_core_init(&virtio_transport.transport);
  484. if (ret)
  485. goto out_wq;
  486. ret = register_virtio_driver(&virtio_vsock_driver);
  487. if (ret)
  488. goto out_vci;
  489. return 0;
  490. out_vci:
  491. vsock_core_exit();
  492. out_wq:
  493. destroy_workqueue(virtio_vsock_workqueue);
  494. return ret;
  495. }
  496. static void __exit virtio_vsock_exit(void)
  497. {
  498. unregister_virtio_driver(&virtio_vsock_driver);
  499. vsock_core_exit();
  500. destroy_workqueue(virtio_vsock_workqueue);
  501. }
  502. module_init(virtio_vsock_init);
  503. module_exit(virtio_vsock_exit);
  504. MODULE_LICENSE("GPL v2");
  505. MODULE_AUTHOR("Asias He");
  506. MODULE_DESCRIPTION("virtio transport for vsock");
  507. MODULE_DEVICE_TABLE(virtio, id_table);