lib.rs 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. pub mod communicator;
  2. pub mod tcp;
  3. pub mod unix;
  4. use bincode::error::{DecodeError, EncodeError};
  5. use std::collections::HashMap;
  6. use std::io::Error as IoError;
  7. use std::sync::mpsc::{RecvError, SendError};
  8. pub trait Serializable: Clone + Send + 'static + bincode::Encode + bincode::Decode {}
  9. impl<T> Serializable for T where T: Clone + Send + 'static + bincode::Encode + bincode::Decode {}
  10. /// Represent data of type T that we expect to receive
  11. pub trait Fut<T> {
  12. /// Wait until the data has arrived and obtain it.
  13. fn get(self) -> Result<T, Error>;
  14. }
  15. #[derive(Debug, Clone, Copy)]
  16. pub struct CommunicationStats {
  17. pub num_msgs_received: usize,
  18. pub num_bytes_received: usize,
  19. pub num_msgs_sent: usize,
  20. pub num_bytes_sent: usize,
  21. }
  22. /// Abstract communication interface between multiple parties
  23. pub trait AbstractCommunicator {
  24. type Fut<T: Serializable>: Fut<T>;
  25. /// How many parties N there are in total
  26. fn get_num_parties(&self) -> usize;
  27. /// My party id in [0, N)
  28. fn get_my_id(&self) -> usize;
  29. /// Send a message of type T to given party
  30. fn send<T: Serializable>(&mut self, party_id: usize, val: T) -> Result<(), Error>;
  31. /// Send a message of type T to next party
  32. fn send_next<T: Serializable>(&mut self, val: T) -> Result<(), Error> {
  33. self.send((self.get_my_id() + 1) % self.get_num_parties(), val)
  34. }
  35. /// Send a message of type T to previous party
  36. fn send_previous<T: Serializable>(&mut self, val: T) -> Result<(), Error> {
  37. self.send(
  38. (self.get_num_parties() + self.get_my_id() - 1) % self.get_num_parties(),
  39. val,
  40. )
  41. }
  42. /// Send a message of type T all parties
  43. fn broadcast<T: Serializable>(&mut self, val: T) -> Result<(), Error> {
  44. let my_id = self.get_my_id();
  45. for party_id in 0..self.get_num_parties() {
  46. if party_id == my_id {
  47. continue;
  48. }
  49. self.send(party_id, val.clone())?;
  50. }
  51. Ok(())
  52. }
  53. /// Expect to receive message of type T from given party. Use the returned future to obtain
  54. /// the message once it has arrived.
  55. fn receive<T: Serializable>(&mut self, party_id: usize) -> Result<Self::Fut<T>, Error>;
  56. /// Expect to receive message of type T from the next party. Use the returned future to obtain
  57. /// the message once it has arrived.
  58. fn receive_next<T: Serializable>(&mut self) -> Result<Self::Fut<T>, Error> {
  59. self.receive((self.get_my_id() + 1) % self.get_num_parties())
  60. }
  61. /// Expect to receive message of type T from the previous party. Use the returned future to obtain
  62. /// the message once it has arrived.
  63. fn receive_previous<T: Serializable>(&mut self) -> Result<Self::Fut<T>, Error> {
  64. self.receive((self.get_num_parties() + self.get_my_id() - 1) % self.get_num_parties())
  65. }
  66. /// Shutdown the communication system
  67. fn shutdown(&mut self) -> HashMap<usize, CommunicationStats>;
  68. }
  69. /// Custom error type
  70. #[derive(Debug)]
  71. pub enum Error {
  72. /// The connection has not been established
  73. ConnectionSetupError,
  74. /// The API was not used correctly
  75. LogicError(String),
  76. /// Some std::io::Error appeared
  77. IoError(IoError),
  78. /// Some std::sync::mpsc::RecvError appeared
  79. RecvError(RecvError),
  80. /// Some std::sync::mpsc::SendError appeared
  81. SendError(String),
  82. /// Some bincode::error::DecodeError appeared
  83. EncodeError(EncodeError),
  84. /// Some bincode::error::DecodeError appeared
  85. DecodeError(DecodeError),
  86. /// Serialization of data failed
  87. SerializationError(String),
  88. /// Deserialization of data failed
  89. DeserializationError(String),
  90. }
  91. /// Enable automatic conversions from std::io::Error
  92. impl From<IoError> for Error {
  93. fn from(e: IoError) -> Error {
  94. Error::IoError(e)
  95. }
  96. }
  97. /// Enable automatic conversions from std::sync::mpsc::RecvError
  98. impl From<RecvError> for Error {
  99. fn from(e: RecvError) -> Error {
  100. Error::RecvError(e)
  101. }
  102. }
  103. /// Enable automatic conversions from std::sync::mpsc::SendError
  104. impl<T> From<SendError<T>> for Error {
  105. fn from(e: SendError<T>) -> Error {
  106. Error::SendError(e.to_string())
  107. }
  108. }
  109. /// Enable automatic conversions from bincode::error::EncodeError
  110. impl From<EncodeError> for Error {
  111. fn from(e: EncodeError) -> Error {
  112. Error::EncodeError(e)
  113. }
  114. }
  115. /// Enable automatic conversions from bincode::error::DecodeError
  116. impl From<DecodeError> for Error {
  117. fn from(e: DecodeError) -> Error {
  118. Error::DecodeError(e)
  119. }
  120. }