//! WebSocket upgrade and per-connection handler (request-driven). use axum::{ extract::ws::{Message, WebSocket}, extract::{Query, State, WebSocketUpgrade}, response::Response, }; use flate2::{Compression, write::GzEncoder}; use futures_util::StreamExt; use once_cell::sync::OnceCell; use std::collections::HashMap; use std::io::Write; use tokio::sync::Mutex; use crate::metrics::{collect_disks, collect_fast_metrics, collect_processes_all}; use crate::proto::pb; use crate::state::AppState; // Compression threshold based on typical payload size const COMPRESSION_THRESHOLD: usize = 768; // Reusable buffer for compression to avoid allocations struct CompressionCache { processes_vec: Vec, } impl CompressionCache { fn new() -> Self { Self { processes_vec: Vec::with_capacity(512), // Typical process count } } } static COMPRESSION_CACHE: OnceCell> = OnceCell::new(); pub async fn ws_handler( ws: WebSocketUpgrade, State(state): State, Query(q): Query>, ) -> Response { // optional auth if let Some(expected) = state.auth_token.as_ref() && q.get("token") != Some(expected) { return ws.on_upgrade(|socket| async move { let _ = socket.close().await; }); } ws.on_upgrade(move |socket| handle_socket(socket, state)) } async fn handle_socket(mut socket: WebSocket, state: AppState) { state .client_count .fetch_add(1, std::sync::atomic::Ordering::Relaxed); while let Some(Ok(msg)) = socket.next().await { match msg { Message::Text(ref text) if text == "get_metrics" => { let m = collect_fast_metrics(&state).await; let _ = send_json(&mut socket, &m).await; } Message::Text(ref text) if text == "get_disks" => { let d = collect_disks(&state).await; let _ = send_json(&mut socket, &d).await; } Message::Text(ref text) if text == "get_processes" => { let payload = collect_processes_all(&state).await; // Map to protobuf message // Get cached buffers let cache = COMPRESSION_CACHE.get_or_init(|| Mutex::new(CompressionCache::new())); let mut cache = cache.lock().await; // Reuse process vector to build the list cache.processes_vec.clear(); cache .processes_vec .extend(payload.top_processes.into_iter().map(|p| pb::Process { pid: p.pid, name: p.name, cpu_usage: p.cpu_usage, mem_bytes: p.mem_bytes, })); let pb = pb::Processes { process_count: payload.process_count as u64, rows: std::mem::take(&mut cache.processes_vec), }; let mut buf = Vec::with_capacity(8 * 1024); if prost::Message::encode(&pb, &mut buf).is_err() { let _ = socket.send(Message::Close(None)).await; } else { // compress if large if buf.len() <= COMPRESSION_THRESHOLD { let _ = socket.send(Message::Binary(buf)).await; } else { // Create a new encoder for each message to ensure proper gzip headers let mut encoder = GzEncoder::new(Vec::with_capacity(buf.len()), Compression::fast()); match encoder.write_all(&buf).and_then(|_| encoder.finish()) { Ok(compressed) => { let _ = socket.send(Message::Binary(compressed)).await; } Err(_) => { let _ = socket.send(Message::Binary(buf)).await; } } } } drop(cache); // Explicit drop to release mutex early } Message::Close(_) => break, _ => {} } } state .client_count .fetch_sub(1, std::sync::atomic::Ordering::Relaxed); } // Small, cheap gzip for larger payloads; send text for small. async fn send_json(ws: &mut WebSocket, value: &T) -> Result<(), axum::Error> { let json = serde_json::to_string(value).expect("serialize"); if json.len() <= COMPRESSION_THRESHOLD { return ws.send(Message::Text(json)).await; } let mut enc = GzEncoder::new(Vec::new(), Compression::fast()); enc.write_all(json.as_bytes()).ok(); let bin = enc.finish().unwrap_or_else(|_| json.into_bytes()); ws.send(Message::Binary(bin)).await } #[cfg(test)] mod tests { use super::*; use prost::Message as ProstMessage; use sysinfo::System; #[tokio::test] async fn test_process_list_not_empty() { // Initialize system data first to ensure we have processes let mut sys = System::new_all(); sys.refresh_all(); // Create state and put the refreshed system in it let state = AppState::new(); { let mut sys_lock = state.sys.lock().await; *sys_lock = sys; } // Get processes directly using the collection function let processes = collect_processes_all(&state).await; // Convert to protobuf message format let cache = COMPRESSION_CACHE.get_or_init(|| Mutex::new(CompressionCache::new())); let mut cache = cache.lock().await; // Reuse process vector to build the list cache.processes_vec.clear(); cache .processes_vec .extend(processes.top_processes.into_iter().map(|p| pb::Process { pid: p.pid, name: p.name, cpu_usage: p.cpu_usage, mem_bytes: p.mem_bytes, })); // Create the protobuf message let pb = pb::Processes { process_count: processes.process_count as u64, rows: cache.processes_vec.clone(), }; // Test protobuf encoding/decoding let mut buf = Vec::new(); prost::Message::encode(&pb, &mut buf).expect("Failed to encode protobuf"); let decoded = pb::Processes::decode(buf.as_slice()).expect("Failed to decode protobuf"); // Print debug info println!("Process count: {}", pb.process_count); println!("Process vector length: {}", pb.rows.len()); println!("Encoded size: {} bytes", buf.len()); println!("Decoded process count: {}", decoded.rows.len()); // Print first few processes if available for (i, process) in pb.rows.iter().take(5).enumerate() { println!( "Process {}: {} (PID: {}) CPU: {:.1}% MEM: {} bytes", i + 1, process.name, process.pid, process.cpu_usage, process.mem_bytes ); } // Validate assert!(!pb.rows.is_empty(), "Process list should not be empty"); assert!( pb.process_count > 0, "Process count should be greater than 0" ); assert_eq!( pb.process_count as usize, pb.rows.len(), "Process count mismatch with actual rows" ); } }