Spaces:
Runtime error
Runtime error
| use std::error::Error; | |
| use std::net::{IpAddr, Ipv6Addr, SocketAddr}; | |
| use std::sync::Arc; | |
| use std::{env, thread}; | |
| use crate::ferron_request_handler::request_handler; | |
| use crate::ferron_util::load_tls::{load_certs, load_private_key}; | |
| use crate::ferron_util::sni::CustomSniResolver; | |
| use crate::ferron_util::validate_config::{prepare_config_for_validation, validate_config}; | |
| use crate::ferron_common::{LogMessage, ServerModule, ServerModuleHandlers}; | |
| use async_channel::Sender; | |
| use chrono::prelude::*; | |
| use futures_util::StreamExt; | |
| use h3_quinn::quinn; | |
| use h3_quinn::quinn::crypto::rustls::QuicServerConfig; | |
| use http::Response; | |
| use http_body_util::{BodyExt, StreamBody}; | |
| use hyper::body::{Buf, Bytes, Frame, Incoming}; | |
| use hyper::service::service_fn; | |
| use hyper::Request; | |
| use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; | |
| use ocsp_stapler::Stapler; | |
| use rustls::crypto::ring::cipher_suite::*; | |
| use rustls::crypto::ring::default_provider; | |
| use rustls::crypto::ring::kx_group::*; | |
| use rustls::server::{Acceptor, WebPkiClientVerifier}; | |
| use rustls::sign::CertifiedKey; | |
| use rustls::version::{TLS12, TLS13}; | |
| use rustls::{RootCertStore, ServerConfig}; | |
| use rustls_acme::acme::ACME_TLS_ALPN_NAME; | |
| use rustls_acme::caches::DirCache; | |
| use rustls_acme::{is_tls_alpn_challenge, AcmeConfig, ResolvesServerCertAcme, UseChallenge}; | |
| use rustls_native_certs::load_native_certs; | |
| use tokio::fs; | |
| use tokio::io::{AsyncWriteExt, BufWriter}; | |
| use tokio::net::{TcpListener, TcpStream}; | |
| use tokio::runtime::Handle; | |
| use tokio::sync::Mutex; | |
| use tokio::time; | |
| use tokio_rustls::server::TlsStream; | |
| use tokio_rustls::LazyConfigAcceptor; | |
| use yaml_rust2::Yaml; | |
| // Enum for maybe TLS stream | |
| enum MaybeTlsStream { | |
| Tls(TlsStream<TcpStream>), | |
| Plain(TcpStream), | |
| } | |
| // Function to accept and handle incoming QUIC connections | |
| async fn accept_quic_connection( | |
| connection_attempt: quinn::Incoming, | |
| local_address: SocketAddr, | |
| config: Arc<Yaml>, | |
| logger: Sender<LogMessage>, | |
| modules: Arc<Vec<Box<dyn ServerModule + std::marker::Send + Sync>>>, | |
| ) { | |
| let remote_address = connection_attempt.remote_address(); | |
| let logger_clone = logger.clone(); | |
| tokio::task::spawn(async move { | |
| match connection_attempt.await { | |
| Ok(connection) => { | |
| let mut h3_conn: h3::server::Connection<h3_quinn::Connection, Bytes> = | |
| match h3::server::Connection::new(h3_quinn::Connection::new(connection)).await { | |
| Ok(h3_conn) => h3_conn, | |
| Err(err) => { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| }; | |
| loop { | |
| match h3_conn.accept().await { | |
| Ok(Some(resolver)) => { | |
| let config = config.clone(); | |
| let remote_address = remote_address; | |
| let logger_clone = logger_clone.clone(); | |
| let modules = modules.clone(); | |
| tokio::spawn(async move { | |
| let handlers_vec = modules | |
| .iter() | |
| .map(|module| module.get_handlers(Handle::current())); | |
| let (request, stream) = resolver; | |
| let (mut send, receive) = stream.split(); | |
| let request_body_stream = futures_util::stream::unfold( | |
| (receive, false), | |
| async move |(mut receive, mut is_body_finished)| loop { | |
| if !is_body_finished { | |
| match receive.recv_data().await { | |
| Ok(Some(mut data)) => { | |
| return Some(( | |
| Ok(Frame::data(data.copy_to_bytes(data.remaining()))), | |
| (receive, false), | |
| )) | |
| } | |
| Ok(None) => is_body_finished = true, | |
| Err(err) => { | |
| return Some(( | |
| Err(std::io::Error::other(err.to_string())), | |
| (receive, false), | |
| )) | |
| } | |
| } | |
| } else { | |
| match receive.recv_trailers().await { | |
| Ok(Some(trailers)) => { | |
| return Some((Ok(Frame::trailers(trailers)), (receive, true))) | |
| } | |
| Ok(None) => is_body_finished = true, | |
| Err(err) => { | |
| return Some(( | |
| Err(std::io::Error::other(err.to_string())), | |
| (receive, true), | |
| )) | |
| } | |
| } | |
| } | |
| }, | |
| ); | |
| let request_body = BodyExt::boxed(StreamBody::new(request_body_stream)); | |
| let (request_parts, _) = request.into_parts(); | |
| let request = Request::from_parts(request_parts, request_body); | |
| let handlers_vec_clone = handlers_vec | |
| .clone() | |
| .collect::<Vec<Box<dyn ServerModuleHandlers + Send>>>(); | |
| let response = match request_handler( | |
| request, | |
| remote_address, | |
| local_address, | |
| true, | |
| config, | |
| logger_clone.clone(), | |
| handlers_vec_clone, | |
| None, | |
| None, | |
| ) | |
| .await | |
| { | |
| Ok(response) => response, | |
| Err(err) => { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| }; | |
| let (response_parts, mut response_body) = response.into_parts(); | |
| if let Err(err) = send | |
| .send_response(Response::from_parts(response_parts, ())) | |
| .await | |
| { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| let mut had_trailers = false; | |
| while let Some(chunk) = response_body.frame().await { | |
| match chunk { | |
| Ok(frame) => { | |
| if frame.is_data() { | |
| match frame.into_data() { | |
| Ok(data) => { | |
| if let Err(err) = send.send_data(data).await { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| } | |
| Err(_) => { | |
| logger_clone | |
| .send(LogMessage::new( | |
| "Error serving HTTP/3 connection: the frame isn't really a data frame".to_string(), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| } | |
| } else if frame.is_trailers() { | |
| match frame.into_trailers() { | |
| Ok(trailers) => { | |
| had_trailers = true; | |
| if let Err(err) = send.send_trailers(trailers).await { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| } | |
| Err(_) => { | |
| logger_clone | |
| .send(LogMessage::new( | |
| "Error serving HTTP/3 connection: the frame isn't really a trailers frame".to_string(), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| } | |
| } | |
| } | |
| Err(err) => { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| } | |
| } | |
| if !had_trailers { | |
| if let Err(err) = send.finish().await { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| }); | |
| } | |
| Ok(None) => break, | |
| Err(err) => { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP/3 connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| } | |
| } | |
| } | |
| Err(err) => { | |
| logger_clone | |
| .send(LogMessage::new( | |
| format!("Cannot accept a connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| }); | |
| } | |
| // Function to accept and handle incoming connections | |
| async fn accept_connection( | |
| stream: TcpStream, | |
| remote_address: SocketAddr, | |
| tls_config_option: Option<(Arc<ServerConfig>, Option<Arc<ServerConfig>>)>, | |
| acme_http01_resolver_option: Option<Arc<ResolvesServerCertAcme>>, | |
| config: Arc<Yaml>, | |
| logger: Sender<LogMessage>, | |
| modules: Arc<Vec<Box<dyn ServerModule + std::marker::Send + Sync>>>, | |
| http3_enabled: Option<u16>, | |
| ) { | |
| // Disable Nagle algorithm to improve performance | |
| if let Err(err) = stream.set_nodelay(true) { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot disable Nagle algorithm: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| }; | |
| let config = config.clone(); | |
| let local_address = match stream.local_addr() { | |
| Ok(local_address) => local_address, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot obtain local address of the connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| }; | |
| let logger_clone = logger.clone(); | |
| tokio::task::spawn(async move { | |
| let maybe_tls_stream = if let Some((tls_config, acme_config_option)) = tls_config_option { | |
| let start_handshake = match LazyConfigAcceptor::new(Acceptor::default(), stream).await { | |
| Ok(start_handshake) => start_handshake, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Error during TLS handshake: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| }; | |
| if let Some(acme_config) = acme_config_option { | |
| if is_tls_alpn_challenge(&start_handshake.client_hello()) { | |
| match start_handshake.into_stream(acme_config).await { | |
| Ok(_) => (), | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Error during TLS handshake: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| }; | |
| return; | |
| } | |
| } | |
| let tls_stream = match start_handshake.into_stream(tls_config).await { | |
| Ok(tls_stream) => tls_stream, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Error during TLS handshake: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| return; | |
| } | |
| }; | |
| MaybeTlsStream::Tls(tls_stream) | |
| } else { | |
| MaybeTlsStream::Plain(stream) | |
| }; | |
| if let MaybeTlsStream::Tls(tls_stream) = maybe_tls_stream { | |
| let alpn_protocol = tls_stream.get_ref().1.alpn_protocol(); | |
| let is_http2; | |
| if config["global"]["enableHTTP2"].as_bool().unwrap_or(true) { | |
| if alpn_protocol == Some("h2".as_bytes()) { | |
| is_http2 = true; | |
| } else { | |
| // Don't allow HTTP/2 if "h2" ALPN offering was't present | |
| is_http2 = false; | |
| } | |
| } else { | |
| is_http2 = false; | |
| } | |
| let io = TokioIo::new(tls_stream); | |
| let handlers_vec = modules | |
| .iter() | |
| .map(|module| module.get_handlers(Handle::current())); | |
| if is_http2 { | |
| let mut http2_builder = hyper::server::conn::http2::Builder::new(TokioExecutor::new()); | |
| http2_builder.timer(TokioTimer::new()); | |
| if let Some(initial_window_size) = | |
| config["global"]["http2Settings"]["initialWindowSize"].as_i64() | |
| { | |
| http2_builder.initial_stream_window_size(initial_window_size as u32); | |
| } | |
| if let Some(max_frame_size) = config["global"]["http2Settings"]["maxFrameSize"].as_i64() { | |
| http2_builder.max_frame_size(max_frame_size as u32); | |
| } | |
| if let Some(max_concurrent_streams) = | |
| config["global"]["http2Settings"]["maxConcurrentStreams"].as_i64() | |
| { | |
| http2_builder.max_concurrent_streams(max_concurrent_streams as u32); | |
| } | |
| if let Some(max_header_list_size) = | |
| config["global"]["http2Settings"]["maxHeaderListSize"].as_i64() | |
| { | |
| http2_builder.max_header_list_size(max_header_list_size as u32); | |
| } | |
| if let Some(enable_connect_protocol) = | |
| config["global"]["http2Settings"]["enableConnectProtocol"].as_bool() | |
| { | |
| if enable_connect_protocol { | |
| http2_builder.enable_connect_protocol(); | |
| } | |
| } | |
| if let Err(err) = http2_builder | |
| .serve_connection( | |
| io, | |
| service_fn(move |request: Request<Incoming>| { | |
| let config = config.clone(); | |
| let logger = logger_clone.clone(); | |
| let handlers_vec_clone = handlers_vec | |
| .clone() | |
| .collect::<Vec<Box<dyn ServerModuleHandlers + Send>>>(); | |
| let acme_http01_resolver_option_clone = acme_http01_resolver_option.clone(); | |
| let (request_parts, request_body) = request.into_parts(); | |
| let request = Request::from_parts( | |
| request_parts, | |
| request_body | |
| .map_err(|e| std::io::Error::other(e.to_string())) | |
| .boxed(), | |
| ); | |
| request_handler( | |
| request, | |
| remote_address, | |
| local_address, | |
| true, | |
| config, | |
| logger, | |
| handlers_vec_clone, | |
| acme_http01_resolver_option_clone, | |
| http3_enabled, | |
| ) | |
| }), | |
| ) | |
| .await | |
| { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Error serving HTTPS connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } else { | |
| let mut http1_builder = hyper::server::conn::http1::Builder::new(); | |
| // The timer is neccessary for the header timeout to work to mitigate Slowloris. | |
| http1_builder.timer(TokioTimer::new()); | |
| if let Err(err) = http1_builder | |
| .serve_connection( | |
| io, | |
| service_fn(move |request: Request<Incoming>| { | |
| let config = config.clone(); | |
| let logger = logger_clone.clone(); | |
| let handlers_vec_clone = handlers_vec | |
| .clone() | |
| .collect::<Vec<Box<dyn ServerModuleHandlers + Send>>>(); | |
| let acme_http01_resolver_option_clone = acme_http01_resolver_option.clone(); | |
| let (request_parts, request_body) = request.into_parts(); | |
| let request = Request::from_parts( | |
| request_parts, | |
| request_body | |
| .map_err(|e| std::io::Error::other(e.to_string())) | |
| .boxed(), | |
| ); | |
| request_handler( | |
| request, | |
| remote_address, | |
| local_address, | |
| true, | |
| config, | |
| logger, | |
| handlers_vec_clone, | |
| acme_http01_resolver_option_clone, | |
| http3_enabled, | |
| ) | |
| }), | |
| ) | |
| .with_upgrades() | |
| .await | |
| { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Error serving HTTPS connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| } else if let MaybeTlsStream::Plain(stream) = maybe_tls_stream { | |
| let io = TokioIo::new(stream); | |
| let handlers_vec = modules | |
| .iter() | |
| .map(|module| module.get_handlers(Handle::current())); | |
| let mut http1_builder = hyper::server::conn::http1::Builder::new(); | |
| // The timer is neccessary for the header timeout to work to mitigate Slowloris. | |
| http1_builder.timer(TokioTimer::new()); | |
| if let Err(err) = http1_builder | |
| .serve_connection( | |
| io, | |
| service_fn(move |request: Request<Incoming>| { | |
| let config = config.clone(); | |
| let logger = logger_clone.clone(); | |
| let handlers_vec_clone = handlers_vec | |
| .clone() | |
| .collect::<Vec<Box<dyn ServerModuleHandlers + Send>>>(); | |
| let acme_http01_resolver_option_clone = acme_http01_resolver_option.clone(); | |
| let (request_parts, request_body) = request.into_parts(); | |
| let request = Request::from_parts( | |
| request_parts, | |
| request_body | |
| .map_err(|e| std::io::Error::other(e.to_string())) | |
| .boxed(), | |
| ); | |
| request_handler( | |
| request, | |
| remote_address, | |
| local_address, | |
| false, | |
| config, | |
| logger, | |
| handlers_vec_clone, | |
| acme_http01_resolver_option_clone, | |
| http3_enabled, | |
| ) | |
| }), | |
| ) | |
| .with_upgrades() | |
| .await | |
| { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Error serving HTTP connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| }); | |
| } | |
| // Main server event loop | |
| async fn server_event_loop( | |
| yaml_config: Arc<Yaml>, | |
| logger: Sender<LogMessage>, | |
| modules: Vec<Box<dyn ServerModule + Send + Sync>>, | |
| module_error: Option<anyhow::Error>, | |
| modules_optional_builtin: Vec<String>, | |
| first_startup: bool, | |
| ) -> Result<(), Box<dyn Error + Send + Sync>> { | |
| if let Some(module_error) = module_error { | |
| logger | |
| .send(LogMessage::new(module_error.to_string(), true)) | |
| .await | |
| .unwrap_or_default(); | |
| Err(module_error)? | |
| } | |
| let prepared_config = match prepare_config_for_validation(&yaml_config) { | |
| Ok(prepared_config) => prepared_config, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Server configuration validation failed: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Server configuration validation failed: {}", | |
| err | |
| )))? | |
| } | |
| }; | |
| for (config_to_validate, is_global, is_location) in prepared_config { | |
| match validate_config( | |
| config_to_validate, | |
| is_global, | |
| is_location, | |
| &modules_optional_builtin, | |
| ) { | |
| Ok(unused_properties) => { | |
| for unused_property in unused_properties { | |
| logger | |
| .send(LogMessage::new( | |
| format!( | |
| "Unused configuration property detected: \"{}\". You might load an appropriate module to use this configuration property", | |
| unused_property | |
| ), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Server configuration validation failed: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Server configuration validation failed: {}", | |
| err | |
| )))? | |
| } | |
| }; | |
| } | |
| let mut crypto_provider = default_provider(); | |
| if let Some(cipher_suite) = yaml_config["global"]["cipherSuite"].as_vec() { | |
| let mut cipher_suites = Vec::new(); | |
| let cipher_suite_iter = cipher_suite.iter(); | |
| for cipher_suite_yaml in cipher_suite_iter { | |
| if let Some(cipher_suite) = cipher_suite_yaml.as_str() { | |
| let cipher_suite_to_add = match cipher_suite { | |
| "TLS_AES_128_GCM_SHA256" => TLS13_AES_128_GCM_SHA256, | |
| "TLS_AES_256_GCM_SHA384" => TLS13_AES_256_GCM_SHA384, | |
| "TLS_CHACHA20_POLY1305_SHA256" => TLS13_CHACHA20_POLY1305_SHA256, | |
| "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, | |
| "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, | |
| "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => { | |
| TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 | |
| } | |
| "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256" => TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, | |
| "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384" => TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, | |
| "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => { | |
| TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 | |
| } | |
| _ => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("The \"{}\" cipher suite is not supported", cipher_suite), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "The \"{}\" cipher suite is not supported", | |
| cipher_suite | |
| )))? | |
| } | |
| }; | |
| cipher_suites.push(cipher_suite_to_add); | |
| } | |
| } | |
| crypto_provider.cipher_suites = cipher_suites; | |
| } | |
| if let Some(ecdh_curves) = yaml_config["global"]["ecdhCurve"].as_vec() { | |
| let mut kx_groups = Vec::new(); | |
| let ecdh_curves_iter = ecdh_curves.iter(); | |
| for ecdh_curve_yaml in ecdh_curves_iter { | |
| if let Some(ecdh_curve) = ecdh_curve_yaml.as_str() { | |
| let kx_group_to_add = match ecdh_curve { | |
| "secp256r1" => SECP256R1, | |
| "secp384r1" => SECP384R1, | |
| "x25519" => X25519, | |
| _ => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("The \"{}\" ECDH curve is not supported", ecdh_curve), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "The \"{}\" ECDH curve is not supported", | |
| ecdh_curve | |
| )))? | |
| } | |
| }; | |
| kx_groups.push(kx_group_to_add); | |
| } | |
| } | |
| crypto_provider.kx_groups = kx_groups; | |
| } | |
| let crypto_provider_cloned = crypto_provider.clone(); | |
| let mut sni_resolver = CustomSniResolver::new(); | |
| let mut certified_keys = Vec::new(); | |
| let mut automatic_tls_enabled = false; | |
| let mut acme_letsencrypt_production = true; | |
| let acme_use_http_challenge = yaml_config["global"]["useAutomaticTLSHTTPChallenge"] | |
| .as_bool() | |
| .unwrap_or(false); | |
| let acme_challenge_type = if acme_use_http_challenge { | |
| UseChallenge::Http01 | |
| } else { | |
| UseChallenge::TlsAlpn01 | |
| }; | |
| // Read automatic TLS configuration | |
| if let Some(read_automatic_tls_enabled) = yaml_config["global"]["enableAutomaticTLS"].as_bool() { | |
| automatic_tls_enabled = read_automatic_tls_enabled; | |
| } | |
| let acme_contact = yaml_config["global"]["automaticTLSContactEmail"].as_str(); | |
| let acme_cache = yaml_config["global"]["automaticTLSContactCacheDirectory"] | |
| .as_str() | |
| .map(|s| s.to_string()) | |
| .map(DirCache::new); | |
| if let Some(read_acme_letsencrypt_production) = | |
| yaml_config["global"]["automaticTLSLetsEncryptProduction"].as_bool() | |
| { | |
| acme_letsencrypt_production = read_acme_letsencrypt_production; | |
| } | |
| if !automatic_tls_enabled { | |
| // Load public certificate and private key | |
| if let Some(cert_path) = yaml_config["global"]["cert"].as_str() { | |
| if let Some(key_path) = yaml_config["global"]["key"].as_str() { | |
| let certs = match load_certs(cert_path) { | |
| Ok(certs) => certs, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot load the \"{}\" TLS certificate: {}", cert_path, err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot load the \"{}\" TLS certificate: {}", | |
| cert_path, err | |
| )))? | |
| } | |
| }; | |
| let key = match load_private_key(key_path) { | |
| Ok(key) => key, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot load the \"{}\" private key: {}", cert_path, err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot load the \"{}\" private key: {}", | |
| cert_path, err | |
| )))? | |
| } | |
| }; | |
| let signing_key = match crypto_provider_cloned.key_provider.load_private_key(key) { | |
| Ok(key) => key, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot load the \"{}\" private key: {}", cert_path, err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot load the \"{}\" private key: {}", | |
| cert_path, err | |
| )))? | |
| } | |
| }; | |
| let certified_key = CertifiedKey::new(certs, signing_key); | |
| sni_resolver.load_fallback_cert_key(Arc::new(certified_key)); | |
| } | |
| } | |
| if let Some(sni) = yaml_config["global"]["sni"].as_hash() { | |
| let sni_hostnames = sni.keys(); | |
| for sni_hostname_unknown in sni_hostnames { | |
| if let Some(sni_hostname) = sni_hostname_unknown.as_str() { | |
| if let Some(cert_path) = sni[sni_hostname_unknown]["cert"].as_str() { | |
| if let Some(key_path) = sni[sni_hostname_unknown]["key"].as_str() { | |
| let certs = match load_certs(cert_path) { | |
| Ok(certs) => certs, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot load the \"{}\" TLS certificate: {}", cert_path, err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot load the \"{}\" TLS certificate: {}", | |
| cert_path, err | |
| )))? | |
| } | |
| }; | |
| let key = match load_private_key(key_path) { | |
| Ok(key) => key, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot load the \"{}\" private key: {}", cert_path, err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot load the \"{}\" private key: {}", | |
| cert_path, err | |
| )))? | |
| } | |
| }; | |
| let signing_key = match crypto_provider_cloned.key_provider.load_private_key(key) { | |
| Ok(key) => key, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot load the \"{}\" private key: {}", cert_path, err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot load the \"{}\" private key: {}", | |
| cert_path, err | |
| )))? | |
| } | |
| }; | |
| let certified_key_arc = Arc::new(CertifiedKey::new(certs, signing_key)); | |
| sni_resolver.load_host_cert_key(sni_hostname, certified_key_arc.clone()); | |
| certified_keys.push(certified_key_arc); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // Build TLS configuration | |
| let tls_config_builder_wants_versions = | |
| ServerConfig::builder_with_provider(Arc::new(crypto_provider_cloned)); | |
| // Very simple minimum and maximum TLS version logic for now... | |
| let min_tls_version_option = yaml_config["global"]["tlsMinVersion"].as_str(); | |
| let max_tls_version_option = yaml_config["global"]["tlsMaxVersion"].as_str(); | |
| let tls_config_builder_wants_verifier = match min_tls_version_option { | |
| Some("TLSv1.3") => match max_tls_version_option { | |
| Some("TLSv1.2") => { | |
| logger | |
| .send(LogMessage::new( | |
| String::from("The maximum TLS version is older than the minimum TLS version"), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(String::from( | |
| "The maximum TLS version is older than the minimum TLS version" | |
| )))? | |
| } | |
| Some("TLSv1.3") | None => { | |
| match tls_config_builder_wants_versions.with_protocol_versions(&[&TLS13]) { | |
| Ok(builder) => builder, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Couldn't create the TLS server configuration: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Couldn't create the TLS server configuration: {}", | |
| err | |
| )))? | |
| } | |
| } | |
| } | |
| _ => { | |
| logger | |
| .send(LogMessage::new( | |
| String::from("Invalid maximum TLS version"), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(String::from("Invalid maximum TLS version")))? | |
| } | |
| }, | |
| Some("TLSv1.2") | None => match max_tls_version_option { | |
| Some("TLSv1.2") => { | |
| match tls_config_builder_wants_versions.with_protocol_versions(&[&TLS12]) { | |
| Ok(builder) => builder, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Couldn't create the TLS server configuration: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Couldn't create the TLS server configuration: {}", | |
| err | |
| )))? | |
| } | |
| } | |
| } | |
| Some("TLSv1.3") | None => { | |
| match tls_config_builder_wants_versions.with_protocol_versions(&[&TLS12, &TLS13]) { | |
| Ok(builder) => builder, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Couldn't create the TLS server configuration: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Couldn't create the TLS server configuration: {}", | |
| err | |
| )))? | |
| } | |
| } | |
| } | |
| _ => { | |
| logger | |
| .send(LogMessage::new( | |
| String::from("Invalid maximum TLS version"), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(String::from("Invalid maximum TLS version")))? | |
| } | |
| }, | |
| _ => { | |
| logger | |
| .send(LogMessage::new( | |
| String::from("Invalid minimum TLS version"), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(String::from("Invalid minimum TLS version")))? | |
| } | |
| }; | |
| let tls_config_builder_wants_server_cert = | |
| match yaml_config["global"]["useClientCertificate"].as_bool() { | |
| Some(true) => { | |
| let mut roots = RootCertStore::empty(); | |
| let certs_result = load_native_certs(); | |
| if !certs_result.errors.is_empty() { | |
| logger | |
| .send(LogMessage::new( | |
| format!( | |
| "Couldn't load the native certificate store: {}", | |
| certs_result.errors[0] | |
| ), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Couldn't load the native certificate store: {}", | |
| certs_result.errors[0] | |
| )))? | |
| } | |
| let certs = certs_result.certs; | |
| for cert in certs { | |
| match roots.add(cert) { | |
| Ok(_) => (), | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!( | |
| "Couldn't add a certificate to the certificate store: {}", | |
| err | |
| ), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Couldn't add a certificate to the certificate store: {}", | |
| err | |
| )))? | |
| } | |
| } | |
| } | |
| tls_config_builder_wants_verifier | |
| .with_client_cert_verifier(WebPkiClientVerifier::builder(Arc::new(roots)).build()?) | |
| } | |
| _ => tls_config_builder_wants_verifier.with_no_client_auth(), | |
| }; | |
| let mut tls_config; | |
| let mut addr = SocketAddr::from((IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 80)); | |
| let mut addr_tls = SocketAddr::from((IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), 443)); | |
| let mut tls_enabled = false; | |
| let mut non_tls_disabled = false; | |
| // Install a process-wide cryptography provider. If it fails, then warn about it. | |
| if crypto_provider.install_default().is_err() && first_startup { | |
| logger | |
| .send(LogMessage::new( | |
| "Cannot install a process-wide cryptography provider".to_string(), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!( | |
| "Cannot install a process-wide cryptography provider" | |
| ))?; | |
| } | |
| // Read port configurations from YAML | |
| if let Some(read_port) = yaml_config["global"]["port"].as_i64() { | |
| addr = SocketAddr::from(( | |
| IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), | |
| match read_port.try_into() { | |
| Ok(port) => port, | |
| Err(_) => { | |
| logger | |
| .send(LogMessage::new(String::from("Invalid HTTP port"), true)) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!("Invalid HTTP port"))? | |
| } | |
| }, | |
| )); | |
| } else if let Some(read_port) = yaml_config["global"]["port"].as_str() { | |
| addr = match read_port.parse() { | |
| Ok(addr) => addr, | |
| Err(_) => { | |
| logger | |
| .send(LogMessage::new(String::from("Invalid HTTP port"), true)) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!("Invalid HTTP port"))? | |
| } | |
| }; | |
| } | |
| if let Some(read_tls_enabled) = yaml_config["global"]["secure"].as_bool() { | |
| tls_enabled = read_tls_enabled; | |
| if let Some(read_non_tls_disabled) = | |
| yaml_config["global"]["disableNonEncryptedServer"].as_bool() | |
| { | |
| non_tls_disabled = read_non_tls_disabled; | |
| } | |
| } | |
| if let Some(read_port) = yaml_config["global"]["sport"].as_i64() { | |
| addr_tls = SocketAddr::from(( | |
| IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)), | |
| match read_port.try_into() { | |
| Ok(port) => port, | |
| Err(_) => { | |
| logger | |
| .send(LogMessage::new(String::from("Invalid HTTPS port"), true)) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!("Invalid HTTPS port"))? | |
| } | |
| }, | |
| )); | |
| } else if let Some(read_port) = yaml_config["global"]["sport"].as_str() { | |
| addr_tls = match read_port.parse() { | |
| Ok(addr) => addr, | |
| Err(_) => { | |
| logger | |
| .send(LogMessage::new(String::from("Invalid HTTPS port"), true)) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!("Invalid HTTPS port"))? | |
| } | |
| }; | |
| } | |
| // Get domains for ACME configuration | |
| let mut acme_domains = Vec::new(); | |
| if let Some(hosts_config) = yaml_config["hosts"].as_vec() { | |
| for host_yaml in hosts_config.iter() { | |
| if let Some(host) = host_yaml.as_hash() { | |
| if let Some(domain_yaml) = host.get(&Yaml::from_str("domain")) { | |
| if let Some(domain) = domain_yaml.as_str() { | |
| if !domain.contains("*") { | |
| acme_domains.push(domain); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| // Create ACME configuration | |
| let mut acme_config = AcmeConfig::new(acme_domains).challenge_type(acme_challenge_type); | |
| if let Some(acme_contact_unwrapped) = acme_contact { | |
| acme_config = acme_config.contact_push(format!("mailto:{}", acme_contact_unwrapped)); | |
| } | |
| let mut acme_config_with_cache = acme_config.cache_option(acme_cache); | |
| acme_config_with_cache = | |
| acme_config_with_cache.directory_lets_encrypt(acme_letsencrypt_production); | |
| let (acme_config, acme_http01_resolver) = if tls_enabled && automatic_tls_enabled { | |
| let mut acme_state = acme_config_with_cache.state(); | |
| let acme_resolver = acme_state.resolver(); | |
| // Create TLS configuration | |
| tls_config = if yaml_config["global"]["enableOCSPStapling"] | |
| .as_bool() | |
| .unwrap_or(true) | |
| { | |
| tls_config_builder_wants_server_cert | |
| .with_cert_resolver(Arc::new(Stapler::new(acme_resolver.clone()))) | |
| } else { | |
| tls_config_builder_wants_server_cert.with_cert_resolver(acme_resolver.clone()) | |
| }; | |
| let acme_logger = logger.clone(); | |
| tokio::spawn(async move { | |
| while let Some(acme_result) = acme_state.next().await { | |
| if let Err(acme_error) = acme_result { | |
| acme_logger | |
| .send(LogMessage::new( | |
| format!("Error while obtaining a TLS certificate: {}", acme_error), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| }); | |
| if acme_use_http_challenge { | |
| (None, Some(acme_resolver)) | |
| } else { | |
| let mut acme_config = tls_config.clone(); | |
| acme_config.alpn_protocols.push(ACME_TLS_ALPN_NAME.to_vec()); | |
| (Some(acme_config), None) | |
| } | |
| } else { | |
| // Create TLS configuration | |
| tls_config = if yaml_config["global"]["enableOCSPStapling"] | |
| .as_bool() | |
| .unwrap_or(true) | |
| { | |
| let ocsp_stapler_arc = Arc::new(Stapler::new(Arc::new(sni_resolver))); | |
| for certified_key in certified_keys.iter() { | |
| ocsp_stapler_arc.preload(certified_key.clone()); | |
| } | |
| tls_config_builder_wants_server_cert.with_cert_resolver(ocsp_stapler_arc.clone()) | |
| } else { | |
| tls_config_builder_wants_server_cert.with_cert_resolver(Arc::new(sni_resolver)) | |
| }; | |
| // Drop the ACME configuration | |
| drop(acme_config_with_cache); | |
| (None, None) | |
| }; | |
| let quic_config = if tls_enabled | |
| && yaml_config["global"]["enableHTTP3"] | |
| .as_bool() | |
| .unwrap_or(false) | |
| { | |
| let mut quic_tls_config = tls_config.clone(); | |
| quic_tls_config.max_early_data_size = u32::MAX; | |
| quic_tls_config.alpn_protocols = vec![b"h3".to_vec(), b"h3-29".to_vec()]; | |
| let quic_config = quinn::ServerConfig::with_crypto(Arc::new(match QuicServerConfig::try_from( | |
| quic_tls_config, | |
| ) { | |
| Ok(quinn_config) => quinn_config, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("There was a problem when starting HTTP/3 server: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "There was a problem when starting HTTP/3 server: {}", | |
| err | |
| )))? | |
| } | |
| })); | |
| Some(quic_config) | |
| } else { | |
| None | |
| }; | |
| // Configure ALPN protocols | |
| let mut alpn_protocols = vec![b"http/1.1".to_vec(), b"http/1.0".to_vec()]; | |
| if yaml_config["global"]["enableHTTP2"] | |
| .as_bool() | |
| .unwrap_or(true) | |
| { | |
| alpn_protocols.insert(0, b"h2".to_vec()); | |
| } | |
| tls_config.alpn_protocols = alpn_protocols; | |
| let tls_config_arc = Arc::new(tls_config); | |
| let acme_config_arc = acme_config.map(Arc::new); | |
| let mut listener = None; | |
| let mut listener_tls = None; | |
| let mut listener_quic = None; | |
| // Bind to the specified ports | |
| if !non_tls_disabled { | |
| println!("HTTP server is listening at {}", addr); | |
| listener = Some(match TcpListener::bind(addr).await { | |
| Ok(listener) => listener, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot listen to HTTP port: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot listen to HTTP port: {}", | |
| err | |
| )))? | |
| } | |
| }); | |
| } | |
| if tls_enabled { | |
| println!("HTTPS server is listening at {}", addr_tls); | |
| listener_tls = Some(match TcpListener::bind(addr_tls).await { | |
| Ok(listener) => listener, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot listen to HTTPS port: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot listen to HTTPS port: {}", | |
| err | |
| )))? | |
| } | |
| }); | |
| if let Some(quic_config) = quic_config { | |
| println!("HTTP/3 server is listening at {}", addr_tls); | |
| listener_quic = Some(match quinn::Endpoint::server(quic_config, addr_tls) { | |
| Ok(listener) => listener, | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot listen to HTTP/3 port: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!(format!( | |
| "Cannot listen to HTTP/3 port: {}", | |
| err | |
| )))? | |
| } | |
| }); | |
| } | |
| } | |
| // Wrap the modules vector in an Arc | |
| let modules_arc = Arc::new(modules); | |
| let http3_enabled = if listener_tls.is_some() { | |
| Some(addr_tls.port()) | |
| } else { | |
| None | |
| }; | |
| // Main loop to accept incoming connections | |
| loop { | |
| let listener_borrowed = &listener; | |
| let listener_accept = async move { | |
| if let Some(listener) = listener_borrowed { | |
| listener.accept().await | |
| } else { | |
| futures_util::future::pending().await | |
| } | |
| }; | |
| let listener_tls_borrowed = &listener_tls; | |
| let listener_tls_accept = async move { | |
| if let Some(listener_tls) = listener_tls_borrowed { | |
| listener_tls.accept().await | |
| } else { | |
| futures_util::future::pending().await | |
| } | |
| }; | |
| let listener_quic_borrowed = &listener_quic; | |
| let listener_quic_accept = async move { | |
| if let Some(listener_quic) = listener_quic_borrowed { | |
| listener_quic.accept().await | |
| } else { | |
| futures_util::future::pending().await | |
| } | |
| }; | |
| if listener_borrowed.is_none() | |
| && listener_tls_borrowed.is_none() | |
| && listener_quic_borrowed.is_none() | |
| { | |
| logger | |
| .send(LogMessage::new( | |
| String::from("No server is listening"), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| Err(anyhow::anyhow!("No server is listening"))?; | |
| } | |
| tokio::select! { | |
| status = listener_accept => { | |
| match status { | |
| Ok((stream, remote_address)) => { | |
| accept_connection( | |
| stream, | |
| remote_address, | |
| None, | |
| acme_http01_resolver.clone(), | |
| yaml_config.clone(), | |
| logger.clone(), | |
| modules_arc.clone(), | |
| None | |
| ) | |
| .await; | |
| } | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot accept a connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| }, | |
| status = listener_tls_accept => { | |
| match status { | |
| Ok((stream, remote_address)) => { | |
| accept_connection( | |
| stream, | |
| remote_address, | |
| Some((tls_config_arc.clone(), acme_config_arc.clone())), | |
| None, | |
| yaml_config.clone(), | |
| logger.clone(), | |
| modules_arc.clone(), | |
| http3_enabled | |
| ) | |
| .await; | |
| } | |
| Err(err) => { | |
| logger | |
| .send(LogMessage::new( | |
| format!("Cannot accept a connection: {}", err), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| }, | |
| status = listener_quic_accept => { | |
| match status { | |
| Some(connection_attempt) => { | |
| let local_ip = SocketAddr::new(connection_attempt.local_ip().unwrap_or(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0))), addr_tls.port()); | |
| accept_quic_connection( | |
| connection_attempt, | |
| local_ip, | |
| yaml_config.clone(), | |
| logger.clone(), | |
| modules_arc.clone() | |
| ) | |
| .await; | |
| } | |
| None => { | |
| logger | |
| .send(LogMessage::new( | |
| "HTTP/3 connections can't be accepted anymore".to_string(), | |
| true, | |
| )) | |
| .await | |
| .unwrap_or_default(); | |
| } | |
| } | |
| } | |
| }; | |
| } | |
| } | |
| // Start the server | |
| pub fn start_server( | |
| yaml_config: Arc<Yaml>, | |
| modules: Vec<Box<dyn ServerModule + Send + Sync>>, | |
| module_error: Option<anyhow::Error>, | |
| modules_optional_builtin: Vec<String>, | |
| first_startup: bool, | |
| ) -> Result<bool, Box<dyn Error + Send + Sync>> { | |
| if let Some(environment_variables_hash) = yaml_config["global"]["environmentVariables"].as_hash() | |
| { | |
| let environment_variables_hash_iter = environment_variables_hash.iter(); | |
| for (variable_name, variable_value) in environment_variables_hash_iter { | |
| if let Some(variable_name) = variable_name.as_str() { | |
| if let Some(variable_value) = variable_value.as_str() { | |
| if !variable_name.is_empty() | |
| && !variable_name.contains('\0') | |
| && !variable_name.contains('=') | |
| && !variable_value.contains('\0') | |
| { | |
| // Safety: the environment variables are set before threads are spawned | |
| // The `std::env::set_var` function is safe to use in single-threaded environments | |
| // In Rust 2024 edition, the `std::env::set_var` function would be `unsafe`. | |
| env::set_var(variable_name, variable_value); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| let available_parallelism = thread::available_parallelism()?.get(); | |
| // Create Tokio runtime for the server | |
| let server_runtime = tokio::runtime::Builder::new_multi_thread() | |
| .worker_threads(available_parallelism) | |
| .max_blocking_threads(1536) | |
| .event_interval(25) | |
| .thread_name("server-pool") | |
| .enable_all() | |
| .build()?; | |
| // Create Tokio runtime for logging | |
| let log_runtime = tokio::runtime::Builder::new_multi_thread() | |
| .worker_threads(match available_parallelism / 2 { | |
| 0 => 1, | |
| non_zero => non_zero, | |
| }) | |
| .max_blocking_threads(768) | |
| .thread_name("log-pool") | |
| .enable_time() | |
| .build()?; | |
| let (logger, receive_log) = async_channel::bounded::<LogMessage>(10000); | |
| let log_filename = yaml_config["global"]["logFilePath"] | |
| .as_str() | |
| .map(String::from); | |
| let error_log_filename = yaml_config["global"]["errorLogFilePath"] | |
| .as_str() | |
| .map(String::from); | |
| log_runtime.spawn(async move { | |
| let log_file = match log_filename { | |
| Some(log_filename) => Some( | |
| fs::OpenOptions::new() | |
| .append(true) | |
| .create(true) | |
| .open(log_filename) | |
| .await, | |
| ), | |
| None => None, | |
| }; | |
| let error_log_file = match error_log_filename { | |
| Some(error_log_filename) => Some( | |
| fs::OpenOptions::new() | |
| .append(true) | |
| .create(true) | |
| .open(error_log_filename) | |
| .await, | |
| ), | |
| None => None, | |
| }; | |
| let log_file_wrapped = match log_file { | |
| Some(Ok(file)) => Some(Arc::new(Mutex::new(BufWriter::with_capacity(131072, file)))), | |
| Some(Err(e)) => { | |
| eprintln!("Failed to open log file: {}", e); | |
| None | |
| } | |
| None => None, | |
| }; | |
| let error_log_file_wrapped = match error_log_file { | |
| Some(Ok(file)) => Some(Arc::new(Mutex::new(BufWriter::with_capacity(131072, file)))), | |
| Some(Err(e)) => { | |
| eprintln!("Failed to open error log file: {}", e); | |
| None | |
| } | |
| None => None, | |
| }; | |
| // The logs are written when the log message is received by the log event loop, and flushed every 100 ms, improving the server performance. | |
| let log_file_wrapped_cloned_for_sleep = log_file_wrapped.clone(); | |
| let error_log_file_wrapped_cloned_for_sleep = error_log_file_wrapped.clone(); | |
| tokio::task::spawn(async move { | |
| let mut interval = time::interval(time::Duration::from_millis(100)); | |
| loop { | |
| interval.tick().await; | |
| if let Some(log_file_wrapped_cloned) = log_file_wrapped_cloned_for_sleep.clone() { | |
| let mut locked_file = log_file_wrapped_cloned.lock().await; | |
| locked_file.flush().await.unwrap_or_default(); | |
| } | |
| if let Some(error_log_file_wrapped_cloned) = error_log_file_wrapped_cloned_for_sleep.clone() | |
| { | |
| let mut locked_file = error_log_file_wrapped_cloned.lock().await; | |
| locked_file.flush().await.unwrap_or_default(); | |
| } | |
| } | |
| }); | |
| // Logging loop | |
| while let Ok(message) = receive_log.recv().await { | |
| let (mut message, is_error) = message.get_message(); | |
| let log_file_wrapped_cloned = if !is_error { | |
| log_file_wrapped.clone() | |
| } else { | |
| error_log_file_wrapped.clone() | |
| }; | |
| if let Some(log_file_wrapped_cloned) = log_file_wrapped_cloned { | |
| tokio::task::spawn(async move { | |
| let mut locked_file = log_file_wrapped_cloned.lock().await; | |
| if is_error { | |
| let now: DateTime<Local> = Local::now(); | |
| let formatted_time = now.format("%Y-%m-%d %H:%M:%S").to_string(); | |
| message = format!("[{}]: {}", formatted_time, message); | |
| } | |
| message.push('\n'); | |
| if let Err(e) = locked_file.write(message.as_bytes()).await { | |
| eprintln!("Failed to write to log file: {}", e); | |
| } | |
| }); | |
| } | |
| } | |
| }); | |
| // Run the server event loop | |
| let result = server_runtime.block_on(async { | |
| let event_loop_future = server_event_loop( | |
| yaml_config, | |
| logger, | |
| modules, | |
| module_error, | |
| modules_optional_builtin, | |
| first_startup, | |
| ); | |
| { | |
| use tokio::signal; | |
| match signal::unix::signal(signal::unix::SignalKind::hangup()) { | |
| Ok(mut signal) => { | |
| tokio::select! { | |
| result = event_loop_future => { | |
| // Sleep the Tokio runtime to ensure error logs are saved | |
| time::sleep(tokio::time::Duration::from_millis(100)).await; | |
| result.map(|_| false) | |
| }, | |
| _ = signal.recv() => Ok(true) | |
| } | |
| } | |
| Err(_) => { | |
| let result = event_loop_future.await; | |
| // Sleep the Tokio runtime to ensure error logs are saved | |
| time::sleep(tokio::time::Duration::from_millis(100)).await; | |
| result.map(|_| false) | |
| } | |
| } | |
| } | |
| { | |
| let result = event_loop_future.await; | |
| // Sleep the Tokio runtime to ensure error logs are saved | |
| time::sleep(tokio::time::Duration::from_millis(100)).await; | |
| result.map(|_| false) | |
| } | |
| }); | |
| // Wait 10 seconds or until all tasks are complete | |
| server_runtime.shutdown_timeout(time::Duration::from_secs(10)); | |
| result | |
| } | |