diff --git a/src/certificates.rs b/src/certificates.rs index af52ac3..078156b 100644 --- a/src/certificates.rs +++ b/src/certificates.rs @@ -18,7 +18,8 @@ use { /// A struct that holds all loaded certificates and the respective domain /// names. pub(crate) struct CertStore { - // use a Vec of pairs instead of a HashMap because order matters + /// Stores the certificates and the domains they apply to, sorted by domain + /// names, longest matches first certs: Vec<(String, CertifiedKey)>, } @@ -133,7 +134,8 @@ impl CertStore { // load all certificates from directories let mut certs = vec![]; - // try to load fallback certificate and key + // Try to load fallback certificate and key directly from the top level + // certificate directory. It will be loaded as the `.` domain. match load_domain(certs_dir, ".".to_string()) { Err(CertLoadError::EmptyDomain(_)) => { /* there are no fallback keys */ } Err(CertLoadError::NoReadCertDir) => unreachable!(), @@ -153,8 +155,9 @@ impl CertStore { Err(CertLoadError::MissingCert(_)) => { return Err(CertLoadError::MissingCert("fallback".to_string())) } - // if there are files, just push them because there is no domain - // name to check against + // For the fallback keys there is no domain name to verify them + // against, so we can skip that step and only have to do it for the + // other keys below. Ok(key) => certs.push((String::new(), key)), } @@ -165,6 +168,8 @@ impl CertStore { .filter(|x| x.path().is_dir()) { let path = file.path(); + + // the filename should be the domain name let filename = path .file_name() .and_then(OsStr::to_str) @@ -184,14 +189,17 @@ impl CertStore { } certs.sort_unstable_by(|(a, _), (b, _)| { - // try to match as many as possible. If one is a substring of the other, - // the `zip` will make them look equal and make the length decide. + // Try to match as many domain segments as possible. If one is a + // substring of the other, the `zip` will only compare the smaller + // length of either a or b and the for loop will not decide. for (a_part, b_part) in a.split('.').rev().zip(b.split('.').rev()) { if a_part != b_part { + // What we sort by here is not really important, but `str` + // already implements Ord, making it easier for us. return a_part.cmp(b_part); } } - // longer domains first + // Sort longer domains first. a.len().cmp(&b.len()).reverse() }); @@ -203,13 +211,17 @@ impl ResolvesServerCert for CertStore { fn resolve(&self, client_hello: rustls::ClientHello<'_>) -> Option { if let Some(name) = client_hello.server_name() { let name: &str = name.into(); + // The certificate list is sorted so the longest match will always + // appear first. We have to find the first that is either this + // domain or a parent domain of the current one. self.certs .iter() .find(|(s, _)| name.ends_with(s)) + // only the key is interesting .map(|(_, k)| k) .cloned() } else { - // This kind of resolver requires SNI + // This kind of resolver requires SNI. None } }