socktop/socktop_connector/src/networking/connection.rs
jasonwitty 08f248c696 Housekeeping and QOL
non functional update:

- refactor stream of consciousness into separate files.
- combine equivelent functions used in networking and wasm features.
- cleanups and version bumps.
2025-09-10 10:39:21 -07:00

186 lines
6.0 KiB
Rust

//! WebSocket connection handling for native (non-WASM) environments.
use crate::config::ConnectorConfig;
use crate::error::{ConnectorError, Result};
use std::io::BufReader;
use std::sync::Arc;
use tokio_tungstenite::tungstenite::client::IntoClientRequest;
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, connect_async};
use url::Url;
#[cfg(feature = "tls")]
use {
rustls::{self, ClientConfig},
rustls::{
DigitallySignedStruct, RootCertStore, SignatureScheme,
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
crypto::ring,
pki_types::{CertificateDer, ServerName, UnixTime},
},
rustls_pemfile::Item,
std::fs::File,
tokio_tungstenite::Connector,
};
pub type WsStream = WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>;
/// Connect to the agent and return the WS stream
pub async fn connect_to_agent(config: &ConnectorConfig) -> Result<WsStream> {
#[cfg(feature = "tls")]
ensure_crypto_provider();
let mut u = Url::parse(&config.url)?;
if let Some(ca_path) = &config.tls_ca_path {
if u.scheme() == "ws" {
let _ = u.set_scheme("wss");
}
return connect_with_ca_and_config(u.as_str(), ca_path, config).await;
}
// No TLS - hostname verification is not applicable
connect_without_ca_and_config(u.as_str(), config).await
}
async fn connect_without_ca_and_config(url: &str, config: &ConnectorConfig) -> Result<WsStream> {
let mut req = url.into_client_request()?;
// Apply WebSocket protocol configuration
if let Some(version) = &config.ws_version {
req.headers_mut().insert(
"Sec-WebSocket-Version",
version
.parse()
.map_err(|_| ConnectorError::protocol_error("Invalid WebSocket version"))?,
);
}
if let Some(protocols) = &config.ws_protocols {
let protocols_str = protocols.join(", ");
req.headers_mut().insert(
"Sec-WebSocket-Protocol",
protocols_str
.parse()
.map_err(|_| ConnectorError::protocol_error("Invalid WebSocket protocols"))?,
);
}
let (ws, _) = connect_async(req).await?;
Ok(ws)
}
#[cfg(feature = "tls")]
async fn connect_with_ca_and_config(
url: &str,
ca_path: &str,
config: &ConnectorConfig,
) -> Result<WsStream> {
// Initialize the crypto provider for rustls
let _ = rustls::crypto::ring::default_provider().install_default();
let mut root = RootCertStore::empty();
let mut reader = BufReader::new(File::open(ca_path)?);
let mut der_certs = Vec::new();
while let Ok(Some(item)) = rustls_pemfile::read_one(&mut reader) {
if let Item::X509Certificate(der) = item {
der_certs.push(der);
}
}
root.add_parsable_certificates(der_certs);
let mut cfg = ClientConfig::builder()
.with_root_certificates(root)
.with_no_client_auth();
let mut req = url.into_client_request()?;
// Apply WebSocket protocol configuration
if let Some(version) = &config.ws_version {
req.headers_mut().insert(
"Sec-WebSocket-Version",
version
.parse()
.map_err(|_| ConnectorError::protocol_error("Invalid WebSocket version"))?,
);
}
if let Some(protocols) = &config.ws_protocols {
let protocols_str = protocols.join(", ");
req.headers_mut().insert(
"Sec-WebSocket-Protocol",
protocols_str
.parse()
.map_err(|_| ConnectorError::protocol_error("Invalid WebSocket protocols"))?,
);
}
if !config.verify_hostname {
#[derive(Debug)]
struct NoVerify;
impl ServerCertVerifier for NoVerify {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ED25519,
SignatureScheme::RSA_PSS_SHA256,
]
}
}
cfg.dangerous().set_certificate_verifier(Arc::new(NoVerify));
eprintln!(
"socktop_connector: hostname verification disabled (default). Set SOCKTOP_VERIFY_NAME=1 to enable strict SAN checking."
);
}
let cfg = Arc::new(cfg);
let (ws, _) = tokio_tungstenite::connect_async_tls_with_config(
req,
None,
config.verify_hostname,
Some(Connector::Rustls(cfg)),
)
.await?;
Ok(ws)
}
#[cfg(not(feature = "tls"))]
async fn connect_with_ca_and_config(
_url: &str,
_ca_path: &str,
_config: &ConnectorConfig,
) -> Result<WsStream> {
Err(ConnectorError::tls_error(
"TLS support not compiled in",
std::io::Error::new(std::io::ErrorKind::Unsupported, "TLS not available"),
))
}
#[cfg(feature = "tls")]
fn ensure_crypto_provider() {
let _ = ring::default_provider().install_default();
}