diff --git a/src/main.rs b/src/main.rs index 0b2adf2..e8b22b0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,26 +8,17 @@ use { }, async_tls::TlsAcceptor, once_cell::sync::Lazy, - std::{error::Error, ffi::OsStr, marker::Unpin, str, sync::Arc}, + std::{error::Error, ffi::OsStr, fs::File, io::BufReader, marker::Unpin, sync::Arc}, url::Url, }; -pub type Result = std::result::Result>; - -struct Args { - sock_addr: String, - content_dir: String, - cert_file: String, - key_file: String, -} - fn main() -> Result { block_on(async { let listener = TcpListener::bind(&ARGS.sock_addr).await?; let mut incoming = listener.incoming(); while let Some(Ok(stream)) = incoming.next().await { spawn(async { - if let Err(e) = connection(stream).await { + if let Err(e) = handle_request(stream).await { eprintln!("Error: {:?}", e); } }); @@ -36,6 +27,18 @@ fn main() -> Result { }) } +type Result = std::result::Result>; + +static ARGS: Lazy = + Lazy::new(|| args().expect("usage: agate ")); + +struct Args { + sock_addr: String, + content_dir: String, + cert_file: String, + key_file: String, +} + fn args() -> Option { let mut args = std::env::args().skip(1); Some(Args { @@ -46,29 +49,12 @@ fn args() -> Option { }) } -static ARGS: Lazy = - Lazy::new(|| args().expect("usage: agate ")); - -fn acceptor() -> Result { - use rustls::{ServerConfig, NoClientAuth, internal::pemfile::{certs, pkcs8_private_keys}}; - use std::{io::BufReader, fs::File}; - - let cert_file = File::open(&ARGS.cert_file)?; - let certs = certs(&mut BufReader::new(cert_file)).or(Err("bad cert"))?; - - let key_file = File::open(&ARGS.key_file)?; - let mut keys = pkcs8_private_keys(&mut BufReader::new(key_file)).or(Err("bad key"))?; - - let mut config = ServerConfig::new(NoClientAuth::new()); - config.set_single_cert(certs, keys.remove(0))?; - Ok(TlsAcceptor::from(Arc::new(config))) -} - /// Handle a single client session (request + response). -async fn connection(stream: TcpStream) -> Result { - static ACCEPTOR: Lazy = Lazy::new(|| acceptor().unwrap()); +async fn handle_request(stream: TcpStream) -> Result { + // Perform handshake. + static TLS: Lazy = Lazy::new(|| acceptor().unwrap()); + let mut stream = TLS.accept(stream).await?; - let mut stream = ACCEPTOR.accept(stream).await?; match parse_request(&mut stream).await { Ok(url) => { eprintln!("Got request for {:?}", url); @@ -81,6 +67,21 @@ async fn connection(stream: TcpStream) -> Result { } } +/// TLS configuration. +fn acceptor() -> Result { + use rustls::{ServerConfig, NoClientAuth, internal::pemfile::{certs, pkcs8_private_keys}}; + + let cert_file = File::open(&ARGS.cert_file)?; + let certs = certs(&mut BufReader::new(cert_file)).or(Err("bad cert"))?; + + let key_file = File::open(&ARGS.key_file)?; + let mut keys = pkcs8_private_keys(&mut BufReader::new(key_file)).or(Err("bad key"))?; + + let mut config = ServerConfig::new(NoClientAuth::new()); + config.set_single_cert(certs, keys.remove(0))?; + Ok(TlsAcceptor::from(Arc::new(config))) +} + /// Return the URL requested by the client. async fn parse_request(mut stream: R) -> Result { // Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we @@ -101,7 +102,7 @@ async fn parse_request(mut stream: R) -> Result { } buf = &mut request[len..]; } - let request = str::from_utf8(&request[..len - 2])?; + let request = std::str::from_utf8(&request[..len - 2])?; // Handle scheme-relative URLs. let url = if request.starts_with("//") {