Skip to main content

rpfm_server/
server_websocket.rs

1//---------------------------------------------------------------------------//
2// Copyright (c) 2017-2026 Ismael Gutiérrez González. All rights reserved.
3//
4// This file is part of the Rusted PackFile Manager (RPFM) project,
5// which can be found here: https://github.com/Frodo45127/rpfm.
6//
7// This file is licensed under the MIT license, which can be found here:
8// https://github.com/Frodo45127/rpfm/blob/master/LICENSE.
9//---------------------------------------------------------------------------//
10
11//! WebSocket upgrade handler and message multiplexer for the `/ws` endpoint.
12//!
13//! On upgrade, the handler either reuses an existing [`Session`] (when the
14//! client supplies `?session_id=N`) or creates a new one. From then on the
15//! socket carries a stream of JSON-encoded [`IpcMessage<Command>`] frames
16//! from the client and [`IpcMessage<Response>`] frames back. Each command
17//! is dispatched into the session's dedicated background thread, whose
18//! responses are forwarded back over the same socket with the originating
19//! request `id` preserved so the client can correlate them.
20//!
21//! Graceful disconnect (`Command::ClientDisconnecting`) tears the session
22//! down immediately and flushes telemetry. Hard disconnects (socket close
23//! without that command) leave the session in a 5-minute grace period so
24//! the client can reconnect with the same `session_id` and pick up where it
25//! left off.
26//!
27//! [`Session`]: crate::session::Session
28//! [`IpcMessage<Command>`]: rpfm_ipc::messages::Message
29//! [`IpcMessage<Response>`]: rpfm_ipc::messages::Message
30
31use axum::{
32    extract::ws::{Message, WebSocket, WebSocketUpgrade},
33    extract::{Query, State},
34    response::IntoResponse
35};
36use futures::stream::StreamExt;
37use futures::sink::SinkExt;
38use serde::Deserialize;
39use tokio::sync::mpsc;
40
41use std::sync::Arc;
42
43use rpfm_ipc::messages::{Command, Message as IpcMessage, Response};
44use rpfm_telemetry::{info, warn};
45
46use crate::session::{DEFAULT_SESSION_TIMEOUT_SECS, SessionId, SessionManager, recv_response};
47
48//-------------------------------------------------------------------------------//
49//                              Enums & Structs
50//-------------------------------------------------------------------------------//
51
52
53/// Query parameters for WebSocket connection.
54#[derive(Debug, Deserialize)]
55pub struct WsQueryParams {
56
57    /// Optional session ID to connect to an existing session.
58    pub session_id: Option<SessionId>,
59}
60
61//-------------------------------------------------------------------------------//
62//                             Implementations
63//-------------------------------------------------------------------------------//
64
65/// WebSocket handler to upgrade the connection and handle messages.
66///
67/// Accepts an optional `session_id` query parameter to reconnect to an existing session.
68/// Example: `ws://localhost:45127/ws?session_id=123`
69pub(crate) async fn ws_handler(
70    State(session_manager): State<Arc<SessionManager>>,
71    Query(params): Query<WsQueryParams>,
72    ws: WebSocketUpgrade,
73) -> impl IntoResponse {
74    ws.max_message_size(usize::MAX)
75        .max_frame_size(usize::MAX)
76        .on_upgrade(move |socket| handle_socket(socket, session_manager, params.session_id))
77}
78
79/// Function to handle a WebSocket connection.
80///
81/// Each WebSocket connection gets its own session with an isolated background thread.
82/// If a session_id is provided and that session exists, the client reconnects to it.
83async fn handle_socket(socket: WebSocket, session_manager: Arc<SessionManager>, requested_session_id: Option<SessionId>) {
84
85    // Get or create a session for this client connection.
86    let (session, is_new) = session_manager.get_or_create_session(requested_session_id);
87    let session_id = session.id();
88
89    if is_new {
90        info!("New WebSocket client connected, created session ID: {}", session_id);
91    } else {
92        info!("WebSocket client reconnected to existing session ID: {}", session_id);
93    }
94
95    let (mut sink, mut receiver) = socket.split();
96    let (tx, mut rx) = mpsc::unbounded_channel::<IpcMessage<Response>>();
97
98    // Send the session ID to the client immediately after connection.
99    let session_connected_msg = IpcMessage {
100        id: 0, // Special ID for connection message
101        data: Response::SessionConnected(session_id),
102    };
103    if let Ok(json) = serde_json::to_string(&session_connected_msg) {
104        let _ = sink.send(Message::Text(json.into())).await;
105    }
106
107    // Task to send responses back to the client.
108    let sender_task = tokio::spawn(async move {
109        while let Some(response_msg) = rx.recv().await {
110            match serde_json::to_string(&response_msg) {
111                Ok(json) => {
112                    if sink.send(Message::Text(json.into())).await.is_err() {
113                        break;
114                    }
115                }
116                Err(error) => {
117                    let error_msg = IpcMessage {
118                        id: response_msg.id,
119                        data: Response::Error(format!("Serialization error: {}", error)),
120                    };
121
122                    if let Ok(json) = serde_json::to_string(&error_msg) {
123                        let _ = sink.send(Message::Text(json.into())).await;
124                    }
125                }
126            }
127        }
128    });
129
130    // Track whether the client requested a graceful disconnect.
131    let mut graceful_disconnect = false;
132
133    // Loop to receive commands from the client.
134    while let Some(msg) = receiver.next().await {
135        if let Ok(msg) = msg {
136            match msg {
137                Message::Text(t) => {
138                    // Try to parse the message to check for ClientDisconnecting.
139                    match serde_json::from_str::<IpcMessage<Command>>(&t) {
140                        Ok(msg) => {
141                            info!("Session {}: Received command [ID {}]: {:?}", session.id(), msg.id, msg.data);
142
143                            // Handle ClientDisconnecting specially - it needs access to session_manager.
144                            if matches!(msg.data, Command::ClientDisconnecting) {
145                                // Send success response before cleanup.
146                                let response_msg = IpcMessage {
147                                    id: msg.id,
148                                    data: Response::Success,
149                                };
150                                let _ = tx.send(response_msg);
151                                graceful_disconnect = true;
152                                break;
153                            }
154
155                            // Route other commands through the session's background thread.
156                            let tx = tx.clone();
157                            let session = session.clone();
158                            tokio::spawn(async move {
159                                let mut receiver = session.send(msg.data);
160                                let response = recv_response(&mut receiver).await;
161                                let response_msg = IpcMessage {
162                                    id: msg.id,
163                                    data: response,
164                                };
165                                let _ = tx.send(response_msg);
166                            });
167                        }
168                        Err(error) => {
169                            warn!("Session {}: Deserialization error: {}", session.id(), error);
170
171                            // Try to extract the message ID from the malformed message so we can
172                            // send an error response back to the client.
173                            if let Some(id) = serde_json::from_str::<serde_json::Value>(&t)
174                                .ok()
175                                .and_then(|v| v.get("id")?.as_u64()) {
176                                let error_msg = IpcMessage {
177                                    id,
178                                    data: Response::Error(format!("Server failed to deserialize command: {}", error)),
179                                };
180                                let _ = tx.send(error_msg);
181                            }
182
183                            // TODO: Handle the error case when the message ID cannot be extracted.
184                        }
185                    }
186                }
187                Message::Close(_) => {
188                    info!("Session {}: Client disconnected", session_id);
189                    break;
190                }
191                _ => {}
192            }
193        } else {
194            info!("Session {}: Client disconnected (error)", session_id);
195            break;
196        }
197    }
198
199    sender_task.abort();
200
201    // Client requested graceful disconnect - remove session immediately.
202    if graceful_disconnect {
203        info!("Session {}: Client requested graceful disconnect, removing session immediately", session_id);
204        session_manager.remove_session(session_id);
205
206        // Check if this was the last session and shutdown the server if so.
207        if session_manager.session_count() == 0 {
208            info!("No more active sessions, shutting down server...");
209            rpfm_telemetry::flush("Server Action Telemetry");
210            std::process::exit(0);
211        }
212    }
213
214    // Unexpected disconnect - mark session for timeout cleanup.
215    else {
216        SessionManager::client_disconnected(session_manager.clone(),session_id);
217        info!("Session {} client disconnected, session will timeout in {} minutes if not reconnected", session_id, DEFAULT_SESSION_TIMEOUT_SECS / 60);
218    }
219}