diff --git a/src/main.rs b/src/main.rs index 0f563d5..39f116f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,6 +32,11 @@ use { url::{Host, Url}, }; +#[cfg(target_family = "unix")] +use std::os::unix::fs::FileTypeExt; +#[cfg(target_family = "unix")] +use tokio::net::{UnixListener, UnixStream}; + static DEFAULT_PORT: u16 = 1965; fn main() { @@ -96,6 +101,48 @@ fn main() { })) }; + #[cfg(target_family = "unix")] + for socketpath in &ARGS.sockets { + let arc = mimetypes.clone(); + + if socketpath.exists() && socketpath.metadata() + .expect("Failed to get existing socket metadata") + .file_type() + .is_socket() { + log::warn!("Socket already exists, attempting to remove {}", socketpath.display()); + let _ = std::fs::remove_file(socketpath); + } + + let listener = match UnixListener::bind(socketpath) { + Err(e) => { + panic!("Failed to listen on {}: {}", socketpath.display(), e) + } + Ok(listener) => listener, + }; + + handles.push(tokio::spawn(async move { + log::info!("Started listener on {}", socketpath.display()); + + loop { + let (stream, _) = listener.accept().await.unwrap_or_else(|e| { + panic!("could not accept new connection on {}: {}", socketpath.display(), e) + }); + let arc = arc.clone(); + tokio::spawn(async { + match RequestHandle::new_unix(stream, arc).await { + Ok(handle) => match handle.handle().await { + Ok(info) => log::info!("{}", info), + Err(err) => log::warn!("{}", err), + }, + Err(log_line) => { + log::warn!("{}", log_line); + } + } + }); + } + })) + }; + futures_util::future::join_all(handles).await; }); } @@ -111,6 +158,7 @@ static ARGS: Lazy = Lazy::new(|| { struct Args { addrs: Vec, + sockets: Vec, content_dir: PathBuf, certs: Arc, hostnames: Vec, @@ -143,6 +191,13 @@ fn args() -> Result { &format!("Address to listen on (default 0.0.0.0:{DEFAULT_PORT} and [::]:{DEFAULT_PORT}; muliple occurences means listening on multiple interfaces)"), "IP:PORT", ); + #[cfg(target_family = "unix")] + opts.optmulti( + "", + "socket", + "Unix socket to listen on (muliple occurences means listening on multiple sockets)", + "PATH", + ); opts.optmulti( "", "hostname", @@ -290,7 +345,13 @@ fn args() -> Result { for i in matches.opt_strs("addr") { addrs.push(i.parse()?); } - if addrs.is_empty() { + + let mut sockets = vec![]; + for i in matches.opt_strs("socket") { + sockets.push(i.parse()?); + } + + if addrs.is_empty() && sockets.is_empty() { addrs = vec![ SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), DEFAULT_PORT), SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), DEFAULT_PORT), @@ -299,6 +360,7 @@ fn args() -> Result { Ok(Args { addrs, + sockets, content_dir: check_path(matches.opt_get_default("content", "content".into())?)?, certs: Arc::new(certs), hostnames, @@ -338,13 +400,14 @@ fn acceptor() -> TlsAcceptor { TlsAcceptor::from(Arc::new(config)) } -struct RequestHandle { - stream: TlsStream, +struct RequestHandle { + stream: TlsStream, + local_port_check: Option, log_line: String, metadata: Arc>, } -impl RequestHandle { +impl RequestHandle { /// Creates a new request handle for the given stream. If establishing the TLS /// session fails, returns a corresponding log line. async fn new(stream: TcpStream, metadata: Arc>) -> Result { @@ -369,9 +432,16 @@ impl RequestHandle { let log_line = format!("{local_addr} {peer_addr}",); + let local_port_check = if ARGS.skip_port_check { + None + } else { + Some(stream.local_addr().unwrap().port()) + }; + match TLS.accept(stream).await { Ok(stream) => Ok(Self { stream, + local_port_check, log_line, metadata, }), @@ -379,7 +449,40 @@ impl RequestHandle { Err(e) => Err(format!("{log_line} \"\" 00 \"TLS error\" error:{e}")), } } +} +#[cfg(target_family = "unix")] +impl RequestHandle { + async fn new_unix( + stream: UnixStream, + metadata: Arc>, + ) -> Result { + let log_line = match stream.local_addr() { + Ok(a) => match a.as_pathname() { + Some(p) => format!("{} -", p.display()), + None => " -".to_string(), + }, + Err(_) => " -".to_string(), + }; + + match TLS.accept(stream).await { + Ok(stream) => Ok(Self { + stream, + // TODO add port check for unix sockets, requires extra arg for port + local_port_check: None, + log_line, + metadata, + }), + // use nonexistent status code 00 if connection was not established + Err(e) => Err(format!("{} \"\" 00 \"TLS error\" error:{}", log_line, e)), + } + } +} + +impl RequestHandle +where + T: AsyncWriteExt + AsyncReadExt + Unpin, +{ /// Do the necessary actions to handle this request. Returns a corresponding /// log line as Err or Ok, depending on if the request finished with or /// without errors. @@ -476,11 +579,11 @@ impl RequestHandle { } // correct port - if !ARGS.skip_port_check { + if let Some(expected_port) = self.local_port_check { if let Some(port) = url.port() { // Validate that the port in the URL is the same as for the stream this request // came in on. - if port != self.stream.get_ref().0.local_addr().unwrap().port() { + if port != expected_port { return Err((PROXY_REQUEST_REFUSED, "Proxy request refused")); } }