1use bytes::Bytes;
2use dashmap::DashMap;
3use std::any::Any;
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::time::Duration;
7use tokio::net::{TcpListener, TcpStream};
8use tokio::sync::{broadcast, mpsc};
9use uuid::Uuid;
10
11use crate::handler::{NoopHandler, ServerHandler};
12use crate::protocol;
13use crate::room::RoomManager;
14use crate::storage::Storage;
15use crate::types::{ClientWire, Result, ServerConfig, ServerWire, SyncError};
16use crate::{error, info, warn};
17
18struct ClientState {
20 room_id: Option<String>,
21 addr: SocketAddr,
22 metadata: Storage,
23}
24
25pub struct Server {
49 config: ServerConfig,
50 handler: Arc<dyn ServerHandler>,
51 rooms: Arc<RoomManager>,
52 clients: Arc<DashMap<Uuid, ClientState>>,
53 client_count: Arc<std::sync::atomic::AtomicUsize>,
54 shutdown_tx: broadcast::Sender<()>,
55}
56
57pub struct ServerHandle {
64 shutdown_tx: broadcast::Sender<()>,
65 rooms: Arc<RoomManager>,
66 clients: Arc<DashMap<Uuid, ClientState>>,
67 handler: Arc<dyn ServerHandler>,
68 client_count: Arc<std::sync::atomic::AtomicUsize>,
69}
70
71impl ServerHandle {
72 pub async fn shutdown(&self) {
74 let _ = self.shutdown_tx.send(());
75 }
76
77 #[inline]
79 pub fn create_room(&self, id: &str) -> Result<()> {
80 self.rooms.create(id)?;
81 self.handler.on_room_create(id);
82 Ok(())
83 }
84
85 #[inline]
90 pub fn delete_room(&self, id: &str) -> bool {
91 let existed = self.rooms.delete(id);
92 if existed {
93 self.handler.on_room_delete(id);
94 }
95 existed
96 }
97
98 pub fn room_exists(&self, id: &str) -> bool {
100 self.rooms.get(id).is_some()
101 }
102
103 pub fn room_count(&self) -> usize {
105 self.rooms.len()
106 }
107
108 pub fn get_room_ids(&self) -> Vec<String> {
110 self.rooms.room_ids()
111 }
112
113 pub fn get_room_clients(&self, id: &str) -> Option<Vec<Uuid>> {
115 self.rooms.get(id).map(|r| r.client_ids())
116 }
117
118 pub fn room_client_count(&self, id: &str) -> Option<usize> {
120 self.rooms.get(id).map(|r| r.len())
121 }
122
123 pub fn get_client_channel_len(&self, room_id: &str, client_id: &Uuid) -> Option<usize> {
128 self.rooms.get(room_id)?.channel_len(client_id)
129 }
130
131 pub fn get_room_channel_lens(&self, room_id: &str) -> Option<Vec<(Uuid, usize)>> {
136 self.rooms.get(room_id).map(|r| r.all_channel_lens())
137 }
138
139 pub fn set_room_meta<T: Any + Send + Sync + 'static>(&self, room_id: &str, value: T) -> bool {
142 if let Some(room) = self.rooms.get(room_id) {
143 room.metadata.set(value);
144 true
145 } else {
146 false
147 }
148 }
149
150 pub fn with_room_meta<T: Any + Send + Sync + 'static, R>(
157 &self,
158 room_id: &str,
159 f: impl FnOnce(&T) -> R,
160 ) -> Option<R> {
161 self.rooms.get(room_id)?.metadata.get(f)
162 }
163
164 pub fn take_room_meta<T: Any + Send + Sync + 'static>(&self, room_id: &str) -> Option<T> {
167 self.rooms.get(room_id)?.metadata.take()
168 }
169
170 pub fn room_has_meta(&self, room_id: &str) -> bool {
172 self.rooms.get(room_id).is_some_and(|r| r.metadata.is_set())
173 }
174
175 pub fn set_client_meta<T: Any + Send + Sync + 'static>(
178 &self,
179 client_id: &Uuid,
180 value: T,
181 ) -> bool {
182 if let Some(state) = self.clients.get(client_id) {
183 state.metadata.set(value);
184 true
185 } else {
186 false
187 }
188 }
189
190 pub fn with_client_meta<T: Any + Send + Sync + 'static, R>(
195 &self,
196 client_id: &Uuid,
197 f: impl FnOnce(&T) -> R,
198 ) -> Option<R> {
199 self.clients.get(client_id)?.metadata.get(f)
200 }
201
202 pub fn take_client_meta<T: Any + Send + Sync + 'static>(&self, client_id: &Uuid) -> Option<T> {
205 self.clients.get(client_id)?.metadata.take()
206 }
207
208 pub fn client_has_meta(&self, client_id: &Uuid) -> bool {
210 self.clients
211 .get(client_id)
212 .is_some_and(|c| c.metadata.is_set())
213 }
214
215 pub fn get_client_room(&self, client_id: Uuid) -> Option<String> {
217 self.clients.get(&client_id)?.room_id.clone()
218 }
219
220 pub fn get_client_addr(&self, client_id: Uuid) -> Option<SocketAddr> {
222 self.clients.get(&client_id).map(|c| c.addr)
223 }
224
225 pub fn get_client_count(&self) -> usize {
227 self.client_count.load(std::sync::atomic::Ordering::Relaxed)
228 }
229
230 pub fn kick_client(&self, client_id: &Uuid) -> bool {
233 let room_id = match self.clients.get(client_id) {
234 Some(state) => match &state.room_id {
235 Some(rid) => rid.clone(),
236 None => return false,
237 },
238 None => return false,
239 };
240
241 let notify = ServerWire::PlayerLeft {
242 client_id: *client_id,
243 };
244 if let Some(room) = self.rooms.get(&room_id) {
245 let rt = tokio::runtime::Handle::current();
246 rt.block_on(async {
247 let _ = room.broadcast(*client_id, ¬ify).await;
248 });
249 }
250
251 self.rooms.remove_client(&room_id, client_id);
252 self.handler.on_leave(*client_id, &room_id);
253
254 true
255 }
256}
257
258impl Server {
259 pub fn builder() -> ServerBuilder {
261 ServerBuilder::new()
262 }
263
264 #[deprecated(note = "Use ServerHandle::create_room instead, at runtime")]
268 pub fn pre_create_room(&self, id: &str) -> Result<()> {
269 self.rooms.create(id)?;
270 self.handler.on_room_create(id);
271 Ok(())
272 }
273
274 #[deprecated(note = "Use ServerHandle::delete_room instead, at runtime")]
277 pub fn pre_delete_room(&self, id: &str) -> bool {
278 let existed = self.rooms.delete(id);
279 if existed {
280 self.handler.on_room_delete(id);
281 }
282 existed
283 }
284
285 pub async fn run(self) -> std::io::Result<ServerHandle> {
293 let listener = TcpListener::bind(&self.config.bind_addr).await?;
294 info!("Server listening on {}", self.config.bind_addr);
295
296 let (shutdown_tx, _) = broadcast::channel::<()>(4);
297 let handle = ServerHandle {
298 shutdown_tx: shutdown_tx.clone(),
299 rooms: self.rooms.clone(),
300 clients: self.clients.clone(),
301 handler: self.handler.clone(),
302 client_count: self.client_count.clone(),
303 };
304
305 let mut shutdown_rx = shutdown_tx.subscribe();
306 let mut server = self;
307 server.shutdown_tx = shutdown_tx;
308
309 let server = Arc::new(server);
310
311 tokio::spawn(async move {
312 loop {
313 tokio::select! {
314 result = listener.accept() => {
315 match result {
316 Ok((stream, addr)) => {
317 stream.set_nodelay(true).ok();
318 if !server.handler.on_connect(addr) {
319 warn!("connection rejected by handler: {addr}");
320 drop(stream);
321 continue;
322 }
323
324 let count = server.client_count.load(std::sync::atomic::Ordering::Relaxed);
325 if count >= server.config.max_clients {
326 warn!("max clients reached, rejecting {addr}");
327 drop(stream);
328 continue;
329 }
330
331 server.client_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
332 let srv = server.clone();
333 tokio::spawn(async move {
334 if let Err(e) = srv.handle_connection(stream, addr).await {
335 if !e.is_connection_closed() {
336 warn!("client error: {e}");
337 }
338 }
339 srv.client_count.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
340 });
341 }
342 Err(e) => {
343 error!("accept error: {e}");
344 }
345 }
346 }
347 _ = shutdown_rx.recv() => {
348 info!("Shutting server down");
349 server.handler.on_shutdown();
350 break;
351 }
352 }
353 }
354 });
355
356 Ok(handle)
357 }
358
359 #[inline]
360 async fn handle_connection(
361 self: &Arc<Self>,
362 stream: TcpStream,
363 addr: SocketAddr,
364 ) -> Result<()> {
365 let client_id = Uuid::new_v4();
366 stream.set_nodelay(true).ok();
367 let (read_half, write_half) = stream.into_split();
368 let mut reader = tokio::io::BufReader::with_capacity(6 * 1024 * 1024, read_half);
369 let mut writer = tokio::io::BufWriter::with_capacity(12 * 1024 * 1024, write_half);
370
371 let (write_tx, mut write_rx) = mpsc::channel::<Bytes>(self.config.channel_capacity);
374
375 let writer_handle = tokio::spawn(async move {
377 while let Some(frame) = write_rx.recv().await {
378 if let Err(e) = protocol::write_frame_raw(&mut writer, frame).await {
379 #[allow(clippy::needless_ifs)]
380 if !e.is_connection_closed() {}
381 warn!("write error: {e}");
382 break;
383 }
384 }
385 });
386
387 let result = self
388 .client_loop(client_id, addr, &mut reader, &write_tx)
389 .await;
390
391 if let Some(state) = self.clients.remove(&client_id) {
393 if let Some(room_id) = state.1.room_id {
394 self.cleanup_client(client_id, &room_id).await;
395 }
396 }
397
398 drop(write_tx);
399 let _ = writer_handle.await;
400
401 result
402 }
403
404 #[inline]
405 async fn client_loop(
406 self: &Arc<Self>,
407 client_id: Uuid,
408 addr: SocketAddr,
409 reader: &mut tokio::io::BufReader<tokio::net::tcp::OwnedReadHalf>,
410 write_tx: &mpsc::Sender<Bytes>,
411 ) -> Result<()> {
412 self.clients.insert(
414 client_id,
415 ClientState {
416 room_id: None,
417 addr,
418 metadata: Storage::new(),
419 },
420 );
421
422 let mut shutdown_rx = self.shutdown_tx.subscribe();
423 let mut ping_interval = tokio::time::interval(self.config.ping_interval);
424 let mut awaiting_pong = false;
425
426 ping_interval.tick().await;
428
429 loop {
430 tokio::select! {
431 result = tokio::time::timeout(
435 self.config.idle_timeout,
436 protocol::read_frame_raw(reader, self.config.max_payload),
437 ) => {
438 match result {
439 Ok(Ok(payload)) => {
440 let msg: ClientWire = wincode::deserialize(&payload)
441 .map_err(|e| SyncError::Protocol(format!("deserialize failed: {:?}", e)))?;
442
443 if matches!(msg, ClientWire::Pong) {
445 awaiting_pong = false;
446 continue;
447 }
448
449 self.process_message(client_id, msg, write_tx).await?;
450 }
451 Ok(Err(SyncError::Io(ref e)))
452 if e.kind() == std::io::ErrorKind::UnexpectedEof =>
453 {
454 return Ok(()); }
456 Ok(Err(e)) => return Err(e),
458 Err(_timeout) => return Err(SyncError::IdleTimeout),
460 }
461 }
462
463 _ = ping_interval.tick() => {
466 if awaiting_pong {
467 return Err(SyncError::PingTimeout);
469 }
470 let ping = ServerWire::Ping;
471 self.send_to_client(client_id, write_tx, &ping).await;
472 awaiting_pong = true;
473 }
474
475 _ = shutdown_rx.recv() => {
477 return Err(SyncError::ConnectionClosed);
478 }
479 }
480 }
481 }
482
483 #[inline]
484 async fn process_message(
485 self: &Arc<Self>,
486 client_id: Uuid,
487 msg: ClientWire,
488 write_tx: &mpsc::Sender<Bytes>,
489 ) -> Result<()> {
490 match msg {
491 ClientWire::EchoTest { data } => {
492 let echo = ServerWire::EchoTest { data };
494 self.send_to_client(client_id, write_tx, &echo).await;
495 }
496
497 ClientWire::JoinRoom { room_id, data } => {
498 let addr = match self.clients.get(&client_id) {
500 Some(state) => state.addr,
501 None => return Ok(()),
502 };
503
504 if self.rooms.get(&room_id).is_none() {
506 let err = ServerWire::Error(format!("room not found: {room_id}"));
507 self.send_to_client(client_id, write_tx, &err).await;
508 return Ok(());
509 }
510
511 let (allow_join, reject_reason) =
513 self.handler.on_join(client_id, &room_id, addr, &data);
514 if !allow_join {
515 let reason = reject_reason.unwrap_or_else(|| "join rejected".to_string());
516 let err = ServerWire::Error(reason);
517 self.send_to_client(client_id, write_tx, &err).await;
518 return Ok(());
519 }
520
521 if let Some(mut state) = self.clients.get_mut(&client_id) {
523 if let Some(old_room) = state.room_id.take() {
524 drop(state);
525 self.cleanup_client(client_id, &old_room).await;
526 }
527 }
528
529 {
531 let room = self.rooms.get(&room_id).unwrap();
532 room.insert(client_id, write_tx.clone());
533 }
534
535 if let Some(mut state) = self.clients.get_mut(&client_id) {
537 state.room_id = Some(room_id.clone());
538 }
539
540 let joined = ServerWire::Joined {
541 client_id,
542 room_id: room_id.clone(),
543 };
544 self.send_to_client(client_id, write_tx, &joined).await;
545
546 let notify = ServerWire::PlayerJoined { client_id };
547 if let Some(room) = self.rooms.get(&room_id) {
548 let dropped = room.broadcast(client_id, ¬ify).await;
549 for id in dropped {
550 self.handler.on_backpressure(id, &room_id);
551 }
552 }
553 }
554
555 ClientWire::LeaveRoom => {
556 if let Some(mut state) = self.clients.get_mut(&client_id) {
557 if let Some(room_id) = state.room_id.take() {
558 drop(state);
559 self.cleanup_client(client_id, &room_id).await;
560 }
561 }
562 }
563
564 ClientWire::Ping => {
565 let pong = ServerWire::Pong;
566 self.send_to_client(client_id, write_tx, &pong).await;
567 }
568
569 ClientWire::Pong => {
570 }
572
573 ClientWire::Broadcast { data } => {
574 let room_id = match self.clients.get(&client_id) {
575 Some(state) => match &state.room_id {
576 Some(id) => id.clone(),
577 None => {
578 let err = ServerWire::Error("not in a room".into());
579 self.send_to_client(client_id, write_tx, &err).await;
580 return Ok(());
581 }
582 },
583 None => return Ok(()),
584 };
585
586 let broadcast = ServerWire::Broadcast {
587 sender_id: client_id,
588 data,
589 };
590 let payload = wincode::serialize(&broadcast)
591 .map_err(|e| SyncError::Protocol(format!("serialize failed: {:?}", e)))?;
592 let payload = Bytes::from(payload);
593
594 if let Some(room) = self.rooms.get(&room_id) {
595 let dropped = room.broadcast_raw(client_id, payload).await;
596 for id in dropped {
597 self.handler.on_backpressure(id, &room_id);
598 }
599 }
600
601 if let ServerWire::Broadcast { ref data, .. } = broadcast {
602 self.handler.on_broadcast(client_id, &room_id, data);
603 }
604 }
605 }
606
607 Ok(())
608 }
609
610 #[inline]
611 async fn cleanup_client(self: &Arc<Self>, client_id: Uuid, room_id: &str) {
612 let notify = ServerWire::PlayerLeft { client_id };
613 if let Some(room) = self.rooms.get(room_id) {
614 let dropped = room.broadcast(client_id, ¬ify).await;
615 for id in dropped {
616 self.handler.on_backpressure(id, room_id);
617 }
618 }
619
620 self.rooms.remove_client(room_id, &client_id);
621
622 self.handler.on_leave(client_id, room_id);
623 }
624
625 #[inline]
626 async fn send_to_client(&self, client_id: Uuid, tx: &mpsc::Sender<Bytes>, msg: &ServerWire) {
627 if let Ok(payload) = wincode::serialize(msg) {
628 if tx.try_send(Bytes::from(payload)).is_err() {
630 if let Some(state) = self.clients.get(&client_id) {
631 if let Some(ref room_id) = state.room_id {
632 self.handler.on_backpressure(client_id, room_id);
633 }
634 }
635 }
636 }
637 }
638}
639
640pub struct ServerBuilder {
652 config: ServerConfig,
653 handler: Arc<dyn ServerHandler>,
654}
655
656impl ServerBuilder {
657 pub fn new() -> Self {
658 Self {
659 config: ServerConfig::default(),
660 handler: Arc::new(NoopHandler),
661 }
662 }
663}
664
665impl Default for ServerBuilder {
666 fn default() -> Self {
667 Self::new()
668 }
669}
670
671impl ServerBuilder {
672 pub fn bind(mut self, addr: impl Into<String>) -> Self {
674 self.config.bind_addr = addr.into();
675 self
676 }
677
678 pub fn max_clients(mut self, n: usize) -> Self {
680 self.config.max_clients = n;
681 self
682 }
683
684 pub fn max_payload(mut self, n: usize) -> Self {
686 self.config.max_payload = n;
687 self
688 }
689
690 pub fn idle_timeout(mut self, d: Duration) -> Self {
694 self.config.idle_timeout = d;
695 self
696 }
697
698 pub fn ping_interval(mut self, d: Duration) -> Self {
703 self.config.ping_interval = d;
704 self
705 }
706
707 pub fn channel_capacity(mut self, n: usize) -> Self {
711 self.config.channel_capacity = n;
712 self
713 }
714
715 pub fn handler(mut self, h: impl ServerHandler) -> Self {
717 self.handler = Arc::new(h);
718 self
719 }
720
721 pub fn build(self) -> Server {
723 let (tx, _) = broadcast::channel(4);
724 Server {
725 config: self.config.clone(),
726 handler: self.handler,
727 rooms: Arc::new(RoomManager::new()),
728 clients: Arc::new(DashMap::with_capacity(self.config.max_clients)),
729 client_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
730 shutdown_tx: tx,
731 }
732 }
733}