tcp.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. //! Functionality for communicators using TCP sockets.
  2. use crate::Communicator;
  3. use crate::{AbstractCommunicator, Error};
  4. use std::collections::{HashMap, HashSet};
  5. use std::io::{Read, Write};
  6. use std::net::{TcpListener, TcpStream};
  7. use std::thread;
  8. use std::time::Duration;
  9. /// Network connection options for a single party: Either we listen for an incoming connection, or
  10. /// we connect to a given host and port.
  11. #[derive(Debug, Clone, PartialEq, Eq)]
  12. pub enum NetworkPartyInfo {
  13. /// Listen for the other party to connect.
  14. Listen,
  15. /// Connect to the other party at the given host and port.
  16. Connect(String, u16),
  17. }
  18. /// Network connection options
  19. #[derive(Debug, Clone)]
  20. pub struct NetworkOptions {
  21. /// Which address to listen on for incoming connections
  22. pub listen_host: String,
  23. /// Which port to listen on for incoming connections
  24. pub listen_port: u16,
  25. /// Connection info for each party
  26. pub connect_info: Vec<NetworkPartyInfo>,
  27. /// How long to try connecting before aborting
  28. pub connect_timeout_seconds: usize,
  29. }
  30. fn tcp_connect(
  31. my_id: usize,
  32. other_id: usize,
  33. host: &str,
  34. port: u16,
  35. timeout_seconds: usize,
  36. ) -> Result<TcpStream, Error> {
  37. // repeatedly try to connect
  38. fn connect_socket(host: &str, port: u16, timeout_seconds: usize) -> Result<TcpStream, Error> {
  39. // try every 100ms
  40. for _ in 0..(10 * timeout_seconds) {
  41. if let Ok(socket) = TcpStream::connect((host, port)) {
  42. return Ok(socket);
  43. }
  44. thread::sleep(Duration::from_millis(100));
  45. }
  46. match TcpStream::connect((host, port)) {
  47. Ok(socket) => Ok(socket),
  48. Err(e) => Err(Error::IoError(e)),
  49. }
  50. }
  51. // connect to the other party
  52. let mut stream = connect_socket(host, port, timeout_seconds)?;
  53. {
  54. // send our party id
  55. let bytes_written = stream.write(&(my_id as u32).to_be_bytes())?;
  56. if bytes_written != 4 {
  57. return Err(Error::ConnectionSetupError);
  58. }
  59. // check that we talk to the right party
  60. let mut other_id_bytes = [0u8; 4];
  61. stream.read_exact(&mut other_id_bytes)?;
  62. if u32::from_be_bytes(other_id_bytes) != other_id as u32 {
  63. return Err(Error::ConnectionSetupError);
  64. }
  65. }
  66. Ok(stream)
  67. }
  68. fn tcp_accept_connections(
  69. my_id: usize,
  70. options: &NetworkOptions,
  71. ) -> Result<HashMap<usize, TcpStream>, Error> {
  72. // prepare function output
  73. let mut output = HashMap::<usize, TcpStream>::new();
  74. // compute set of parties that should connect to us
  75. let mut expected_parties: HashSet<usize> = options
  76. .connect_info
  77. .iter()
  78. .enumerate()
  79. .filter_map(|(party_id, npi)| {
  80. if party_id != my_id && *npi == NetworkPartyInfo::Listen {
  81. Some(party_id)
  82. } else {
  83. None
  84. }
  85. })
  86. .collect();
  87. // if nobody should connect to us, we are done
  88. if expected_parties.is_empty() {
  89. return Ok(output);
  90. }
  91. // create a listender and iterate over incoming connections
  92. let listener = TcpListener::bind((options.listen_host.clone(), options.listen_port))?;
  93. for mut stream in listener.incoming().filter_map(Result::ok) {
  94. // see which party has connected
  95. let mut other_id_bytes = [0u8; 4];
  96. if stream.read_exact(&mut other_id_bytes).is_err() {
  97. continue;
  98. }
  99. let other_id = u32::from_be_bytes(other_id_bytes) as usize;
  100. // check if we expect this party
  101. if !expected_parties.contains(&other_id) {
  102. continue;
  103. }
  104. // respond with our party id
  105. if stream.write_all(&(my_id as u32).to_be_bytes()).is_err() {
  106. continue;
  107. }
  108. // connection has been established
  109. expected_parties.remove(&other_id);
  110. output.insert(other_id, stream);
  111. // check if we have received connections from every party
  112. if expected_parties.is_empty() {
  113. break;
  114. }
  115. }
  116. if !expected_parties.is_empty() {
  117. Err(Error::ConnectionSetupError)
  118. } else {
  119. Ok(output)
  120. }
  121. }
  122. /// Setup TCP connections
  123. pub fn setup_connection(
  124. num_parties: usize,
  125. my_id: usize,
  126. options: &NetworkOptions,
  127. ) -> Result<HashMap<usize, TcpStream>, Error> {
  128. // make a copy of the options to pass it into the new thread
  129. let options_cpy: NetworkOptions = (*options).clone();
  130. // spawn thread to listen for incoming connections
  131. let listen_thread_handle = thread::spawn(move || tcp_accept_connections(my_id, &options_cpy));
  132. // prepare the map of connection we will return
  133. let mut output = HashMap::with_capacity(num_parties - 1);
  134. // connect to all parties that we are supposed to connect to
  135. for (party_id, info) in options.connect_info.iter().enumerate() {
  136. if party_id == my_id {
  137. continue;
  138. }
  139. match info {
  140. NetworkPartyInfo::Listen => {}
  141. NetworkPartyInfo::Connect(host, port) => {
  142. output.insert(
  143. party_id,
  144. tcp_connect(
  145. my_id,
  146. party_id,
  147. host,
  148. *port,
  149. options.connect_timeout_seconds,
  150. )?,
  151. );
  152. }
  153. }
  154. }
  155. // join the listen thread and obtain the connections that reached us
  156. let accepted_connections = match listen_thread_handle.join() {
  157. Ok(accepted_connections) => accepted_connections,
  158. Err(_) => return Err(Error::ConnectionSetupError),
  159. }?;
  160. // return the union of both maps
  161. output.extend(accepted_connections);
  162. Ok(output)
  163. }
  164. /// Create communicator using TCP connections
  165. pub fn make_tcp_communicator(
  166. num_parties: usize,
  167. my_id: usize,
  168. options: &NetworkOptions,
  169. ) -> Result<impl AbstractCommunicator, Error> {
  170. // create connections with other parties
  171. let stream_map = setup_connection(num_parties, my_id, options)?;
  172. stream_map
  173. .iter()
  174. .for_each(|(_, s)| s.set_nodelay(true).expect("set_nodelay failed"));
  175. // use streams as reader/writer pairs
  176. let rw_map = stream_map
  177. .into_iter()
  178. .map(|(party_id, stream)| (party_id, (stream.try_clone().unwrap(), stream)))
  179. .collect();
  180. // create new communicator
  181. Ok(Communicator::from_reader_writer(num_parties, my_id, rw_map))
  182. }
  183. /// Create communicator using TCP connections via localhost
  184. pub fn make_local_tcp_communicators(num_parties: usize) -> Vec<impl AbstractCommunicator> {
  185. let ports: [u16; 3] = [20_000, 20_001, 20_002];
  186. let opts: Vec<_> = (0..num_parties)
  187. .map(|party_id| NetworkOptions {
  188. listen_host: "localhost".to_owned(),
  189. listen_port: ports[party_id],
  190. connect_info: (0..num_parties)
  191. .map(|other_id| {
  192. if other_id < party_id {
  193. NetworkPartyInfo::Connect("localhost".to_owned(), ports[other_id])
  194. } else {
  195. NetworkPartyInfo::Listen
  196. }
  197. })
  198. .collect(),
  199. connect_timeout_seconds: 3,
  200. })
  201. .collect();
  202. let communicators: Vec<_> = opts
  203. .iter()
  204. .enumerate()
  205. .map(|(party_id, opts)| {
  206. let opts_cpy = (*opts).clone();
  207. thread::spawn(move || make_tcp_communicator(num_parties, party_id, &opts_cpy))
  208. })
  209. .collect();
  210. communicators
  211. .into_iter()
  212. .map(|h| h.join().unwrap().unwrap())
  213. .collect()
  214. }
  215. #[cfg(test)]
  216. mod tests {
  217. use super::*;
  218. use crate::Fut;
  219. use std::thread;
  220. #[test]
  221. fn test_tcp_communicators() {
  222. let num_parties = 3;
  223. let msg_0: u8 = 42;
  224. let msg_1: u32 = 0x_dead_beef;
  225. let msg_2: [u32; 2] = [0x_1333_3337, 0x_c0ff_ffee];
  226. let ports: [u16; 3] = [20_000, 20_001, 20_002];
  227. let opts: Vec<_> = (0..num_parties)
  228. .map(|party_id| NetworkOptions {
  229. listen_host: "localhost".to_owned(),
  230. listen_port: ports[party_id],
  231. connect_info: (0..num_parties)
  232. .map(|other_id| {
  233. if other_id < party_id {
  234. NetworkPartyInfo::Connect("localhost".to_owned(), ports[other_id])
  235. } else {
  236. NetworkPartyInfo::Listen
  237. }
  238. })
  239. .collect(),
  240. connect_timeout_seconds: 3,
  241. })
  242. .collect();
  243. let communicators: Vec<_> = opts
  244. .iter()
  245. .enumerate()
  246. .map(|(party_id, opts)| {
  247. let opts_cpy = (*opts).clone();
  248. thread::spawn(move || make_tcp_communicator(num_parties, party_id, &opts_cpy))
  249. })
  250. .collect();
  251. let communicators: Vec<_> = communicators
  252. .into_iter()
  253. .map(|h| h.join().unwrap().unwrap())
  254. .collect();
  255. let thread_handles: Vec<_> = communicators
  256. .into_iter()
  257. .enumerate()
  258. .map(|(party_id, mut communicator)| {
  259. thread::spawn(move || {
  260. if party_id == 0 {
  261. let fut_1 = communicator.receive::<u32>(1).unwrap();
  262. let fut_2 = communicator.receive::<[u32; 2]>(2).unwrap();
  263. communicator.send(1, msg_0).unwrap();
  264. communicator.send(2, msg_0).unwrap();
  265. let val_1 = fut_1.get();
  266. let val_2 = fut_2.get();
  267. assert!(val_1.is_ok());
  268. assert!(val_2.is_ok());
  269. assert_eq!(val_1.unwrap(), msg_1);
  270. assert_eq!(val_2.unwrap(), msg_2);
  271. } else if party_id == 1 {
  272. let fut_0 = communicator.receive::<u8>(0).unwrap();
  273. let fut_2 = communicator.receive::<[u32; 2]>(2).unwrap();
  274. communicator.send(0, msg_1).unwrap();
  275. communicator.send(2, msg_1).unwrap();
  276. let val_0 = fut_0.get();
  277. let val_2 = fut_2.get();
  278. assert!(val_0.is_ok());
  279. assert!(val_2.is_ok());
  280. assert_eq!(val_0.unwrap(), msg_0);
  281. assert_eq!(val_2.unwrap(), msg_2);
  282. } else if party_id == 2 {
  283. let fut_0 = communicator.receive::<u8>(0).unwrap();
  284. let fut_1 = communicator.receive::<u32>(1).unwrap();
  285. communicator.send(0, msg_2).unwrap();
  286. communicator.send(1, msg_2).unwrap();
  287. let val_0 = fut_0.get();
  288. let val_1 = fut_1.get();
  289. assert!(val_0.is_ok());
  290. assert!(val_1.is_ok());
  291. assert_eq!(val_0.unwrap(), msg_0);
  292. assert_eq!(val_1.unwrap(), msg_1);
  293. }
  294. communicator.shutdown();
  295. })
  296. })
  297. .collect();
  298. thread_handles.into_iter().for_each(|h| h.join().unwrap());
  299. }
  300. }