socktop/socktop_connector/src/connector.rs
jasonwitty e51054811c Refactor for additional socktop connector library
- socktop connector allows you to communicate with socktop agent directly from you code without needing to implement the agent API directly.
- will also be used for non tui implementation of "socktop collector" in the future.
- moved to rust 2024 to take advantage of some new features that helped with refactor.
- fixed everything that exploded with update.
- added rust docs for lib.
2025-09-04 05:30:25 -07:00

365 lines
12 KiB
Rust

//! WebSocket connector for communicating with socktop agents.
use flate2::bufread::GzDecoder;
use futures_util::{SinkExt, StreamExt};
use prost::Message as _;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, RootCertStore};
use rustls::{DigitallySignedStruct, SignatureScheme};
use rustls_pemfile::Item;
use std::io::Read;
use std::{fs::File, io::BufReader, sync::Arc};
use tokio::net::TcpStream;
use tokio_tungstenite::{
Connector, MaybeTlsStream, WebSocketStream, connect_async, connect_async_tls_with_config,
tungstenite::Message, tungstenite::client::IntoClientRequest,
};
use url::Url;
use crate::types::{AgentRequest, AgentResponse, DiskInfo, Metrics, ProcessInfo, ProcessesPayload};
#[cfg(feature = "tls")]
fn ensure_crypto_provider() {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| {
let _ = rustls::crypto::ring::default_provider().install_default();
});
}
mod pb {
// generated by build.rs
include!(concat!(env!("OUT_DIR"), "/socktop.rs"));
}
pub type WsStream = WebSocketStream<MaybeTlsStream<TcpStream>>;
/// Configuration for connecting to a socktop agent
#[derive(Debug, Clone)]
pub struct ConnectorConfig {
pub url: String,
pub tls_ca_path: Option<String>,
pub verify_hostname: bool,
}
impl ConnectorConfig {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
tls_ca_path: None,
verify_hostname: false,
}
}
pub fn with_tls_ca(mut self, ca_path: impl Into<String>) -> Self {
self.tls_ca_path = Some(ca_path.into());
self
}
pub fn with_hostname_verification(mut self, verify: bool) -> Self {
self.verify_hostname = verify;
self
}
}
/// A WebSocket connector for communicating with socktop agents
pub struct SocktopConnector {
config: ConnectorConfig,
stream: Option<WsStream>,
}
impl SocktopConnector {
/// Create a new connector with the given configuration
pub fn new(config: ConnectorConfig) -> Self {
Self {
config,
stream: None,
}
}
/// Connect to the agent
pub async fn connect(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let stream = connect_to_agent(
&self.config.url,
self.config.tls_ca_path.as_deref(),
self.config.verify_hostname,
)
.await?;
self.stream = Some(stream);
Ok(())
}
/// Send a request to the agent and get the response
pub async fn request(
&mut self,
request: AgentRequest,
) -> Result<AgentResponse, Box<dyn std::error::Error>> {
let stream = self.stream.as_mut().ok_or("Not connected")?;
match request {
AgentRequest::Metrics => {
let metrics = request_metrics(stream)
.await
.ok_or("Failed to get metrics")?;
Ok(AgentResponse::Metrics(metrics))
}
AgentRequest::Disks => {
let disks = request_disks(stream).await.ok_or("Failed to get disks")?;
Ok(AgentResponse::Disks(disks))
}
AgentRequest::Processes => {
let processes = request_processes(stream)
.await
.ok_or("Failed to get processes")?;
Ok(AgentResponse::Processes(processes))
}
}
}
/// Check if the connector is connected
pub fn is_connected(&self) -> bool {
self.stream.is_some()
}
/// Disconnect from the agent
pub async fn disconnect(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if let Some(mut stream) = self.stream.take() {
let _ = stream.close(None).await;
}
Ok(())
}
}
// Connect to the agent and return the WS stream
async fn connect_to_agent(
url: &str,
tls_ca: Option<&str>,
verify_hostname: bool,
) -> Result<WsStream, Box<dyn std::error::Error>> {
#[cfg(feature = "tls")]
ensure_crypto_provider();
let mut u = Url::parse(url)?;
if let Some(ca_path) = tls_ca {
if u.scheme() == "ws" {
let _ = u.set_scheme("wss");
}
return connect_with_ca(u.as_str(), ca_path, verify_hostname).await;
}
// No TLS - hostname verification is not applicable
let (ws, _) = connect_async(u.as_str()).await?;
Ok(ws)
}
#[cfg(feature = "tls")]
async fn connect_with_ca(
url: &str,
ca_path: &str,
verify_hostname: bool,
) -> Result<WsStream, Box<dyn std::error::Error>> {
// Initialize the crypto provider for rustls
let _ = rustls::crypto::ring::default_provider().install_default();
let mut root = RootCertStore::empty();
let mut reader = BufReader::new(File::open(ca_path)?);
let mut der_certs = Vec::new();
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut reader) {
if let Item::X509Certificate(der) = item {
der_certs.push(der);
}
}
root.add_parsable_certificates(der_certs);
let mut cfg = ClientConfig::builder()
.with_root_certificates(root)
.with_no_client_auth();
let req = url.into_client_request()?;
if !verify_hostname {
#[derive(Debug)]
struct NoVerify;
impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ED25519,
SignatureScheme::RSA_PSS_SHA256,
]
}
}
cfg.dangerous().set_certificate_verifier(Arc::new(NoVerify));
eprintln!(
"socktop_connector: hostname verification disabled (default). Set SOCKTOP_VERIFY_NAME=1 to enable strict SAN checking."
);
}
let cfg = Arc::new(cfg);
let (ws, _) =
connect_async_tls_with_config(req, None, verify_hostname, Some(Connector::Rustls(cfg)))
.await?;
Ok(ws)
}
#[cfg(not(feature = "tls"))]
async fn connect_with_ca(
_url: &str,
_ca_path: &str,
) -> Result<WsStream, Box<dyn std::error::Error>> {
Err("TLS support not compiled in".into())
}
// Send a "get_metrics" request and await a single JSON reply
async fn request_metrics(ws: &mut WsStream) -> Option<Metrics> {
if ws.send(Message::Text("get_metrics".into())).await.is_err() {
return None;
}
match ws.next().await {
Some(Ok(Message::Binary(b))) => {
gunzip_to_string(&b).and_then(|s| serde_json::from_str::<Metrics>(&s).ok())
}
Some(Ok(Message::Text(json))) => serde_json::from_str::<Metrics>(&json).ok(),
_ => None,
}
}
// Send a "get_disks" request and await a JSON Vec<DiskInfo>
async fn request_disks(ws: &mut WsStream) -> Option<Vec<DiskInfo>> {
if ws.send(Message::Text("get_disks".into())).await.is_err() {
return None;
}
match ws.next().await {
Some(Ok(Message::Binary(b))) => {
gunzip_to_string(&b).and_then(|s| serde_json::from_str::<Vec<DiskInfo>>(&s).ok())
}
Some(Ok(Message::Text(json))) => serde_json::from_str::<Vec<DiskInfo>>(&json).ok(),
_ => None,
}
}
// Send a "get_processes" request and await a ProcessesPayload decoded from protobuf (binary, may be gzipped)
async fn request_processes(ws: &mut WsStream) -> Option<ProcessesPayload> {
if ws
.send(Message::Text("get_processes".into()))
.await
.is_err()
{
return None;
}
match ws.next().await {
Some(Ok(Message::Binary(b))) => {
let gz = is_gzip(&b);
let data = if gz { gunzip_to_vec(&b)? } else { b };
match pb::Processes::decode(data.as_slice()) {
Ok(pb) => {
let rows: Vec<ProcessInfo> = pb
.rows
.into_iter()
.map(|p: pb::Process| ProcessInfo {
pid: p.pid,
name: p.name,
cpu_usage: p.cpu_usage,
mem_bytes: p.mem_bytes,
})
.collect();
Some(ProcessesPayload {
process_count: pb.process_count as usize,
top_processes: rows,
})
}
Err(e) => {
if std::env::var("SOCKTOP_DEBUG").ok().as_deref() == Some("1") {
eprintln!("protobuf decode failed: {e}");
}
// Fallback: maybe it's JSON (bytes already decompressed if gz)
match String::from_utf8(data) {
Ok(s) => serde_json::from_str::<ProcessesPayload>(&s).ok(),
Err(_) => None,
}
}
}
}
Some(Ok(Message::Text(json))) => serde_json::from_str::<ProcessesPayload>(&json).ok(),
_ => None,
}
}
// Decompress a gzip-compressed binary frame into a String.
fn gunzip_to_string(bytes: &[u8]) -> Option<String> {
let mut dec = GzDecoder::new(bytes);
let mut out = String::new();
dec.read_to_string(&mut out).ok()?;
Some(out)
}
fn gunzip_to_vec(bytes: &[u8]) -> Option<Vec<u8>> {
let mut dec = GzDecoder::new(bytes);
let mut out = Vec::new();
dec.read_to_end(&mut out).ok()?;
Some(out)
}
fn is_gzip(bytes: &[u8]) -> bool {
bytes.len() >= 2 && bytes[0] == 0x1f && bytes[1] == 0x8b
}
/// Convenience function to create a connector and connect in one step.
///
/// This function is for non-TLS WebSocket connections (`ws://`). Since there's no
/// certificate involved, hostname verification is not applicable.
///
/// For TLS connections with certificate pinning, use `connect_to_socktop_agent_with_tls()`.
pub async fn connect_to_socktop_agent(
url: impl Into<String>,
) -> Result<SocktopConnector, Box<dyn std::error::Error>> {
let config = ConnectorConfig::new(url);
let mut connector = SocktopConnector::new(config);
connector.connect().await?;
Ok(connector)
}
/// Convenience function to create a connector with TLS and connect in one step.
///
/// This function enables TLS with certificate pinning using the provided CA certificate.
/// The `verify_hostname` parameter controls whether the server's hostname is verified
/// against the certificate (recommended for production, can be disabled for testing).
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub async fn connect_to_socktop_agent_with_tls(
url: impl Into<String>,
ca_path: impl Into<String>,
verify_hostname: bool,
) -> Result<SocktopConnector, Box<dyn std::error::Error>> {
let config = ConnectorConfig::new(url)
.with_tls_ca(ca_path)
.with_hostname_verification(verify_hostname);
let mut connector = SocktopConnector::new(config);
connector.connect().await?;
Ok(connector)
}