use std::error::Error; use std::str::FromStr; use crate::ferron_common::{ ErrorLogger, HyperUpgraded, RequestData, ResponseData, ServerConfig, ServerModule, ServerModuleHandlers, SocketData, }; use crate::ferron_common::{HyperResponse, WithRuntime}; use async_trait::async_trait; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; use hyper::body::Bytes; use hyper::{header, Request, StatusCode, Uri}; use hyper_tungstenite::HyperWebsocket; use hyper_util::rt::TokioIo; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::net::TcpStream; use tokio::runtime::Handle; pub fn server_module_init( _config: &ServerConfig, ) -> Result, Box> { Ok(Box::new(ForwardProxyModule::new())) } struct ForwardProxyModule; impl ForwardProxyModule { fn new() -> Self { Self } } impl ServerModule for ForwardProxyModule { fn get_handlers(&self, handle: Handle) -> Box { Box::new(ForwardProxyModuleHandlers { handle }) } } struct ForwardProxyModuleHandlers { handle: Handle, } #[async_trait] impl ServerModuleHandlers for ForwardProxyModuleHandlers { async fn request_handler( &mut self, request: RequestData, _config: &ServerConfig, _socket_data: &SocketData, _error_logger: &ErrorLogger, ) -> Result> { Ok(ResponseData::builder(request).build()) } async fn proxy_request_handler( &mut self, request: RequestData, _config: &ServerConfig, _socket_data: &SocketData, error_logger: &ErrorLogger, ) -> Result> { WithRuntime::new(self.handle.clone(), async move { // Code taken from reverse proxy module let (hyper_request, _auth_user, _original_url) = request.into_parts(); let (mut hyper_request_parts, request_body) = hyper_request.into_parts(); match hyper_request_parts.uri.scheme_str() { Some("http") | None => (), _ => { return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_REQUEST) .build(), ); } }; let host = match hyper_request_parts.uri.host() { Some(host) => host, None => { return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_REQUEST) .build(), ); } }; let port = hyper_request_parts.uri.port_u16().unwrap_or(80); let addr = format!("{}:{}", host, port); let stream = match TcpStream::connect(addr).await { Ok(stream) => stream, Err(err) => { match err.kind() { tokio::io::ErrorKind::ConnectionRefused | tokio::io::ErrorKind::NotFound | tokio::io::ErrorKind::HostUnreachable => { error_logger .log(&format!("Service unavailable: {}", err)) .await; return Ok( ResponseData::builder_without_request() .status(StatusCode::SERVICE_UNAVAILABLE) .build(), ); } tokio::io::ErrorKind::TimedOut => { error_logger.log(&format!("Gateway timeout: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::GATEWAY_TIMEOUT) .build(), ); } _ => { error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; } }; match stream.set_nodelay(true) { Ok(_) => (), Err(err) => { error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; let hyper_request_path = hyper_request_parts.uri.path(); hyper_request_parts.uri = Uri::from_str(&format!( "{}{}", hyper_request_path, match hyper_request_parts.uri.query() { Some(query) => format!("?{}", query), None => "".to_string(), } ))?; // Connection header to disable HTTP/1.1 keep-alive hyper_request_parts .headers .insert(header::CONNECTION, "close".parse()?); let proxy_request = Request::from_parts(hyper_request_parts, request_body); http_proxy(stream, proxy_request, error_logger).await }) .await } async fn response_modifying_handler( &mut self, response: HyperResponse, ) -> Result> { Ok(response) } async fn proxy_response_modifying_handler( &mut self, response: HyperResponse, ) -> Result> { Ok(response) } async fn connect_proxy_request_handler( &mut self, upgraded_request: HyperUpgraded, connect_address: &str, _config: &ServerConfig, _socket_data: &SocketData, error_logger: &ErrorLogger, ) -> Result<(), Box> { WithRuntime::new(self.handle.clone(), async move { let mut stream = match TcpStream::connect(connect_address).await { Ok(stream) => stream, Err(err) => { error_logger .log(&format!("Cannot connect to the remote server: {}", err)) .await; return Ok(()); } }; match stream.set_nodelay(true) { Ok(_) => (), Err(err) => { error_logger .log(&format!( "Cannot disable Nagle algorithm when connecting to the remote server: {}", err )) .await; return Ok(()); } }; let mut upgraded = TokioIo::new(upgraded_request); tokio::io::copy_bidirectional(&mut upgraded, &mut stream) .await .unwrap_or_default(); Ok(()) }) .await } fn does_connect_proxy_requests(&mut self) -> bool { true } async fn websocket_request_handler( &mut self, _websocket: HyperWebsocket, _uri: &hyper::Uri, _config: &ServerConfig, _socket_data: &SocketData, _error_logger: &ErrorLogger, ) -> Result<(), Box> { Ok(()) } fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool { false } } async fn http_proxy( stream: impl AsyncRead + AsyncWrite + Send + Unpin + 'static, proxy_request: Request>, error_logger: &ErrorLogger, ) -> Result> { let io = TokioIo::new(stream); let (mut sender, conn) = match hyper::client::conn::http1::handshake(io).await { Ok(data) => data, Err(err) => { error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok( ResponseData::builder_without_request() .status(StatusCode::BAD_GATEWAY) .build(), ); } }; let send_request = sender.send_request(proxy_request); let mut pinned_conn = Box::pin(conn); tokio::pin!(send_request); let response; loop { tokio::select! { biased; proxy_response = &mut send_request => { let proxy_response = match proxy_response { Ok(response) => response, Err(err) => { error_logger.log(&format!("Bad gateway: {}", err)).await; return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build()); } }; response = ResponseData::builder_without_request() .response(proxy_response.map(|b| { b.map_err(|e| std::io::Error::other(e.to_string())) .boxed() })) .parallel_fn(async move { pinned_conn.await.unwrap_or_default(); }) .build(); break; }, state = &mut pinned_conn => { if state.is_err() { error_logger.log("Bad gateway: incomplete response").await; return Ok(ResponseData::builder_without_request().status(StatusCode::BAD_GATEWAY).build()); } }, }; } Ok(response) }