use std::collections::HashMap; use std::error::Error; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; use crate::ferron_common::{ ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule, ServerModuleHandlers, SocketData, }; use crate::ferron_common::{HyperResponse, WithRuntime}; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; use http::uri::{PathAndQuery, Scheme}; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; use hyper::body::Bytes; use hyper::client::conn::http1::SendRequest; use hyper::{header, Request, StatusCode, Uri}; use hyper_tungstenite::HyperWebsocket; use hyper_util::rt::TokioIo; use rustls::pki_types::ServerName; use rustls::RootCertStore; use rustls_native_certs::load_native_certs; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::runtime::Handle; use tokio::sync::RwLock; use tokio_rustls::TlsConnector; use tokio_tungstenite::Connector; use crate::ferron_util::no_server_verifier::NoServerVerifier; use crate::ferron_util::ttl_cache::TtlCache; const DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST: u32 = 32; pub fn server_module_init( config: &ServerConfig, ) -> Result, Box> { let mut roots: RootCertStore = RootCertStore::empty(); let certs_result = load_native_certs(); if !certs_result.errors.is_empty() { Err(anyhow::anyhow!(format!( "Couldn't load the native certificate store: {}", certs_result.errors[0] )))? } let certs = certs_result.certs; for cert in certs { match roots.add(cert) { Ok(_) => (), Err(err) => Err(anyhow::anyhow!(format!( "Couldn't add a certificate to the certificate store: {}", err )))?, } } let mut connections_vec = Vec::new(); for _ in 0..DEFAULT_CONCURRENT_CONNECTIONS_PER_HOST { connections_vec.push(RwLock::new(HashMap::new())); } Ok(Box::new(ReverseProxyModule::new( Arc::new(roots), Arc::new(connections_vec), Arc::new(RwLock::new(TtlCache::new(Duration::from_millis( config["global"]["loadBalancerHealthCheckWindow"] .as_i64() .unwrap_or(5000) as u64, )))), ))) } #[allow(clippy::type_complexity)] struct ReverseProxyModule { roots: Arc, connections: Arc>>>>>, failed_backends: Arc>>, } impl ReverseProxyModule { #[allow(clippy::type_complexity)] fn new( roots: Arc, connections: Arc>>>>>, failed_backends: Arc>>, ) -> Self { Self { roots, connections, failed_backends, } } } impl ServerModule for ReverseProxyModule { fn get_handlers(&self, handle: Handle) -> Box { Box::new(ReverseProxyModuleHandlers { roots: self.roots.clone(), connections: self.connections.clone(), failed_backends: self.failed_backends.clone(), handle, }) } } #[allow(clippy::type_complexity)] struct ReverseProxyModuleHandlers { handle: Handle, roots: Arc, connections: Arc>>>>>, failed_backends: Arc>>, } #[async_trait] impl ServerModuleHandlers for ReverseProxyModuleHandlers { async fn request_handler( &mut self, request: RequestData, config: &ServerConfig, socket_data: &SocketData, error_logger: &ErrorLogger, ) -> Result> { WithRuntime::new(self.handle.clone(), async move { let enable_health_check = config["enableLoadBalancerHealthCheck"] .as_bool() .unwrap_or(false); let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"] .as_i64() .unwrap_or(3) as u64; let disable_certificate_verification = config["disableProxyCertificateVerification"] .as_bool() .unwrap_or(false); if let Some(proxy_to) = determine_proxy_to( config, socket_data.encrypted, &self.failed_backends, enable_health_check, health_check_max_fails, ) .await { let (hyper_request, _auth_user, _original_url) = request.into_parts(); let (mut hyper_request_parts, request_body) = hyper_request.into_parts(); let proxy_request_url = proxy_to.parse::()?; let scheme_str = proxy_request_url.scheme_str(); let mut encrypted = false; match scheme_str { Some("http") => { encrypted = false; } Some("https") => { encrypted = true; } _ => Err(anyhow::anyhow!( "Only HTTP and HTTPS reverse proxy URLs are supported." ))?, }; let host = match proxy_request_url.host() { Some(host) => host, None => Err(anyhow::anyhow!( "The reverse proxy URL doesn't include the host" ))?, }; let port = proxy_request_url.port_u16().unwrap_or(match scheme_str { Some("http") => 80, Some("https") => 443, _ => 80, }); let addr = format!("{}:{}", host, port); let authority = proxy_request_url.authority().cloned(); let hyper_request_path = hyper_request_parts.uri.path(); let path = match hyper_request_path.as_bytes().first() { Some(b'/') => { let mut proxy_request_path = proxy_request_url.path(); while proxy_request_path.as_bytes().last().copied() == Some(b'/') { proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)]; } format!("{}{}", proxy_request_path, hyper_request_path) } _ => hyper_request_path.to_string(), }; hyper_request_parts.uri = Uri::from_str(&format!( "{}{}", path, match hyper_request_parts.uri.query() { Some(query) => format!("?{}", query), None => "".to_string(), } ))?; let original_host = hyper_request_parts.headers.get(header::HOST).cloned(); // Host header for host identification match authority { Some(authority) => { hyper_request_parts .headers .insert(header::HOST, authority.to_string().parse()?); } None => { hyper_request_parts.headers.remove(header::HOST); } } // Connection header to enable HTTP/1.1 keep-alive hyper_request_parts .headers .insert(header::CONNECTION, "keep-alive".parse()?); // X-Forwarded-* headers to send the client's data to a server that's behind the reverse proxy hyper_request_parts.headers.insert( "x-forwarded-for", socket_data .remote_addr .ip() .to_canonical() .to_string() .parse()?, ); if socket_data.encrypted { hyper_request_parts .headers .insert("x-forwarded-proto", "https".parse()?); } else { hyper_request_parts .headers .insert("x-forwarded-proto", "http".parse()?); } if let Some(original_host) = original_host { hyper_request_parts .headers .insert("x-forwarded-host", original_host); } let proxy_request = Request::from_parts(hyper_request_parts, request_body); let connections = &self.connections[rand::random_range(..self.connections.len())]; let rwlock_read = connections.read().await; let sender_read_option = rwlock_read.get(&addr); if let Some(sender_read) = sender_read_option { if !sender_read.is_closed() { drop(rwlock_read); let mut rwlock_write = connections.write().await; let sender_option = rwlock_write.get_mut(&addr); if let Some(sender) = sender_option { if !sender.is_closed() { let result = http_proxy_kept_alive(sender, proxy_request, error_logger).await; drop(rwlock_write); return result; } else { drop(rwlock_write); } } else { drop(rwlock_write); } } else { drop(rwlock_read); } } else { drop(rwlock_read); } let stream = match TcpStream::connect(&addr).await { Ok(stream) => stream, Err(err) => { if enable_health_check { let mut failed_backends_write = self.failed_backends.write().await; let proxy_to = proxy_to.clone(); let failed_attempts = failed_backends_write.get(&proxy_to); failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1)); } match err.kind() { tokio::io::ErrorKind::ConnectionRefused | tokio::io::ErrorKind::NotFound | tokio::io::ErrorKind::HostUnreachable => { error_logger .log(&format!("Service unavailable: {}", err)) .await; return Ok( ResponseData::builder_without_request() .status(StatusCode::SERVICE_UNAVAILABLE) .build(), ); } tokio::io::ErrorKind::TimedOut => { error_logger.log(&format!("Gateway timeout: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::GATEWAY_TIMEOUT) .build(), ); } _ => { error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; } }; match stream.set_nodelay(true) { Ok(_) => (), Err(err) => { if enable_health_check { let mut failed_backends_write = self.failed_backends.write().await; let proxy_to = proxy_to.clone(); let failed_attempts = failed_backends_write.get(&proxy_to); failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1)); } error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; let failed_backends_option_borrowed = if enable_health_check { Some(&*self.failed_backends) } else { None }; if !encrypted { http_proxy( connections, addr, stream, proxy_request, error_logger, proxy_to, failed_backends_option_borrowed, ) .await } else { let tls_client_config = (if disable_certificate_verification { rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new())) } else { rustls::ClientConfig::builder().with_root_certificates(self.roots.clone()) }) .with_no_client_auth(); let connector = TlsConnector::from(Arc::new(tls_client_config)); let domain = ServerName::try_from(host)?.to_owned(); let tls_stream = match connector.connect(domain, stream).await { Ok(stream) => stream, Err(err) => { if enable_health_check { let mut failed_backends_write = self.failed_backends.write().await; let proxy_to = proxy_to.clone(); let failed_attempts = failed_backends_write.get(&proxy_to); failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1)); } error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; http_proxy( connections, addr, tls_stream, proxy_request, error_logger, proxy_to, failed_backends_option_borrowed, ) .await } } else { Ok(ResponseData::builder(request).build()) } }) .await } async fn proxy_request_handler( &mut self, request: RequestData, _config: &ServerConfig, _socket_data: &SocketData, _error_logger: &ErrorLogger, ) -> Result> { Ok(ResponseData::builder(request).build()) } async fn response_modifying_handler( &mut self, response: HyperResponse, ) -> Result> { Ok(response) } async fn proxy_response_modifying_handler( &mut self, response: HyperResponse, ) -> Result> { Ok(response) } async fn connect_proxy_request_handler( &mut self, _upgraded_request: HyperUpgraded, _connect_address: &str, _config: &ServerConfig, _socket_data: &SocketData, _error_logger: &ErrorLogger, ) -> Result<(), Box> { Ok(()) } fn does_connect_proxy_requests(&mut self) -> bool { false } async fn websocket_request_handler( &mut self, websocket: HyperWebsocket, uri: &hyper::Uri, config: &ServerConfig, socket_data: &SocketData, error_logger: &ErrorLogger, ) -> Result<(), Box> { WithRuntime::new(self.handle.clone(), async move { let enable_health_check = config["enableLoadBalancerHealthCheck"] .as_bool() .unwrap_or(false); let health_check_max_fails = config["loadBalancerHealthCheckMaximumFails"] .as_i64() .unwrap_or(3) as u64; let disable_certificate_verification = config["disableProxyCertificateVerification"] .as_bool() .unwrap_or(false); if let Some(proxy_to) = determine_proxy_to( config, socket_data.encrypted, &self.failed_backends, enable_health_check, health_check_max_fails, ) .await { let proxy_request_url = proxy_to.parse::()?; let scheme_str = proxy_request_url.scheme_str(); let mut encrypted = false; match scheme_str { Some("http") => { encrypted = false; } Some("https") => { encrypted = true; } _ => Err(anyhow::anyhow!( "Only HTTP and HTTPS reverse proxy URLs are supported." ))?, }; let request_path = uri.path(); let path = match request_path.as_bytes().first() { Some(b'/') => { let mut proxy_request_path = proxy_request_url.path(); while proxy_request_path.as_bytes().last().copied() == Some(b'/') { proxy_request_path = &proxy_request_path[..(proxy_request_path.len() - 1)]; } format!("{}{}", proxy_request_path, request_path) } _ => request_path.to_string(), }; let mut proxy_request_url_parts = proxy_request_url.into_parts(); proxy_request_url_parts.scheme = if encrypted { Some(Scheme::from_str("wss")?) } else { Some(Scheme::from_str("ws")?) }; match proxy_request_url_parts.path_and_query { Some(path_and_query) => { let path_and_query_string = match path_and_query.query() { Some(query) => { format!("{}?{}", path, query) } None => path, }; proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path_and_query_string)?); } None => { proxy_request_url_parts.path_and_query = Some(PathAndQuery::from_str(&path)?); } }; let proxy_request_url = hyper::Uri::from_parts(proxy_request_url_parts)?; let connector = if !encrypted { Connector::Plain } else { Connector::Rustls(Arc::new( (if disable_certificate_verification { rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(NoServerVerifier::new())) } else { rustls::ClientConfig::builder().with_root_certificates(self.roots.clone()) }) .with_no_client_auth(), )) }; let client_bi_stream = websocket.await?; let (proxy_bi_stream, _) = match tokio_tungstenite::connect_async_tls_with_config( proxy_request_url, None, true, Some(connector), ) .await { Ok(data) => data, Err(err) => { error_logger .log(&format!("Cannot connect to WebSocket server: {}", err)) .await; return Ok(()); } }; let (mut client_sink, mut client_stream) = client_bi_stream.split(); let (mut proxy_sink, mut proxy_stream) = proxy_bi_stream.split(); let client_to_proxy = async { while let Some(Ok(value)) = client_stream.next().await { if proxy_sink.send(value).await.is_err() { break; } } }; let proxy_to_client = async { while let Some(Ok(value)) = proxy_stream.next().await { if client_sink.send(value).await.is_err() { break; } } }; tokio::pin!(client_to_proxy); tokio::pin!(proxy_to_client); let client_to_proxy_first; tokio::select! { _ = &mut client_to_proxy => { client_to_proxy_first = true; } _ = &mut proxy_to_client => { client_to_proxy_first = false; } } if client_to_proxy_first { proxy_to_client.await; } else { client_to_proxy.await; } } Ok(()) }) .await } fn does_websocket_requests(&mut self, config: &ServerConfig, socket_data: &SocketData) -> bool { if socket_data.encrypted { let secure_proxy_to = &config["secureProxyTo"]; if secure_proxy_to.as_vec().is_some() || secure_proxy_to.as_str().is_some() { return true; } } let proxy_to = &config["proxyTo"]; proxy_to.as_vec().is_some() || proxy_to.as_str().is_some() } } async fn determine_proxy_to( config: &ServerConfig, encrypted: bool, failed_backends: &RwLock>, enable_health_check: bool, health_check_max_fails: u64, ) -> Option { let mut proxy_to = None; // When the array is supplied with non-string values, the reverse proxy may have undesirable behavior // The "proxyTo" and "secureProxyTo" are validated though. if encrypted { let secure_proxy_to_yaml = &config["secureProxyTo"]; if let Some(secure_proxy_to_vector) = secure_proxy_to_yaml.as_vec() { if enable_health_check { let mut secure_proxy_to_vector = secure_proxy_to_vector.clone(); loop { if !secure_proxy_to_vector.is_empty() { let index = rand::random_range(..secure_proxy_to_vector.len()); if let Some(secure_proxy_to) = secure_proxy_to_vector[index].as_str() { proxy_to = Some(secure_proxy_to.to_string()); let failed_backends_read = failed_backends.read().await; let failed_backend_fails = match failed_backends_read.get(&secure_proxy_to.to_string()) { Some(fails) => fails, None => break, }; if failed_backend_fails > health_check_max_fails { secure_proxy_to_vector.remove(index); } else { break; } } } else { break; } } } else if !secure_proxy_to_vector.is_empty() { if let Some(secure_proxy_to) = secure_proxy_to_vector[rand::random_range(..secure_proxy_to_vector.len())].as_str() { proxy_to = Some(secure_proxy_to.to_string()); } } } else if let Some(secure_proxy_to) = secure_proxy_to_yaml.as_str() { proxy_to = Some(secure_proxy_to.to_string()); } } if proxy_to.is_none() { let proxy_to_yaml = &config["proxyTo"]; if let Some(proxy_to_vector) = proxy_to_yaml.as_vec() { if enable_health_check { let mut proxy_to_vector = proxy_to_vector.clone(); loop { if !proxy_to_vector.is_empty() { let index = rand::random_range(..proxy_to_vector.len()); if let Some(proxy_to_str) = proxy_to_vector[index].as_str() { proxy_to = Some(proxy_to_str.to_string()); let failed_backends_read = failed_backends.read().await; let failed_backend_fails = match failed_backends_read.get(&proxy_to_str.to_string()) { Some(fails) => fails, None => break, }; if failed_backend_fails > health_check_max_fails { proxy_to_vector.remove(index); } else { break; } } } else { break; } } } else if !proxy_to_vector.is_empty() { if let Some(proxy_to_str) = proxy_to_vector[rand::random_range(..proxy_to_vector.len())].as_str() { proxy_to = Some(proxy_to_str.to_string()); } } } else if let Some(proxy_to_str) = proxy_to_yaml.as_str() { proxy_to = Some(proxy_to_str.to_string()); } } proxy_to } async fn http_proxy( connections: &RwLock>>>, connect_addr: String, stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static, proxy_request: Request>, error_logger: &ErrorLogger, proxy_to: String, failed_backends: Option<&tokio::sync::RwLock>>, ) -> Result> { let io = TokioIo::new(stream); let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await { Ok(data) => data, Err(err) => { if let Some(failed_backends) = failed_backends { let mut failed_backends_write = failed_backends.write().await; let failed_attempts = failed_backends_write.get(&proxy_to); failed_backends_write.insert(proxy_to, failed_attempts.map_or(1, |x| x + 1)); } error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; let send_request = sender.send_request(proxy_request); let mut pinned_conn = Box::pin(conn); tokio::pin!(send_request); let response; loop { tokio::select! { biased; proxy_response = &mut send_request => { let proxy_response = match proxy_response { Ok(response) => response, Err(err) => { error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build()); } }; response = ResponseData::builder_without_request() .response(proxy_response.map(|b| { b.map_err(|e| std::io::Error::other(e.to_string())) .boxed() })) .parallel_fn(async move { pinned_conn.await.unwrap_or_default(); }) .build(); break; }, state = &mut pinned_conn => { if state.is_err() { error_logger.log("Bad gateway: incomplete response").await; return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build()); } }, }; } if !sender.is_closed() { let mut rwlock_write = connections.write().await; rwlock_write.insert(connect_addr, sender); drop(rwlock_write); } Ok(response) } async fn http_proxy_kept_alive( sender: &mut SendRequest>, proxy_request: Request>, error_logger: &ErrorLogger, ) -> Result> { let proxy_response = match sender.send_request(proxy_request).await { Ok(response) => response, Err(err) => { error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; let response = ResponseData::builder_without_request() .response(proxy_response.map(|b| b.map_err(|e| std::io::Error::other(e.to_string())).boxed())) .build(); Ok(response) }