rpfm_server/server_websocket.rs
1//---------------------------------------------------------------------------//
2// Copyright (c) 2017-2026 Ismael Gutiérrez González. All rights reserved.
3//
4// This file is part of the Rusted PackFile Manager (RPFM) project,
5// which can be found here: https://github.com/Frodo45127/rpfm.
6//
7// This file is licensed under the MIT license, which can be found here:
8// https://github.com/Frodo45127/rpfm/blob/master/LICENSE.
9//---------------------------------------------------------------------------//
10
11//! WebSocket upgrade handler and message multiplexer for the `/ws` endpoint.
12//!
13//! On upgrade, the handler either reuses an existing [`Session`] (when the
14//! client supplies `?session_id=N`) or creates a new one. From then on the
15//! socket carries a stream of JSON-encoded [`IpcMessage<Command>`] frames
16//! from the client and [`IpcMessage<Response>`] frames back. Each command
17//! is dispatched into the session's dedicated background thread, whose
18//! responses are forwarded back over the same socket with the originating
19//! request `id` preserved so the client can correlate them.
20//!
21//! Graceful disconnect (`Command::ClientDisconnecting`) tears the session
22//! down immediately and flushes telemetry. Hard disconnects (socket close
23//! without that command) leave the session in a 5-minute grace period so
24//! the client can reconnect with the same `session_id` and pick up where it
25//! left off.
26//!
27//! [`Session`]: crate::session::Session
28//! [`IpcMessage<Command>`]: rpfm_ipc::messages::Message
29//! [`IpcMessage<Response>`]: rpfm_ipc::messages::Message
30
31use axum::{
32 extract::ws::{Message, WebSocket, WebSocketUpgrade},
33 extract::{Query, State},
34 response::IntoResponse
35};
36use futures::stream::StreamExt;
37use futures::sink::SinkExt;
38use serde::Deserialize;
39use tokio::sync::mpsc;
40
41use std::sync::Arc;
42
43use rpfm_ipc::messages::{Command, Message as IpcMessage, Response};
44use rpfm_telemetry::{info, warn};
45
46use crate::session::{DEFAULT_SESSION_TIMEOUT_SECS, SessionId, SessionManager, recv_response};
47
48//-------------------------------------------------------------------------------//
49// Enums & Structs
50//-------------------------------------------------------------------------------//
51
52
53/// Query parameters for WebSocket connection.
54#[derive(Debug, Deserialize)]
55pub struct WsQueryParams {
56
57 /// Optional session ID to connect to an existing session.
58 pub session_id: Option<SessionId>,
59}
60
61//-------------------------------------------------------------------------------//
62// Implementations
63//-------------------------------------------------------------------------------//
64
65/// WebSocket handler to upgrade the connection and handle messages.
66///
67/// Accepts an optional `session_id` query parameter to reconnect to an existing session.
68/// Example: `ws://localhost:45127/ws?session_id=123`
69pub(crate) async fn ws_handler(
70 State(session_manager): State<Arc<SessionManager>>,
71 Query(params): Query<WsQueryParams>,
72 ws: WebSocketUpgrade,
73) -> impl IntoResponse {
74 ws.max_message_size(usize::MAX)
75 .max_frame_size(usize::MAX)
76 .on_upgrade(move |socket| handle_socket(socket, session_manager, params.session_id))
77}
78
79/// Function to handle a WebSocket connection.
80///
81/// Each WebSocket connection gets its own session with an isolated background thread.
82/// If a session_id is provided and that session exists, the client reconnects to it.
83async fn handle_socket(socket: WebSocket, session_manager: Arc<SessionManager>, requested_session_id: Option<SessionId>) {
84
85 // Get or create a session for this client connection.
86 let (session, is_new) = session_manager.get_or_create_session(requested_session_id);
87 let session_id = session.id();
88
89 if is_new {
90 info!("New WebSocket client connected, created session ID: {}", session_id);
91 } else {
92 info!("WebSocket client reconnected to existing session ID: {}", session_id);
93 }
94
95 let (mut sink, mut receiver) = socket.split();
96 let (tx, mut rx) = mpsc::unbounded_channel::<IpcMessage<Response>>();
97
98 // Send the session ID to the client immediately after connection.
99 let session_connected_msg = IpcMessage {
100 id: 0, // Special ID for connection message
101 data: Response::SessionConnected(session_id),
102 };
103 if let Ok(json) = serde_json::to_string(&session_connected_msg) {
104 let _ = sink.send(Message::Text(json.into())).await;
105 }
106
107 // Task to send responses back to the client.
108 let sender_task = tokio::spawn(async move {
109 while let Some(response_msg) = rx.recv().await {
110 match serde_json::to_string(&response_msg) {
111 Ok(json) => {
112 if sink.send(Message::Text(json.into())).await.is_err() {
113 break;
114 }
115 }
116 Err(error) => {
117 let error_msg = IpcMessage {
118 id: response_msg.id,
119 data: Response::Error(format!("Serialization error: {}", error)),
120 };
121
122 if let Ok(json) = serde_json::to_string(&error_msg) {
123 let _ = sink.send(Message::Text(json.into())).await;
124 }
125 }
126 }
127 }
128 });
129
130 // Track whether the client requested a graceful disconnect.
131 let mut graceful_disconnect = false;
132
133 // Loop to receive commands from the client.
134 while let Some(msg) = receiver.next().await {
135 if let Ok(msg) = msg {
136 match msg {
137 Message::Text(t) => {
138 // Try to parse the message to check for ClientDisconnecting.
139 match serde_json::from_str::<IpcMessage<Command>>(&t) {
140 Ok(msg) => {
141 info!("Session {}: Received command [ID {}]: {:?}", session.id(), msg.id, msg.data);
142
143 // Handle ClientDisconnecting specially - it needs access to session_manager.
144 if matches!(msg.data, Command::ClientDisconnecting) {
145 // Send success response before cleanup.
146 let response_msg = IpcMessage {
147 id: msg.id,
148 data: Response::Success,
149 };
150 let _ = tx.send(response_msg);
151 graceful_disconnect = true;
152 break;
153 }
154
155 // Route other commands through the session's background thread.
156 let tx = tx.clone();
157 let session = session.clone();
158 tokio::spawn(async move {
159 let mut receiver = session.send(msg.data);
160 let response = recv_response(&mut receiver).await;
161 let response_msg = IpcMessage {
162 id: msg.id,
163 data: response,
164 };
165 let _ = tx.send(response_msg);
166 });
167 }
168 Err(error) => {
169 warn!("Session {}: Deserialization error: {}", session.id(), error);
170
171 // Try to extract the message ID from the malformed message so we can
172 // send an error response back to the client.
173 if let Some(id) = serde_json::from_str::<serde_json::Value>(&t)
174 .ok()
175 .and_then(|v| v.get("id")?.as_u64()) {
176 let error_msg = IpcMessage {
177 id,
178 data: Response::Error(format!("Server failed to deserialize command: {}", error)),
179 };
180 let _ = tx.send(error_msg);
181 }
182
183 // TODO: Handle the error case when the message ID cannot be extracted.
184 }
185 }
186 }
187 Message::Close(_) => {
188 info!("Session {}: Client disconnected", session_id);
189 break;
190 }
191 _ => {}
192 }
193 } else {
194 info!("Session {}: Client disconnected (error)", session_id);
195 break;
196 }
197 }
198
199 sender_task.abort();
200
201 // Client requested graceful disconnect - remove session immediately.
202 if graceful_disconnect {
203 info!("Session {}: Client requested graceful disconnect, removing session immediately", session_id);
204 session_manager.remove_session(session_id);
205
206 // Check if this was the last session and shutdown the server if so.
207 if session_manager.session_count() == 0 {
208 info!("No more active sessions, shutting down server...");
209 rpfm_telemetry::flush("Server Action Telemetry");
210 std::process::exit(0);
211 }
212 }
213
214 // Unexpected disconnect - mark session for timeout cleanup.
215 else {
216 SessionManager::client_disconnected(session_manager.clone(),session_id);
217 info!("Session {} client disconnected, session will timeout in {} minutes if not reconnected", session_id, DEFAULT_SESSION_TIMEOUT_SECS / 60);
218 }
219}