Add support for listening UNIX sockets

A new CLI option --socket enables listening on UNIX sockets. This is
similar to the --addr option, but takes a path as argument.

If the given path already exists and it is a socket, attempt to remove
it before listening.

The port check was refactored to avoid the retrieval of the TCP port on
every request.
This commit is contained in:
equalsraf 2023-03-16 15:21:29 +00:00 committed by Matt Brubeck
parent 60f76bc617
commit 4b6d238436

View file

@ -32,6 +32,11 @@ use {
url::{Host, Url},
};
#[cfg(target_family = "unix")]
use std::os::unix::fs::FileTypeExt;
#[cfg(target_family = "unix")]
use tokio::net::{UnixListener, UnixStream};
static DEFAULT_PORT: u16 = 1965;
fn main() {
@ -96,6 +101,48 @@ fn main() {
}))
};
#[cfg(target_family = "unix")]
for socketpath in &ARGS.sockets {
let arc = mimetypes.clone();
if socketpath.exists() && socketpath.metadata()
.expect("Failed to get existing socket metadata")
.file_type()
.is_socket() {
log::warn!("Socket already exists, attempting to remove {}", socketpath.display());
let _ = std::fs::remove_file(socketpath);
}
let listener = match UnixListener::bind(socketpath) {
Err(e) => {
panic!("Failed to listen on {}: {}", socketpath.display(), e)
}
Ok(listener) => listener,
};
handles.push(tokio::spawn(async move {
log::info!("Started listener on {}", socketpath.display());
loop {
let (stream, _) = listener.accept().await.unwrap_or_else(|e| {
panic!("could not accept new connection on {}: {}", socketpath.display(), e)
});
let arc = arc.clone();
tokio::spawn(async {
match RequestHandle::new_unix(stream, arc).await {
Ok(handle) => match handle.handle().await {
Ok(info) => log::info!("{}", info),
Err(err) => log::warn!("{}", err),
},
Err(log_line) => {
log::warn!("{}", log_line);
}
}
});
}
}))
};
futures_util::future::join_all(handles).await;
});
}
@ -111,6 +158,7 @@ static ARGS: Lazy<Args> = Lazy::new(|| {
struct Args {
addrs: Vec<SocketAddr>,
sockets: Vec<PathBuf>,
content_dir: PathBuf,
certs: Arc<certificates::CertStore>,
hostnames: Vec<Host>,
@ -143,6 +191,13 @@ fn args() -> Result<Args> {
&format!("Address to listen on (default 0.0.0.0:{DEFAULT_PORT} and [::]:{DEFAULT_PORT}; muliple occurences means listening on multiple interfaces)"),
"IP:PORT",
);
#[cfg(target_family = "unix")]
opts.optmulti(
"",
"socket",
"Unix socket to listen on (muliple occurences means listening on multiple sockets)",
"PATH",
);
opts.optmulti(
"",
"hostname",
@ -290,7 +345,13 @@ fn args() -> Result<Args> {
for i in matches.opt_strs("addr") {
addrs.push(i.parse()?);
}
if addrs.is_empty() {
let mut sockets = vec![];
for i in matches.opt_strs("socket") {
sockets.push(i.parse()?);
}
if addrs.is_empty() && sockets.is_empty() {
addrs = vec![
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), DEFAULT_PORT),
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), DEFAULT_PORT),
@ -299,6 +360,7 @@ fn args() -> Result<Args> {
Ok(Args {
addrs,
sockets,
content_dir: check_path(matches.opt_get_default("content", "content".into())?)?,
certs: Arc::new(certs),
hostnames,
@ -338,13 +400,14 @@ fn acceptor() -> TlsAcceptor {
TlsAcceptor::from(Arc::new(config))
}
struct RequestHandle {
stream: TlsStream<TcpStream>,
struct RequestHandle<T> {
stream: TlsStream<T>,
local_port_check: Option<u16>,
log_line: String,
metadata: Arc<Mutex<FileOptions>>,
}
impl RequestHandle {
impl RequestHandle<TcpStream> {
/// Creates a new request handle for the given stream. If establishing the TLS
/// session fails, returns a corresponding log line.
async fn new(stream: TcpStream, metadata: Arc<Mutex<FileOptions>>) -> Result<Self, String> {
@ -369,9 +432,16 @@ impl RequestHandle {
let log_line = format!("{local_addr} {peer_addr}",);
let local_port_check = if ARGS.skip_port_check {
None
} else {
Some(stream.local_addr().unwrap().port())
};
match TLS.accept(stream).await {
Ok(stream) => Ok(Self {
stream,
local_port_check,
log_line,
metadata,
}),
@ -379,7 +449,40 @@ impl RequestHandle {
Err(e) => Err(format!("{log_line} \"\" 00 \"TLS error\" error:{e}")),
}
}
}
#[cfg(target_family = "unix")]
impl RequestHandle<UnixStream> {
async fn new_unix(
stream: UnixStream,
metadata: Arc<Mutex<FileOptions>>,
) -> Result<Self, String> {
let log_line = match stream.local_addr() {
Ok(a) => match a.as_pathname() {
Some(p) => format!("{} -", p.display()),
None => "<unnamed socket> -".to_string(),
},
Err(_) => "<unnamed socket> -".to_string(),
};
match TLS.accept(stream).await {
Ok(stream) => Ok(Self {
stream,
// TODO add port check for unix sockets, requires extra arg for port
local_port_check: None,
log_line,
metadata,
}),
// use nonexistent status code 00 if connection was not established
Err(e) => Err(format!("{} \"\" 00 \"TLS error\" error:{}", log_line, e)),
}
}
}
impl<T> RequestHandle<T>
where
T: AsyncWriteExt + AsyncReadExt + Unpin,
{
/// Do the necessary actions to handle this request. Returns a corresponding
/// log line as Err or Ok, depending on if the request finished with or
/// without errors.
@ -476,11 +579,11 @@ impl RequestHandle {
}
// correct port
if !ARGS.skip_port_check {
if let Some(expected_port) = self.local_port_check {
if let Some(port) = url.port() {
// Validate that the port in the URL is the same as for the stream this request
// came in on.
if port != self.stream.get_ref().0.local_addr().unwrap().port() {
if port != expected_port {
return Err((PROXY_REQUEST_REFUSED, "Proxy request refused"));
}
}