diff --git a/src/main.rs b/src/main.rs index 1bdd3b8..bea9ab5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -117,15 +117,21 @@ fn check_path(s: String) -> Result { } } +struct RequestHandle { + pub stream: TlsStream, +} + /// Handle a single client session (request + response). async fn handle_request(stream: TcpStream) -> Result { - let stream = &mut TLS.accept(stream).await?; + let stream = TLS.accept(stream).await?; - match parse_request(stream).await { - Ok(url) => send_response(url, stream).await?, - Err((status, msg)) => send_header(stream, status, msg).await?, + let mut handle = RequestHandle { stream }; + + match parse_request(&mut handle).await { + Ok(url) => send_response(url, &mut handle).await?, + Err((status, msg)) => send_header(&mut handle, status, msg).await?, } - stream.shutdown().await?; + handle.stream.shutdown().await?; Ok(()) } @@ -145,7 +151,7 @@ fn acceptor() -> Result { } /// Return the URL requested by the client. -async fn parse_request(stream: &mut TlsStream) -> std::result::Result { +async fn parse_request(handle: &mut RequestHandle) -> std::result::Result { // Because requests are limited to 1024 bytes (plus 2 bytes for CRLF), we // can use a fixed-sized buffer on the stack, avoiding allocations and // copying, and stopping bad clients from making us use too much memory. @@ -155,7 +161,7 @@ async fn parse_request(stream: &mut TlsStream) -> std::result::Result // Read until CRLF, end-of-stream, or there's no buffer space left. loop { - let bytes_read = stream.read(buf).await.or(Err((59, "Request ended unexpectedly")))?; + let bytes_read = handle.stream.read(buf).await.or(Err((59, "Request ended unexpectedly")))?; len += bytes_read; if request[..len].ends_with(b"\r\n") { break; @@ -169,7 +175,8 @@ async fn parse_request(stream: &mut TlsStream) -> std::result::Result log::info!( "Got request for {:?} from {}", request, - stream + handle + .stream .get_ref() .0 .peer_addr() @@ -193,7 +200,7 @@ async fn parse_request(stream: &mut TlsStream) -> std::result::Result } if let Some(port) = url.port() { // Validate that the port in the URL is the same as for the stream this request came in on. - if port != stream.get_ref().0.local_addr().unwrap().port() { + if port != handle.stream.get_ref().0.local_addr().unwrap().port() { return Err((53, "proxy request refused")); } } @@ -201,13 +208,13 @@ async fn parse_request(stream: &mut TlsStream) -> std::result::Result } /// Send the client the file located at the requested URL. -async fn send_response(url: Url, stream: &mut TlsStream) -> Result { +async fn send_response(url: Url, handle: &mut RequestHandle) -> Result { let mut path = std::path::PathBuf::from(&ARGS.content_dir); if let Some(segments) = url.path_segments() { for segment in segments { if !ARGS.serve_secret && segment.starts_with('.') { // Do not serve anything that looks like a hidden file. - return send_header(stream, 52, "If I told you, it would not be a secret.").await; + return send_header(handle, 52, "If I told you, it would not be a secret.").await; } path.push(&*percent_decode_str(segment).decode_utf8()?); } @@ -221,13 +228,13 @@ async fn send_response(url: Url, stream: &mut TlsStream) -> Result { path.push("index.gmi"); if !path.exists() && path.with_file_name(".directory-listing-ok").exists() { path.pop(); - return list_directory(stream, &path).await; + return list_directory(handle, &path).await; } } else { // if client is not redirected, links may not work as expected without trailing slash let mut url = url; url.set_path(&format!("{}/", url.path())); - return send_header(stream, 31, url.as_str()).await; + return send_header(handle, 31, url.as_str()).await; } } } @@ -236,32 +243,32 @@ async fn send_response(url: Url, stream: &mut TlsStream) -> Result { let mut file = match tokio::fs::File::open(&path).await { Ok(file) => file, Err(e) => { - send_header(stream, 51, "Not found, sorry.").await?; + send_header(handle, 51, "Not found, sorry.").await?; Err(e)? } }; // Send header. if path.extension() == Some(OsStr::new("gmi")) { - send_text_gemini_header(stream).await?; + send_text_gemini_header(handle).await?; } else { let mime = mime_guess::from_path(&path).first_or_octet_stream(); - send_header(stream, 20, mime.essence_str()).await?; + send_header(handle, 20, mime.essence_str()).await?; } // Send body. - tokio::io::copy(&mut file, stream).await?; + tokio::io::copy(&mut file, &mut handle.stream).await?; Ok(()) } -async fn list_directory(stream: &mut TlsStream, path: &Path) -> Result { +async fn list_directory(handle: &mut RequestHandle, path: &Path) -> Result { // https://url.spec.whatwg.org/#path-percent-encode-set const ENCODE_SET: AsciiSet = CONTROLS.add(b' ') .add(b'"').add(b'#').add(b'<').add(b'>') .add(b'?').add(b'`').add(b'{').add(b'}'); log::info!("Listing directory {:?}", path); - send_text_gemini_header(stream).await?; + send_text_gemini_header(handle).await?; let mut entries = tokio::fs::read_dir(path).await?; let mut lines = vec![]; while let Some(entry) = entries.next_entry().await? { @@ -280,25 +287,25 @@ async fn list_directory(stream: &mut TlsStream, path: &Path) -> Resul } lines.sort(); for line in lines { - stream.write_all(line.as_bytes()).await?; + handle.stream.write_all(line.as_bytes()).await?; } Ok(()) } -async fn send_header(stream: &mut TlsStream, status: u8, meta: &str) -> Result { +async fn send_header(handle: &mut RequestHandle, status: u8, meta: &str) -> Result { use std::fmt::Write; let mut response = String::with_capacity(64); write!(response, "{} {}", status, meta)?; log::info!("Responding with status {:?}", response); response.push_str("\r\n"); - stream.write_all(response.as_bytes()).await?; + handle.stream.write_all(response.as_bytes()).await?; Ok(()) } -async fn send_text_gemini_header(stream: &mut TlsStream) -> Result { +async fn send_text_gemini_header(handle: &mut RequestHandle) -> Result { if let Some(lang) = ARGS.language.as_deref() { - send_header(stream, 20, &format!("text/gemini;lang={}", lang)).await + send_header(handle, 20, &format!("text/gemini;lang={}", lang)).await } else { - send_header(stream, 20, "text/gemini").await + send_header(handle, 20, "text/gemini").await } }