channel.cpp 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. /**
  2. \file channel.cpp
  3. \author michael.zohner@ec-spride.de
  4. \copyright ABY - A Framework for Efficient Mixed-protocol Secure Two-party Computation
  5. Copyright (C) 2019 Engineering Cryptographic Protocols Group, TU Darmstadt
  6. This program is free software: you can redistribute it and/or modify
  7. it under the terms of the GNU Lesser General Public License as published
  8. by the Free Software Foundation, either version 3 of the License, or
  9. (at your option) any later version.
  10. ABY is distributed in the hope that it will be useful,
  11. but WITHOUT ANY WARRANTY; without even the implied warranty of
  12. MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  13. GNU Lesser General Public License for more details.
  14. You should have received a copy of the GNU Lesser General Public License
  15. along with this program. If not, see <http://www.gnu.org/licenses/>.
  16. */
  17. #include "channel.h"
  18. #include "typedefs.h"
  19. #include "rcvthread.h"
  20. #include "sndthread.h"
  21. #include <cassert>
  22. #include <cstring>
  23. channel::channel(uint8_t channelid, RcvThread* rcver, SndThread* snder)
  24. : m_bChannelID(channelid), m_cRcver(rcver), m_cSnder(snder),
  25. m_eRcved(std::make_unique<CEvent>()), m_eFin(std::make_unique<CEvent>()),
  26. m_bSndAlive(true), m_bRcvAlive(true),
  27. m_qRcvedBlocks(rcver->add_listener(channelid, m_eRcved.get(), m_eFin.get())),
  28. m_qRcvedBlocks_mutex_(rcver->get_listener_mutex(channelid))
  29. {
  30. assert(rcver->getlock() == snder->getlock());
  31. }
  32. channel::~channel() {
  33. if(m_bRcvAlive) {
  34. m_cRcver->remove_listener(m_bChannelID);
  35. }
  36. }
  37. void channel::send(uint8_t* buf, uint64_t nbytes) {
  38. assert(m_bSndAlive);
  39. m_cSnder->add_snd_task(m_bChannelID, nbytes, buf);
  40. }
  41. void channel::blocking_send(CEvent* eventcaller, uint8_t* buf, uint64_t nbytes) {
  42. assert(m_bSndAlive);
  43. m_cSnder->add_event_snd_task(eventcaller, m_bChannelID, nbytes, buf);
  44. eventcaller->Wait();
  45. }
  46. void channel::send_id_len(uint8_t* buf, uint64_t nbytes, uint64_t id, uint64_t len) {
  47. assert(m_bSndAlive);
  48. m_cSnder->add_snd_task_start_len(m_bChannelID, nbytes, buf, id, len);
  49. }
  50. void channel::blocking_send_id_len(CEvent* eventcaller, uint8_t* buf, uint64_t nbytes, uint64_t id, uint64_t len) {
  51. assert(m_bSndAlive);
  52. m_cSnder->add_event_snd_task_start_len(eventcaller, m_bChannelID, nbytes, buf, id, len);
  53. eventcaller->Wait();
  54. }
  55. //buf needs to be freed, data contains the payload
  56. uint8_t* channel::blocking_receive_id_len(uint8_t** data, uint64_t* id, uint64_t* len) {
  57. uint8_t* buf = blocking_receive();
  58. *data = buf;
  59. *id = *((uint64_t*) *data);
  60. (*data) += sizeof(uint64_t);
  61. *len = *((uint64_t*) *data);
  62. (*data) += sizeof(uint64_t);
  63. return buf;
  64. }
  65. bool channel::queue_empty() const {
  66. std::lock_guard<std::mutex> lock(m_qRcvedBlocks_mutex_);
  67. bool qempty = m_qRcvedBlocks->empty();
  68. return qempty;
  69. }
  70. uint8_t* channel::blocking_receive() {
  71. assert(m_bRcvAlive);
  72. while(queue_empty())
  73. m_eRcved->Wait();
  74. rcv_ctx* ret = nullptr;
  75. uint8_t* ret_block = nullptr;
  76. {
  77. std::lock_guard<std::mutex> lock(m_qRcvedBlocks_mutex_);
  78. ret = (rcv_ctx*) m_qRcvedBlocks->front();
  79. ret_block = ret->buf;
  80. m_qRcvedBlocks->pop();
  81. }
  82. free(ret);
  83. return ret_block;
  84. }
  85. void channel::blocking_receive(uint8_t* rcvbuf, uint64_t rcvsize) {
  86. assert(m_bRcvAlive);
  87. while(queue_empty())
  88. m_eRcved->Wait();
  89. std::unique_lock<std::mutex> lock(m_qRcvedBlocks_mutex_);
  90. rcv_ctx* ret = (rcv_ctx*) m_qRcvedBlocks->front();
  91. uint8_t* ret_block = ret->buf;
  92. uint64_t rcved_this_call = ret->rcvbytes;
  93. if(rcved_this_call == rcvsize) {
  94. m_qRcvedBlocks->pop();
  95. lock.unlock();
  96. free(ret);
  97. } else if(rcvsize < rcved_this_call) {
  98. //if the block contains too much data, copy only the receive size
  99. ret->rcvbytes -= rcvsize;
  100. uint8_t* newbuf = (uint8_t*) malloc(ret->rcvbytes);
  101. memcpy(newbuf, ret->buf+rcvsize, ret->rcvbytes);
  102. ret->buf = newbuf;
  103. lock.unlock();
  104. rcved_this_call = rcvsize;
  105. } else {
  106. //I want to receive more data than are in that block. Perform recursive call (might become troublesome for too many recursion steps)
  107. m_qRcvedBlocks->pop();
  108. lock.unlock();
  109. free(ret);
  110. uint8_t* new_rcvbuf_start = rcvbuf + rcved_this_call;
  111. uint64_t new_rcvsize = rcvsize -rcved_this_call;
  112. blocking_receive(new_rcvbuf_start, new_rcvsize);
  113. }
  114. memcpy(rcvbuf, ret_block, rcved_this_call);
  115. free(ret_block);
  116. }
  117. bool channel::is_alive() {
  118. return (!(queue_empty() && m_eFin->IsSet()));
  119. }
  120. bool channel::data_available() {
  121. return !queue_empty();
  122. }
  123. void channel::signal_end() {
  124. m_cSnder->signal_end(m_bChannelID);
  125. m_bSndAlive = false;
  126. }
  127. void channel::wait_for_fin() {
  128. m_eFin->Wait();
  129. m_bRcvAlive = false;
  130. }
  131. void channel::synchronize_end() {
  132. if(m_bSndAlive)
  133. signal_end();
  134. if(m_bRcvAlive)
  135. m_cRcver->flush_queue(m_bChannelID);
  136. if(m_bRcvAlive)
  137. wait_for_fin();
  138. }