lib.rs 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. //! Simple communication layer for passing messages among multiple parties.
  2. #![warn(missing_docs)]
  3. mod communicator;
  4. pub mod tcp;
  5. pub mod unix;
  6. pub use crate::communicator::{Communicator, MyFut};
  7. use bincode::error::{DecodeError, EncodeError};
  8. use std::collections::HashMap;
  9. use std::io::Error as IoError;
  10. use std::sync::mpsc::{RecvError, SendError};
  11. /// Trait that captures the requirements for data types to be sent/received.
  12. pub trait Serializable: Clone + Send + 'static + bincode::Encode + bincode::Decode {}
  13. impl<T> Serializable for T where T: Clone + Send + 'static + bincode::Encode + bincode::Decode {}
  14. /// C++-style Future type. Represents data of type T that we expect to receive.
  15. pub trait Fut<T> {
  16. /// Wait until the data has arrived and obtain it.
  17. fn get(self) -> Result<T, Error>;
  18. }
  19. /// Recorded communication statistics for one point-to-point channel.
  20. #[derive(Debug, Default, Clone, Copy, serde::Serialize)]
  21. pub struct CommunicationStats {
  22. /// Number of messages received.
  23. pub num_msgs_received: usize,
  24. /// Number of bytes received over all messages.
  25. pub num_bytes_received: usize,
  26. /// Number of messages sent.
  27. pub num_msgs_sent: usize,
  28. /// Number of bytes sent over all messages.
  29. pub num_bytes_sent: usize,
  30. }
  31. /// Abstract communication interface between multiple parties
  32. pub trait AbstractCommunicator {
  33. /// Future type to represent expected data.
  34. type Fut<T: Serializable>: Fut<T>;
  35. /// How many parties N there are in total.
  36. fn get_num_parties(&self) -> usize;
  37. /// My party id in [0, N).
  38. fn get_my_id(&self) -> usize;
  39. /// Send a message of type T to given party.
  40. fn send<T: Serializable>(&mut self, party_id: usize, val: T) -> Result<(), Error>;
  41. /// Send a message of multiple elements of type T to given party.
  42. fn send_slice<T: Serializable>(&mut self, party_id: usize, val: &[T]) -> Result<(), Error>;
  43. /// Send a message of type T to next party.
  44. fn send_next<T: Serializable>(&mut self, val: T) -> Result<(), Error> {
  45. self.send((self.get_my_id() + 1) % self.get_num_parties(), val)
  46. }
  47. /// Send a message of multiple elements of type T to next party.
  48. fn send_slice_next<T: Serializable>(&mut self, val: &[T]) -> Result<(), Error> {
  49. self.send_slice((self.get_my_id() + 1) % self.get_num_parties(), val)
  50. }
  51. /// Send a message of type T to previous party.
  52. fn send_previous<T: Serializable>(&mut self, val: T) -> Result<(), Error> {
  53. self.send(
  54. (self.get_num_parties() + self.get_my_id() - 1) % self.get_num_parties(),
  55. val,
  56. )
  57. }
  58. /// Send a message of multiple elements of type T to previous party.
  59. fn send_slice_previous<T: Serializable>(&mut self, val: &[T]) -> Result<(), Error> {
  60. self.send_slice(
  61. (self.get_num_parties() + self.get_my_id() - 1) % self.get_num_parties(),
  62. val,
  63. )
  64. }
  65. /// Send a message of type T all parties.
  66. fn broadcast<T: Serializable>(&mut self, val: T) -> Result<(), Error> {
  67. let my_id = self.get_my_id();
  68. for party_id in 0..self.get_num_parties() {
  69. if party_id == my_id {
  70. continue;
  71. }
  72. self.send(party_id, val.clone())?;
  73. }
  74. Ok(())
  75. }
  76. /// Expect to receive message of type T from given party. Use the returned future to obtain
  77. /// the message once it has arrived.
  78. fn receive<T: Serializable>(&mut self, party_id: usize) -> Result<Self::Fut<T>, Error>;
  79. /// Expect to receive message of type T from the next party. Use the returned future to obtain
  80. /// the message once it has arrived.
  81. fn receive_next<T: Serializable>(&mut self) -> Result<Self::Fut<T>, Error> {
  82. self.receive((self.get_my_id() + 1) % self.get_num_parties())
  83. }
  84. /// Expect to receive message of type T from the previous party. Use the returned future to obtain
  85. /// the message once it has arrived.
  86. fn receive_previous<T: Serializable>(&mut self) -> Result<Self::Fut<T>, Error> {
  87. self.receive((self.get_num_parties() + self.get_my_id() - 1) % self.get_num_parties())
  88. }
  89. /// Shutdown the communication system.
  90. fn shutdown(&mut self);
  91. /// Obtain statistics about how many messages/bytes were send/received.
  92. fn get_stats(&self) -> HashMap<usize, CommunicationStats>;
  93. /// Reset statistics.
  94. fn reset_stats(&mut self);
  95. }
  96. /// Custom error type.
  97. #[derive(Debug)]
  98. pub enum Error {
  99. /// The connection has not been established.
  100. ConnectionSetupError,
  101. /// The API was not used correctly.
  102. LogicError(String),
  103. /// Some std::io::Error appeared.
  104. IoError(IoError),
  105. /// Some std::sync::mpsc::RecvError appeared.
  106. RecvError(RecvError),
  107. /// Some std::sync::mpsc::SendError appeared.
  108. SendError(String),
  109. /// Some bincode::error::DecodeError appeared.
  110. EncodeError(EncodeError),
  111. /// Some bincode::error::DecodeError appeared.
  112. DecodeError(DecodeError),
  113. /// Serialization of data failed.
  114. SerializationError(String),
  115. /// Deserialization of data failed.
  116. DeserializationError(String),
  117. }
  118. /// Enable automatic conversions from std::io::Error.
  119. impl From<IoError> for Error {
  120. fn from(e: IoError) -> Error {
  121. Error::IoError(e)
  122. }
  123. }
  124. /// Enable automatic conversions from std::sync::mpsc::RecvError.
  125. impl From<RecvError> for Error {
  126. fn from(e: RecvError) -> Error {
  127. Error::RecvError(e)
  128. }
  129. }
  130. /// Enable automatic conversions from std::sync::mpsc::SendError.
  131. impl<T> From<SendError<T>> for Error {
  132. fn from(e: SendError<T>) -> Error {
  133. Error::SendError(e.to_string())
  134. }
  135. }
  136. /// Enable automatic conversions from bincode::error::EncodeError.
  137. impl From<EncodeError> for Error {
  138. fn from(e: EncodeError) -> Error {
  139. Error::EncodeError(e)
  140. }
  141. }
  142. /// Enable automatic conversions from bincode::error::DecodeError.
  143. impl From<DecodeError> for Error {
  144. fn from(e: DecodeError) -> Error {
  145. Error::DecodeError(e)
  146. }
  147. }