use std::error::Error; use std::sync::Arc; use std::time::Duration; use crate::ferron_util::ip_blocklist::IpBlockList; use crate::ferron_util::ip_match::ip_match; use crate::ferron_util::match_hostname::match_hostname; use crate::ferron_util::match_location::match_location; use crate::ferron_util::non_standard_code_structs::{ NonStandardCode, NonStandardCodesLocationWrap, NonStandardCodesWrap, }; use crate::ferron_util::ttl_cache::TtlCache; use crate::ferron_common::{ ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule, ServerModuleHandlers, SocketData, }; use crate::ferron_common::{HyperUpgraded, WithRuntime}; use async_trait::async_trait; use base64::{engine::general_purpose, Engine}; use fancy_regex::RegexBuilder; use http_body_util::{BodyExt, Empty}; use hyper::header::HeaderValue; use hyper::{header, HeaderMap, Response, StatusCode}; use hyper_tungstenite::HyperWebsocket; use password_auth::verify_password; use tokio::runtime::Handle; use tokio::sync::RwLock; use yaml_rust2::Yaml; fn non_standard_codes_config_init( non_standard_codes_list: &[Yaml], ) -> Result, anyhow::Error> { let non_standard_codes_list_iter = non_standard_codes_list.iter(); let mut non_standard_codes_list_vec = Vec::new(); for non_standard_codes_list_entry in non_standard_codes_list_iter { let status_code: u16 = match non_standard_codes_list_entry["scode"].as_i64() { Some(scode) => scode.try_into()?, None => { return Err(anyhow::anyhow!( "Non-standard codes must include a status code" )); } }; let regex = match non_standard_codes_list_entry["regex"].as_str() { Some(regex_str) => match RegexBuilder::new(regex_str) .case_insensitive(cfg!(windows)) .build() { Ok(regex) => Some(regex), Err(err) => { return Err(anyhow::anyhow!( "Invalid non-standard code regular expression: {}", err.to_string() )); } }, None => None, }; let url = non_standard_codes_list_entry["url"] .as_str() .map(|s| s.to_string()); if regex.is_none() && url.is_none() { return Err(anyhow::anyhow!( "Non-standard codes must either include URL or a matching regular expression" )); } let location = non_standard_codes_list_entry["location"] .as_str() .map(|s| s.to_string()); let realm = non_standard_codes_list_entry["realm"] .as_str() .map(|s| s.to_string()); let disable_brute_force_protection = non_standard_codes_list_entry["disableBruteProtection"] .as_bool() .unwrap_or(false); let user_list = match non_standard_codes_list_entry["userList"].as_vec() { Some(userlist) => { let mut new_userlist = Vec::new(); for user_yaml in userlist.iter() { if let Some(user) = user_yaml.as_str() { new_userlist.push(user.to_string()); } } Some(new_userlist) } None => None, }; let users = match non_standard_codes_list_entry["users"].as_vec() { Some(users_vec) => { let mut users_str_vec = Vec::new(); for user_yaml in users_vec.iter() { if let Some(user) = user_yaml.as_str() { users_str_vec.push(user); } } let mut users_init = IpBlockList::new(); users_init.load_from_vec(users_str_vec); Some(users_init) } None => None, }; non_standard_codes_list_vec.push(NonStandardCode::new( status_code, url, regex, location, realm, disable_brute_force_protection, user_list, users, )); } Ok(non_standard_codes_list_vec) } pub fn server_module_init( config: &ServerConfig, ) -> Result, Box> { let mut global_non_standard_codes_list = Vec::new(); let mut host_non_standard_codes_lists = Vec::new(); if let Some(non_standard_codes_list_yaml) = config["global"]["nonStandardCodes"].as_vec() { global_non_standard_codes_list = non_standard_codes_config_init(non_standard_codes_list_yaml)?; } if let Some(hosts) = config["hosts"].as_vec() { for host_yaml in hosts.iter() { let domain = host_yaml["domain"].as_str().map(String::from); let ip = host_yaml["ip"].as_str().map(String::from); let mut locations = Vec::new(); if let Some(locations_yaml) = host_yaml["locations"].as_vec() { for location_yaml in locations_yaml.iter() { if let Some(path_str) = location_yaml["path"].as_str() { let path = String::from(path_str); if let Some(non_standard_codes_list_yaml) = location_yaml["nonStandardCodes"].as_vec() { locations.push(NonStandardCodesLocationWrap::new( path, non_standard_codes_config_init(non_standard_codes_list_yaml)?, )); } } } } if let Some(non_standard_codes_list_yaml) = host_yaml["nonStandardCodes"].as_vec() { host_non_standard_codes_lists.push(NonStandardCodesWrap::new( domain, ip, non_standard_codes_config_init(non_standard_codes_list_yaml)?, locations, )); } else if !locations.is_empty() { host_non_standard_codes_lists.push(NonStandardCodesWrap::new( domain, ip, Vec::new(), locations, )); } } } Ok(Box::new(NonStandardCodesModule::new( Arc::new(global_non_standard_codes_list), Arc::new(host_non_standard_codes_lists), Arc::new(RwLock::new(TtlCache::new(Duration::new(300, 0)))), ))) } struct NonStandardCodesModule { global_non_standard_codes_list: Arc>, host_non_standard_codes_lists: Arc>, brute_force_db: Arc>>, } impl NonStandardCodesModule { fn new( global_non_standard_codes_list: Arc>, host_non_standard_codes_lists: Arc>, brute_force_db: Arc>>, ) -> Self { Self { global_non_standard_codes_list, host_non_standard_codes_lists, brute_force_db, } } } impl ServerModule for NonStandardCodesModule { fn get_handlers(&self, handle: Handle) -> Box { Box::new(NonStandardCodesModuleHandlers { global_non_standard_codes_list: self.global_non_standard_codes_list.clone(), host_non_standard_codes_lists: self.host_non_standard_codes_lists.clone(), brute_force_db: self.brute_force_db.clone(), handle, }) } } fn parse_basic_auth(auth_str: &str) -> Option<(String, String)> { if let Some(base64_credentials) = auth_str.strip_prefix("Basic ") { if let Ok(decoded) = general_purpose::STANDARD.decode(base64_credentials) { if let Ok(decoded_str) = std::str::from_utf8(&decoded) { let parts: Vec<&str> = decoded_str.splitn(2, ':').collect(); if parts.len() == 2 { return Some((parts[0].to_string(), parts[1].to_string())); } } } } None } struct NonStandardCodesModuleHandlers { global_non_standard_codes_list: Arc>, host_non_standard_codes_lists: Arc>, brute_force_db: Arc>>, handle: Handle, } #[async_trait] impl ServerModuleHandlers for NonStandardCodesModuleHandlers { async fn request_handler( &mut self, request: RequestData, config: &ServerConfig, socket_data: &SocketData, error_logger: &ErrorLogger, ) -> Result> { WithRuntime::new(self.handle.clone(), async move { let hyper_request = request.get_hyper_request(); let global_non_standard_codes_list = self.global_non_standard_codes_list.iter(); let empty_vector = Vec::new(); let another_empty_vector = Vec::new(); let mut host_non_standard_codes_list = empty_vector.iter(); let mut location_non_standard_codes_list = another_empty_vector.iter(); // Should have used a HashMap instead of iterating over an array for better performance... for host_non_standard_codes_list_wrap in self.host_non_standard_codes_lists.iter() { if match_hostname( match &host_non_standard_codes_list_wrap.domain { Some(value) => Some(value as &str), None => None, }, match hyper_request.headers().get(header::HOST) { Some(value) => value.to_str().ok(), None => None, }, ) && match &host_non_standard_codes_list_wrap.ip { Some(value) => ip_match(value as &str, socket_data.remote_addr.ip()), None => true, } { host_non_standard_codes_list = host_non_standard_codes_list_wrap.non_standard_codes.iter(); if let Ok(path_decoded) = urlencoding::decode( request .get_original_url() .unwrap_or(request.get_hyper_request().uri()) .path(), ) { for location_wrap in host_non_standard_codes_list_wrap.locations.iter() { if match_location(&location_wrap.path, &path_decoded) { location_non_standard_codes_list = location_wrap.non_standard_codes.iter(); break; } } } break; } } let combined_non_standard_codes_list = global_non_standard_codes_list .chain(host_non_standard_codes_list) .chain(location_non_standard_codes_list); let request_url = format!( "{}{}", hyper_request.uri().path(), match hyper_request.uri().query() { Some(query) => format!("?{}", query), None => String::from(""), } ); let mut auth_user = None; for non_standard_code in combined_non_standard_codes_list { let mut redirect_url = None; let mut url_matched = false; if let Some(users) = &non_standard_code.users { if !users.is_blocked(socket_data.remote_addr.ip()) { // Don't process this non-standard code continue; } } if let Some(regex) = &non_standard_code.regex { let regex_match_option = regex.find(&request_url)?; if let Some(regex_match) = regex_match_option { url_matched = true; if non_standard_code.status_code == 301 || non_standard_code.status_code == 302 || non_standard_code.status_code == 307 || non_standard_code.status_code == 308 { let matched_text = regex_match.as_str(); if let Some(location) = &non_standard_code.location { redirect_url = Some(regex.replace(matched_text, location).to_string()); } } } } if !url_matched { if let Some(url) = &non_standard_code.url { if url == hyper_request.uri().path() { url_matched = true; if non_standard_code.status_code == 301 || non_standard_code.status_code == 302 || non_standard_code.status_code == 307 || non_standard_code.status_code == 308 { if let Some(location) = &non_standard_code.location { redirect_url = Some(format!( "{}{}", location, match hyper_request.uri().query() { Some(query) => format!("?{}", query), None => String::from(""), } )); } } } } } if url_matched { match non_standard_code.status_code { 301 | 302 | 307 | 308 => { return Ok( ResponseData::builder(request) .response( Response::builder() .status(StatusCode::from_u16(non_standard_code.status_code)?) .header(header::LOCATION, redirect_url.unwrap_or(request_url)) .body(Empty::new().map_err(|e| match e {}).boxed())?, ) .build(), ); } 401 => { let brute_force_db_key = socket_data.remote_addr.ip().to_string(); if !non_standard_code.disable_brute_force_protection { let rwlock_read = self.brute_force_db.read().await; let current_attempts = rwlock_read.get(&brute_force_db_key).unwrap_or(0); if current_attempts >= 10 { error_logger .log(&format!( "Too many failed authorization attempts for client \"{}\"", socket_data.remote_addr.ip() )) .await; return Ok( ResponseData::builder(request) .status(StatusCode::TOO_MANY_REQUESTS) .build(), ); } } let mut header_map = HeaderMap::new(); header_map.insert( header::WWW_AUTHENTICATE, HeaderValue::from_str(&format!( "Basic realm=\"{}\", charset=\"UTF-8\"", non_standard_code .realm .clone() .unwrap_or("Ferron HTTP Basic Authorization".to_string()) .replace("\\", "\\\\") .replace("\"", "\\\"") ))?, ); if let Some(authorization_header_value) = hyper_request.headers().get(header::AUTHORIZATION) { let authorization_str = match authorization_header_value.to_str() { Ok(str) => str, Err(_) => { return Ok( ResponseData::builder(request) .status(StatusCode::BAD_REQUEST) .build(), ); } }; if let Some((username, password)) = parse_basic_auth(authorization_str) { if let Some(users_vec_yaml) = config["users"].as_vec() { let mut authorized_user = None; for user_yaml in users_vec_yaml { if let Some(username_db) = user_yaml["name"].as_str() { if username_db != username { continue; } if let Some(user_list) = &non_standard_code.user_list { if !user_list.contains(&username) { continue; } } if let Some(password_hash_db) = user_yaml["pass"].as_str() { let password_cloned = password.clone(); let password_hash_db_cloned = password_hash_db.to_string(); // Offload verifying the hash into a separate blocking thread. let password_valid = tokio::task::spawn_blocking(move || { verify_password(password_cloned, &password_hash_db_cloned).is_ok() }) .await?; if password_valid { authorized_user = Some(&username); break; } } } } if let Some(authorized_user) = authorized_user { auth_user = Some(authorized_user.to_owned()); continue; } } if !non_standard_code.disable_brute_force_protection { let mut rwlock_write = self.brute_force_db.write().await; rwlock_write.cleanup(); let current_attempts = rwlock_write.get(&brute_force_db_key).unwrap_or(0); rwlock_write.insert(brute_force_db_key, current_attempts + 1); } error_logger .log(&format!( "Authorization failed for user \"{}\" and client \"{}\"", username, socket_data.remote_addr.ip() )) .await; } } return Ok( ResponseData::builder(request) .status(StatusCode::UNAUTHORIZED) .headers(header_map) .build(), ); } _ => { return Ok( ResponseData::builder(request) .status(StatusCode::from_u16(non_standard_code.status_code)?) .build(), ) } } } } if auth_user.is_some() { let (hyper_request, _, original_url) = request.into_parts(); Ok(ResponseData::builder(RequestData::new(hyper_request, auth_user, original_url)).build()) } else { Ok(ResponseData::builder(request).build()) } }) .await } async fn proxy_request_handler( &mut self, request: RequestData, _config: &ServerConfig, _socket_data: &SocketData, _error_logger: &ErrorLogger, ) -> Result> { Ok(ResponseData::builder(request).build()) } async fn response_modifying_handler( &mut self, response: HyperResponse, ) -> Result> { Ok(response) } async fn proxy_response_modifying_handler( &mut self, response: HyperResponse, ) -> Result> { Ok(response) } async fn connect_proxy_request_handler( &mut self, _upgraded_request: HyperUpgraded, _connect_address: &str, _config: &ServerConfig, _socket_data: &SocketData, _error_logger: &ErrorLogger, ) -> Result<(), Box> { Ok(()) } fn does_connect_proxy_requests(&mut self) -> bool { false } async fn websocket_request_handler( &mut self, _websocket: HyperWebsocket, _uri: &hyper::Uri, _config: &ServerConfig, _socket_data: &SocketData, _error_logger: &ErrorLogger, ) -> Result<(), Box> { Ok(()) } fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool { false } }