1use bytes::Bytes;
2use tokio::io::{BufReader, BufWriter};
3use tokio::net::TcpStream;
4use uuid::Uuid;
5
6use crate::protocol;
7use crate::types::{ClientWire, ServerEvent, ServerWire, SyncError};
8
9pub struct Client {
11 reader: BufReader<tokio::net::tcp::OwnedReadHalf>,
12 writer: BufWriter<tokio::net::tcp::OwnedWriteHalf>,
13 max_payload: usize,
14 client_id: Option<Uuid>,
15}
16
17impl Client {
18 pub async fn connect(addr: &str) -> Result<Self, SyncError> {
20 Self::builder().connect(addr).await
21 }
22
23 pub fn builder() -> ClientBuilder {
25 ClientBuilder::new()
26 }
27
28 pub fn client_id(&self) -> Option<Uuid> {
33 self.client_id
34 }
35
36 pub async fn join(&mut self, room_id: &str, data: Option<&[u8]>) -> Result<(), SyncError> {
38 let msg = if let Some(data) = data {
39 ClientWire::JoinRoom {
40 room_id: room_id.into(),
41 data: data.to_vec(),
42 }
43 } else {
44 ClientWire::JoinRoom {
45 room_id: room_id.into(),
46 data: Vec::new(),
47 }
48 };
49 self.send(&msg).await
50 }
51
52 pub async fn echo_test(&mut self, data: &[u8]) -> Result<(), SyncError> {
55 let msg = ClientWire::EchoTest {
56 data: data.to_vec(),
57 };
58 self.send(&msg).await
59 }
60
61 pub async fn leave(&mut self) -> Result<(), SyncError> {
63 self.send(&ClientWire::LeaveRoom).await
64 }
65
66 #[inline]
68 pub async fn ping(&mut self) -> Result<(), SyncError> {
69 self.send(&ClientWire::Ping).await
70 }
71
72 #[inline]
74 pub async fn broadcast(&mut self, data: &[u8]) -> Result<(), SyncError> {
75 let msg = ClientWire::Broadcast {
76 data: data.to_vec(),
77 };
78 self.send(&msg).await
79 }
80
81 #[inline]
85 pub async fn recv(&mut self) -> Result<Option<ServerEvent>, SyncError> {
86 loop {
87 let payload = match protocol::read_frame_raw(&mut self.reader, self.max_payload).await {
88 Ok(p) => p,
89 Err(ref e) if e.is_connection_closed() => {
90 return Ok(None);
91 }
92 Err(e) => return Err(e),
93 };
94
95 let wire: ServerWire = wincode::deserialize(&payload)
96 .map_err(|e| SyncError::Protocol(format!("deserialize failed: {:?}", e)))?;
97
98 if matches!(wire, ServerWire::Ping) {
100 self.send(&ClientWire::Pong).await?;
101 continue;
102 }
103
104 if matches!(wire, ServerWire::Pong) {
106 continue;
107 }
108
109 if let ServerWire::Joined { client_id, .. } = &wire {
111 self.client_id = Some(*client_id);
112 }
113
114 return Ok(Some(Self::wire_to_event(wire)));
115 }
116 }
117
118 #[inline]
119 async fn send(&mut self, msg: &ClientWire) -> Result<(), SyncError> {
120 protocol::write_frame(&mut self.writer, msg).await
121 }
122
123 #[inline]
124 fn wire_to_event(wire: ServerWire) -> ServerEvent {
125 match wire {
126 ServerWire::Joined { client_id, room_id } => ServerEvent::Joined { client_id, room_id },
127 ServerWire::PlayerJoined { client_id } => ServerEvent::PlayerJoined { client_id },
128 ServerWire::PlayerLeft { client_id } => ServerEvent::PlayerLeft { client_id },
129 ServerWire::Error(msg) => ServerEvent::Error(msg),
130 ServerWire::Broadcast { sender_id, data } => ServerEvent::Broadcast {
131 sender_id,
132 data: Bytes::from(data),
133 },
134 ServerWire::EchoTest { data } => ServerEvent::EchoTest {
135 data: Bytes::from(data),
136 },
137 _ => unreachable!("ping/pong should be handled in recv()"),
139 }
140 }
141}
142
143pub struct ClientBuilder {
145 max_payload: usize,
146}
147
148impl ClientBuilder {
149 pub fn new() -> Self {
150 Self {
151 max_payload: 256 * 1024,
152 }
153 }
154}
155
156impl Default for ClientBuilder {
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162impl ClientBuilder {
163 pub fn max_payload(mut self, n: usize) -> Self {
165 self.max_payload = n;
166 self
167 }
168
169 pub async fn connect(self, addr: &str) -> Result<Client, SyncError> {
171 let stream = TcpStream::connect(addr).await.map_err(|e| {
172 if e.kind() == std::io::ErrorKind::ConnectionRefused {
173 SyncError::ConnectionRefused
174 } else {
175 SyncError::Io(e)
176 }
177 })?;
178 stream.set_nodelay(true).ok(); let (read_half, write_half) = stream.into_split();
180
181 Ok(Client {
182 reader: BufReader::with_capacity(4 * 1024 * 1024, read_half),
183 writer: BufWriter::with_capacity(6 * 1024 * 1024, write_half),
184 max_payload: self.max_payload,
185 client_id: None,
186 })
187 }
188}