//! WebSocket connector for communicating with socktop agents. use flate2::bufread::GzDecoder; use futures_util::{SinkExt, StreamExt}; use prost::Message as _; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::pki_types::{CertificateDer, ServerName, UnixTime}; use rustls::{ClientConfig, RootCertStore}; use rustls::{DigitallySignedStruct, SignatureScheme}; use rustls_pemfile::Item; use std::io::Read; use std::{fs::File, io::BufReader, sync::Arc}; use tokio::net::TcpStream; use tokio_tungstenite::{ Connector, MaybeTlsStream, WebSocketStream, connect_async, connect_async_tls_with_config, tungstenite::Message, tungstenite::client::IntoClientRequest, }; use url::Url; use crate::types::{AgentRequest, AgentResponse, DiskInfo, Metrics, ProcessInfo, ProcessesPayload}; #[cfg(feature = "tls")] fn ensure_crypto_provider() { use std::sync::Once; static INIT: Once = Once::new(); INIT.call_once(|| { let _ = rustls::crypto::ring::default_provider().install_default(); }); } mod pb { // generated by build.rs include!(concat!(env!("OUT_DIR"), "/socktop.rs")); } pub type WsStream = WebSocketStream>; /// Configuration for connecting to a socktop agent #[derive(Debug, Clone)] pub struct ConnectorConfig { pub url: String, pub tls_ca_path: Option, pub verify_hostname: bool, } impl ConnectorConfig { pub fn new(url: impl Into) -> Self { Self { url: url.into(), tls_ca_path: None, verify_hostname: false, } } pub fn with_tls_ca(mut self, ca_path: impl Into) -> Self { self.tls_ca_path = Some(ca_path.into()); self } pub fn with_hostname_verification(mut self, verify: bool) -> Self { self.verify_hostname = verify; self } } /// A WebSocket connector for communicating with socktop agents pub struct SocktopConnector { config: ConnectorConfig, stream: Option, } impl SocktopConnector { /// Create a new connector with the given configuration pub fn new(config: ConnectorConfig) -> Self { Self { config, stream: None, } } /// Connect to the agent pub async fn connect(&mut self) -> Result<(), Box> { let stream = connect_to_agent( &self.config.url, self.config.tls_ca_path.as_deref(), self.config.verify_hostname, ) .await?; self.stream = Some(stream); Ok(()) } /// Send a request to the agent and get the response pub async fn request( &mut self, request: AgentRequest, ) -> Result> { let stream = self.stream.as_mut().ok_or("Not connected")?; match request { AgentRequest::Metrics => { let metrics = request_metrics(stream) .await .ok_or("Failed to get metrics")?; Ok(AgentResponse::Metrics(metrics)) } AgentRequest::Disks => { let disks = request_disks(stream).await.ok_or("Failed to get disks")?; Ok(AgentResponse::Disks(disks)) } AgentRequest::Processes => { let processes = request_processes(stream) .await .ok_or("Failed to get processes")?; Ok(AgentResponse::Processes(processes)) } } } /// Check if the connector is connected pub fn is_connected(&self) -> bool { self.stream.is_some() } /// Disconnect from the agent pub async fn disconnect(&mut self) -> Result<(), Box> { if let Some(mut stream) = self.stream.take() { let _ = stream.close(None).await; } Ok(()) } } // Connect to the agent and return the WS stream async fn connect_to_agent( url: &str, tls_ca: Option<&str>, verify_hostname: bool, ) -> Result> { #[cfg(feature = "tls")] ensure_crypto_provider(); let mut u = Url::parse(url)?; if let Some(ca_path) = tls_ca { if u.scheme() == "ws" { let _ = u.set_scheme("wss"); } return connect_with_ca(u.as_str(), ca_path, verify_hostname).await; } // No TLS - hostname verification is not applicable let (ws, _) = connect_async(u.as_str()).await?; Ok(ws) } #[cfg(feature = "tls")] async fn connect_with_ca( url: &str, ca_path: &str, verify_hostname: bool, ) -> Result> { // Initialize the crypto provider for rustls let _ = rustls::crypto::ring::default_provider().install_default(); let mut root = RootCertStore::empty(); let mut reader = BufReader::new(File::open(ca_path)?); let mut der_certs = Vec::new(); while let Ok(Some(item)) = rustls_pemfile::read_one(&mut reader) { if let Item::X509Certificate(der) = item { der_certs.push(der); } } root.add_parsable_certificates(der_certs); let mut cfg = ClientConfig::builder() .with_root_certificates(root) .with_no_client_auth(); let req = url.into_client_request()?; if !verify_hostname { #[derive(Debug)] struct NoVerify; impl ServerCertVerifier for NoVerify { fn verify_server_cert( &self, _end_entity: &CertificateDer<'_>, _intermediates: &[CertificateDer<'_>], _server_name: &ServerName, _ocsp_response: &[u8], _now: UnixTime, ) -> Result { Ok(ServerCertVerified::assertion()) } fn verify_tls12_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &CertificateDer<'_>, _dss: &DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ SignatureScheme::ECDSA_NISTP256_SHA256, SignatureScheme::ED25519, SignatureScheme::RSA_PSS_SHA256, ] } } cfg.dangerous().set_certificate_verifier(Arc::new(NoVerify)); eprintln!( "socktop_connector: hostname verification disabled (default). Set SOCKTOP_VERIFY_NAME=1 to enable strict SAN checking." ); } let cfg = Arc::new(cfg); let (ws, _) = connect_async_tls_with_config(req, None, verify_hostname, Some(Connector::Rustls(cfg))) .await?; Ok(ws) } #[cfg(not(feature = "tls"))] async fn connect_with_ca( _url: &str, _ca_path: &str, ) -> Result> { Err("TLS support not compiled in".into()) } // Send a "get_metrics" request and await a single JSON reply async fn request_metrics(ws: &mut WsStream) -> Option { if ws.send(Message::Text("get_metrics".into())).await.is_err() { return None; } match ws.next().await { Some(Ok(Message::Binary(b))) => { gunzip_to_string(&b).and_then(|s| serde_json::from_str::(&s).ok()) } Some(Ok(Message::Text(json))) => serde_json::from_str::(&json).ok(), _ => None, } } // Send a "get_disks" request and await a JSON Vec async fn request_disks(ws: &mut WsStream) -> Option> { if ws.send(Message::Text("get_disks".into())).await.is_err() { return None; } match ws.next().await { Some(Ok(Message::Binary(b))) => { gunzip_to_string(&b).and_then(|s| serde_json::from_str::>(&s).ok()) } Some(Ok(Message::Text(json))) => serde_json::from_str::>(&json).ok(), _ => None, } } // Send a "get_processes" request and await a ProcessesPayload decoded from protobuf (binary, may be gzipped) async fn request_processes(ws: &mut WsStream) -> Option { if ws .send(Message::Text("get_processes".into())) .await .is_err() { return None; } match ws.next().await { Some(Ok(Message::Binary(b))) => { let gz = is_gzip(&b); let data = if gz { gunzip_to_vec(&b)? } else { b }; match pb::Processes::decode(data.as_slice()) { Ok(pb) => { let rows: Vec = pb .rows .into_iter() .map(|p: pb::Process| ProcessInfo { pid: p.pid, name: p.name, cpu_usage: p.cpu_usage, mem_bytes: p.mem_bytes, }) .collect(); Some(ProcessesPayload { process_count: pb.process_count as usize, top_processes: rows, }) } Err(e) => { if std::env::var("SOCKTOP_DEBUG").ok().as_deref() == Some("1") { eprintln!("protobuf decode failed: {e}"); } // Fallback: maybe it's JSON (bytes already decompressed if gz) match String::from_utf8(data) { Ok(s) => serde_json::from_str::(&s).ok(), Err(_) => None, } } } } Some(Ok(Message::Text(json))) => serde_json::from_str::(&json).ok(), _ => None, } } // Decompress a gzip-compressed binary frame into a String. fn gunzip_to_string(bytes: &[u8]) -> Option { let mut dec = GzDecoder::new(bytes); let mut out = String::new(); dec.read_to_string(&mut out).ok()?; Some(out) } fn gunzip_to_vec(bytes: &[u8]) -> Option> { let mut dec = GzDecoder::new(bytes); let mut out = Vec::new(); dec.read_to_end(&mut out).ok()?; Some(out) } fn is_gzip(bytes: &[u8]) -> bool { bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b } /// Convenience function to create a connector and connect in one step. /// /// This function is for non-TLS WebSocket connections (`ws://`). Since there's no /// certificate involved, hostname verification is not applicable. /// /// For TLS connections with certificate pinning, use `connect_to_socktop_agent_with_tls()`. pub async fn connect_to_socktop_agent( url: impl Into, ) -> Result> { let config = ConnectorConfig::new(url); let mut connector = SocktopConnector::new(config); connector.connect().await?; Ok(connector) } /// Convenience function to create a connector with TLS and connect in one step. /// /// This function enables TLS with certificate pinning using the provided CA certificate. /// The `verify_hostname` parameter controls whether the server's hostname is verified /// against the certificate (recommended for production, can be disabled for testing). #[cfg(feature = "tls")] #[cfg_attr(docsrs, doc(cfg(feature = "tls")))] pub async fn connect_to_socktop_agent_with_tls( url: impl Into, ca_path: impl Into, verify_hostname: bool, ) -> Result> { let config = ConnectorConfig::new(url) .with_tls_ca(ca_path) .with_hostname_verification(verify_hostname); let mut connector = SocktopConnector::new(config); connector.connect().await?; Ok(connector) }