mirror of
https://github.com/samsonjs/agate.git
synced 2026-03-25 09:05:50 +00:00
Add support for listening UNIX sockets
A new CLI option --socket enables listening on UNIX sockets. This is similar to the --addr option, but takes a path as argument. If the given path already exists and it is a socket, attempt to remove it before listening. The port check was refactored to avoid the retrieval of the TCP port on every request.
This commit is contained in:
parent
60f76bc617
commit
4b6d238436
1 changed files with 109 additions and 6 deletions
115
src/main.rs
115
src/main.rs
|
|
@ -32,6 +32,11 @@ use {
|
|||
url::{Host, Url},
|
||||
};
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
use std::os::unix::fs::FileTypeExt;
|
||||
#[cfg(target_family = "unix")]
|
||||
use tokio::net::{UnixListener, UnixStream};
|
||||
|
||||
static DEFAULT_PORT: u16 = 1965;
|
||||
|
||||
fn main() {
|
||||
|
|
@ -96,6 +101,48 @@ fn main() {
|
|||
}))
|
||||
};
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
for socketpath in &ARGS.sockets {
|
||||
let arc = mimetypes.clone();
|
||||
|
||||
if socketpath.exists() && socketpath.metadata()
|
||||
.expect("Failed to get existing socket metadata")
|
||||
.file_type()
|
||||
.is_socket() {
|
||||
log::warn!("Socket already exists, attempting to remove {}", socketpath.display());
|
||||
let _ = std::fs::remove_file(socketpath);
|
||||
}
|
||||
|
||||
let listener = match UnixListener::bind(socketpath) {
|
||||
Err(e) => {
|
||||
panic!("Failed to listen on {}: {}", socketpath.display(), e)
|
||||
}
|
||||
Ok(listener) => listener,
|
||||
};
|
||||
|
||||
handles.push(tokio::spawn(async move {
|
||||
log::info!("Started listener on {}", socketpath.display());
|
||||
|
||||
loop {
|
||||
let (stream, _) = listener.accept().await.unwrap_or_else(|e| {
|
||||
panic!("could not accept new connection on {}: {}", socketpath.display(), e)
|
||||
});
|
||||
let arc = arc.clone();
|
||||
tokio::spawn(async {
|
||||
match RequestHandle::new_unix(stream, arc).await {
|
||||
Ok(handle) => match handle.handle().await {
|
||||
Ok(info) => log::info!("{}", info),
|
||||
Err(err) => log::warn!("{}", err),
|
||||
},
|
||||
Err(log_line) => {
|
||||
log::warn!("{}", log_line);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}))
|
||||
};
|
||||
|
||||
futures_util::future::join_all(handles).await;
|
||||
});
|
||||
}
|
||||
|
|
@ -111,6 +158,7 @@ static ARGS: Lazy<Args> = Lazy::new(|| {
|
|||
|
||||
struct Args {
|
||||
addrs: Vec<SocketAddr>,
|
||||
sockets: Vec<PathBuf>,
|
||||
content_dir: PathBuf,
|
||||
certs: Arc<certificates::CertStore>,
|
||||
hostnames: Vec<Host>,
|
||||
|
|
@ -143,6 +191,13 @@ fn args() -> Result<Args> {
|
|||
&format!("Address to listen on (default 0.0.0.0:{DEFAULT_PORT} and [::]:{DEFAULT_PORT}; muliple occurences means listening on multiple interfaces)"),
|
||||
"IP:PORT",
|
||||
);
|
||||
#[cfg(target_family = "unix")]
|
||||
opts.optmulti(
|
||||
"",
|
||||
"socket",
|
||||
"Unix socket to listen on (muliple occurences means listening on multiple sockets)",
|
||||
"PATH",
|
||||
);
|
||||
opts.optmulti(
|
||||
"",
|
||||
"hostname",
|
||||
|
|
@ -290,7 +345,13 @@ fn args() -> Result<Args> {
|
|||
for i in matches.opt_strs("addr") {
|
||||
addrs.push(i.parse()?);
|
||||
}
|
||||
if addrs.is_empty() {
|
||||
|
||||
let mut sockets = vec![];
|
||||
for i in matches.opt_strs("socket") {
|
||||
sockets.push(i.parse()?);
|
||||
}
|
||||
|
||||
if addrs.is_empty() && sockets.is_empty() {
|
||||
addrs = vec![
|
||||
SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), DEFAULT_PORT),
|
||||
SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), DEFAULT_PORT),
|
||||
|
|
@ -299,6 +360,7 @@ fn args() -> Result<Args> {
|
|||
|
||||
Ok(Args {
|
||||
addrs,
|
||||
sockets,
|
||||
content_dir: check_path(matches.opt_get_default("content", "content".into())?)?,
|
||||
certs: Arc::new(certs),
|
||||
hostnames,
|
||||
|
|
@ -338,13 +400,14 @@ fn acceptor() -> TlsAcceptor {
|
|||
TlsAcceptor::from(Arc::new(config))
|
||||
}
|
||||
|
||||
struct RequestHandle {
|
||||
stream: TlsStream<TcpStream>,
|
||||
struct RequestHandle<T> {
|
||||
stream: TlsStream<T>,
|
||||
local_port_check: Option<u16>,
|
||||
log_line: String,
|
||||
metadata: Arc<Mutex<FileOptions>>,
|
||||
}
|
||||
|
||||
impl RequestHandle {
|
||||
impl RequestHandle<TcpStream> {
|
||||
/// Creates a new request handle for the given stream. If establishing the TLS
|
||||
/// session fails, returns a corresponding log line.
|
||||
async fn new(stream: TcpStream, metadata: Arc<Mutex<FileOptions>>) -> Result<Self, String> {
|
||||
|
|
@ -369,9 +432,16 @@ impl RequestHandle {
|
|||
|
||||
let log_line = format!("{local_addr} {peer_addr}",);
|
||||
|
||||
let local_port_check = if ARGS.skip_port_check {
|
||||
None
|
||||
} else {
|
||||
Some(stream.local_addr().unwrap().port())
|
||||
};
|
||||
|
||||
match TLS.accept(stream).await {
|
||||
Ok(stream) => Ok(Self {
|
||||
stream,
|
||||
local_port_check,
|
||||
log_line,
|
||||
metadata,
|
||||
}),
|
||||
|
|
@ -379,7 +449,40 @@ impl RequestHandle {
|
|||
Err(e) => Err(format!("{log_line} \"\" 00 \"TLS error\" error:{e}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_family = "unix")]
|
||||
impl RequestHandle<UnixStream> {
|
||||
async fn new_unix(
|
||||
stream: UnixStream,
|
||||
metadata: Arc<Mutex<FileOptions>>,
|
||||
) -> Result<Self, String> {
|
||||
let log_line = match stream.local_addr() {
|
||||
Ok(a) => match a.as_pathname() {
|
||||
Some(p) => format!("{} -", p.display()),
|
||||
None => "<unnamed socket> -".to_string(),
|
||||
},
|
||||
Err(_) => "<unnamed socket> -".to_string(),
|
||||
};
|
||||
|
||||
match TLS.accept(stream).await {
|
||||
Ok(stream) => Ok(Self {
|
||||
stream,
|
||||
// TODO add port check for unix sockets, requires extra arg for port
|
||||
local_port_check: None,
|
||||
log_line,
|
||||
metadata,
|
||||
}),
|
||||
// use nonexistent status code 00 if connection was not established
|
||||
Err(e) => Err(format!("{} \"\" 00 \"TLS error\" error:{}", log_line, e)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> RequestHandle<T>
|
||||
where
|
||||
T: AsyncWriteExt + AsyncReadExt + Unpin,
|
||||
{
|
||||
/// Do the necessary actions to handle this request. Returns a corresponding
|
||||
/// log line as Err or Ok, depending on if the request finished with or
|
||||
/// without errors.
|
||||
|
|
@ -476,11 +579,11 @@ impl RequestHandle {
|
|||
}
|
||||
|
||||
// correct port
|
||||
if !ARGS.skip_port_check {
|
||||
if let Some(expected_port) = self.local_port_check {
|
||||
if let Some(port) = url.port() {
|
||||
// Validate that the port in the URL is the same as for the stream this request
|
||||
// came in on.
|
||||
if port != self.stream.get_ref().0.local_addr().unwrap().port() {
|
||||
if port != expected_port {
|
||||
return Err((PROXY_REQUEST_REFUSED, "Proxy request refused"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue