coroutine.hpp 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. #ifndef __COROUTINE_HPP__
  2. #define __COROUTINE_HPP__
  3. #include <vector>
  4. #include "corotypes.hpp"
  5. #include "mpcio.hpp"
  6. // The top-level coroutine runner will call run_coroutines with
  7. // a MPCTIO, and we should call its send() method. Subcoroutines that
  8. // launch their own coroutines (and coroutine running) will call
  9. // run_coroutines with a yield_t instead, which we should just call, in
  10. // order to yield to the next higher level of coroutine runner.
  11. static inline void send_or_yield(MPCTIO &tio) { tio.send(); }
  12. static inline void send_or_yield(yield_t &yield) { yield(); }
  13. // Get and set communication_nthreads for an MPCTIO; for a yield_t, this
  14. // is a no-op.
  15. static inline int getset_communication_nthreads(MPCTIO &tio, int nthreads = 0) {
  16. return tio.comm_nthreads(nthreads);
  17. }
  18. static inline int getset_communication_nthreads(yield_t &yield, int nthreads = 0) {
  19. return 0;
  20. }
  21. // Use this version if you have a variable number of coroutines (or a
  22. // larger constant number than is supported below).
  23. template <typename T>
  24. inline void run_coroutines(T &mpctio_or_yield, std::vector<coro_t> &coroutines) {
  25. // If there's more than one coroutine, at most one of them can have
  26. // communication_nthreads larger than 1 (see mpcio.hpp for details).
  27. // For now, we set them _all_ to 1 (if there's more than one of
  28. // them), and restore communication_nthreads when they're all done.
  29. int saved_communication_nthreads = 0;
  30. if (coroutines.size() > 1) {
  31. saved_communication_nthreads =
  32. getset_communication_nthreads(mpctio_or_yield, 1);
  33. }
  34. // Loop until all the coroutines are finished
  35. bool finished = false;
  36. while(!finished) {
  37. // If this current function is not itself a coroutine (i.e.,
  38. // this is the top-level function that launches all the
  39. // coroutines), here's where to call send(). Otherwise, call
  40. // yield() here to let other coroutines at this level run.
  41. send_or_yield(mpctio_or_yield);
  42. finished = true;
  43. for (auto &c : coroutines) {
  44. // This tests if coroutine c still has work to do (is not
  45. // finished)
  46. if (c) {
  47. finished = false;
  48. // Resume coroutine c from the point it yield()ed
  49. c();
  50. }
  51. }
  52. }
  53. if (saved_communication_nthreads > 0) {
  54. getset_communication_nthreads(mpctio_or_yield,
  55. saved_communication_nthreads);
  56. }
  57. }
  58. // Use one of these versions if you have a small fixed number of
  59. // coroutines. You can of course also use the above, but the API for
  60. // this version is simpler.
  61. template <typename T>
  62. inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1)
  63. {
  64. std::vector<coro_t> coroutines;
  65. coroutines.emplace_back(l1);
  66. run_coroutines(mpctio_or_yield, coroutines);
  67. }
  68. template <typename T>
  69. inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1,
  70. const coro_lambda_t &l2)
  71. {
  72. std::vector<coro_t> coroutines;
  73. coroutines.emplace_back(l1);
  74. coroutines.emplace_back(l2);
  75. run_coroutines(mpctio_or_yield, coroutines);
  76. }
  77. template <typename T>
  78. inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1,
  79. const coro_lambda_t &l2, const coro_lambda_t &l3)
  80. {
  81. std::vector<coro_t> coroutines;
  82. coroutines.emplace_back(l1);
  83. coroutines.emplace_back(l2);
  84. coroutines.emplace_back(l3);
  85. run_coroutines(mpctio_or_yield, coroutines);
  86. }
  87. template <typename T>
  88. inline void run_coroutines(T &mpctio_or_yield, const coro_lambda_t &l1,
  89. const coro_lambda_t &l2, const coro_lambda_t &l3,
  90. const coro_lambda_t &l4)
  91. {
  92. std::vector<coro_t> coroutines;
  93. coroutines.emplace_back(l1);
  94. coroutines.emplace_back(l2);
  95. coroutines.emplace_back(l3);
  96. coroutines.emplace_back(l4);
  97. run_coroutines(mpctio_or_yield, coroutines);
  98. }
  99. #endif