web / ferron /src /modules /non_standard_codes.rs
victorgeek's picture
Upload folder using huggingface_hub
9552aa0 verified
raw
history blame
18.9 kB
use std::error::Error;
use std::sync::Arc;
use std::time::Duration;
use crate::ferron_util::ip_blocklist::IpBlockList;
use crate::ferron_util::ip_match::ip_match;
use crate::ferron_util::match_hostname::match_hostname;
use crate::ferron_util::match_location::match_location;
use crate::ferron_util::non_standard_code_structs::{
NonStandardCode, NonStandardCodesLocationWrap, NonStandardCodesWrap,
};
use crate::ferron_util::ttl_cache::TtlCache;
use crate::ferron_common::{
ErrorLogger, HyperResponse, RequestData, ResponseData, ServerConfig, ServerModule,
ServerModuleHandlers, SocketData,
};
use crate::ferron_common::{HyperUpgraded, WithRuntime};
use async_trait::async_trait;
use base64::{engine::general_purpose, Engine};
use fancy_regex::RegexBuilder;
use http_body_util::{BodyExt, Empty};
use hyper::header::HeaderValue;
use hyper::{header, HeaderMap, Response, StatusCode};
use hyper_tungstenite::HyperWebsocket;
use password_auth::verify_password;
use tokio::runtime::Handle;
use tokio::sync::RwLock;
use yaml_rust2::Yaml;
fn non_standard_codes_config_init(
non_standard_codes_list: &[Yaml],
) -> Result<Vec<NonStandardCode>, anyhow::Error> {
let non_standard_codes_list_iter = non_standard_codes_list.iter();
let mut non_standard_codes_list_vec = Vec::new();
for non_standard_codes_list_entry in non_standard_codes_list_iter {
let status_code: u16 = match non_standard_codes_list_entry["scode"].as_i64() {
Some(scode) => scode.try_into()?,
None => {
return Err(anyhow::anyhow!(
"Non-standard codes must include a status code"
));
}
};
let regex = match non_standard_codes_list_entry["regex"].as_str() {
Some(regex_str) => match RegexBuilder::new(regex_str)
.case_insensitive(cfg!(windows))
.build()
{
Ok(regex) => Some(regex),
Err(err) => {
return Err(anyhow::anyhow!(
"Invalid non-standard code regular expression: {}",
err.to_string()
));
}
},
None => None,
};
let url = non_standard_codes_list_entry["url"]
.as_str()
.map(|s| s.to_string());
if regex.is_none() && url.is_none() {
return Err(anyhow::anyhow!(
"Non-standard codes must either include URL or a matching regular expression"
));
}
let location = non_standard_codes_list_entry["location"]
.as_str()
.map(|s| s.to_string());
let realm = non_standard_codes_list_entry["realm"]
.as_str()
.map(|s| s.to_string());
let disable_brute_force_protection = non_standard_codes_list_entry["disableBruteProtection"]
.as_bool()
.unwrap_or(false);
let user_list = match non_standard_codes_list_entry["userList"].as_vec() {
Some(userlist) => {
let mut new_userlist = Vec::new();
for user_yaml in userlist.iter() {
if let Some(user) = user_yaml.as_str() {
new_userlist.push(user.to_string());
}
}
Some(new_userlist)
}
None => None,
};
let users = match non_standard_codes_list_entry["users"].as_vec() {
Some(users_vec) => {
let mut users_str_vec = Vec::new();
for user_yaml in users_vec.iter() {
if let Some(user) = user_yaml.as_str() {
users_str_vec.push(user);
}
}
let mut users_init = IpBlockList::new();
users_init.load_from_vec(users_str_vec);
Some(users_init)
}
None => None,
};
non_standard_codes_list_vec.push(NonStandardCode::new(
status_code,
url,
regex,
location,
realm,
disable_brute_force_protection,
user_list,
users,
));
}
Ok(non_standard_codes_list_vec)
}
pub fn server_module_init(
config: &ServerConfig,
) -> Result<Box<dyn ServerModule + Send + Sync>, Box<dyn Error + Send + Sync>> {
let mut global_non_standard_codes_list = Vec::new();
let mut host_non_standard_codes_lists = Vec::new();
if let Some(non_standard_codes_list_yaml) = config["global"]["nonStandardCodes"].as_vec() {
global_non_standard_codes_list = non_standard_codes_config_init(non_standard_codes_list_yaml)?;
}
if let Some(hosts) = config["hosts"].as_vec() {
for host_yaml in hosts.iter() {
let domain = host_yaml["domain"].as_str().map(String::from);
let ip = host_yaml["ip"].as_str().map(String::from);
let mut locations = Vec::new();
if let Some(locations_yaml) = host_yaml["locations"].as_vec() {
for location_yaml in locations_yaml.iter() {
if let Some(path_str) = location_yaml["path"].as_str() {
let path = String::from(path_str);
if let Some(non_standard_codes_list_yaml) = location_yaml["nonStandardCodes"].as_vec() {
locations.push(NonStandardCodesLocationWrap::new(
path,
non_standard_codes_config_init(non_standard_codes_list_yaml)?,
));
}
}
}
}
if let Some(non_standard_codes_list_yaml) = host_yaml["nonStandardCodes"].as_vec() {
host_non_standard_codes_lists.push(NonStandardCodesWrap::new(
domain,
ip,
non_standard_codes_config_init(non_standard_codes_list_yaml)?,
locations,
));
} else if !locations.is_empty() {
host_non_standard_codes_lists.push(NonStandardCodesWrap::new(
domain,
ip,
Vec::new(),
locations,
));
}
}
}
Ok(Box::new(NonStandardCodesModule::new(
Arc::new(global_non_standard_codes_list),
Arc::new(host_non_standard_codes_lists),
Arc::new(RwLock::new(TtlCache::new(Duration::new(300, 0)))),
)))
}
struct NonStandardCodesModule {
global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
}
impl NonStandardCodesModule {
fn new(
global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
) -> Self {
Self {
global_non_standard_codes_list,
host_non_standard_codes_lists,
brute_force_db,
}
}
}
impl ServerModule for NonStandardCodesModule {
fn get_handlers(&self, handle: Handle) -> Box<dyn ServerModuleHandlers + Send> {
Box::new(NonStandardCodesModuleHandlers {
global_non_standard_codes_list: self.global_non_standard_codes_list.clone(),
host_non_standard_codes_lists: self.host_non_standard_codes_lists.clone(),
brute_force_db: self.brute_force_db.clone(),
handle,
})
}
}
fn parse_basic_auth(auth_str: &str) -> Option<(String, String)> {
if let Some(base64_credentials) = auth_str.strip_prefix("Basic ") {
if let Ok(decoded) = general_purpose::STANDARD.decode(base64_credentials) {
if let Ok(decoded_str) = std::str::from_utf8(&decoded) {
let parts: Vec<&str> = decoded_str.splitn(2, ':').collect();
if parts.len() == 2 {
return Some((parts[0].to_string(), parts[1].to_string()));
}
}
}
}
None
}
struct NonStandardCodesModuleHandlers {
global_non_standard_codes_list: Arc<Vec<NonStandardCode>>,
host_non_standard_codes_lists: Arc<Vec<NonStandardCodesWrap>>,
brute_force_db: Arc<RwLock<TtlCache<String, u8>>>,
handle: Handle,
}
#[async_trait]
impl ServerModuleHandlers for NonStandardCodesModuleHandlers {
async fn request_handler(
&mut self,
request: RequestData,
config: &ServerConfig,
socket_data: &SocketData,
error_logger: &ErrorLogger,
) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
WithRuntime::new(self.handle.clone(), async move {
let hyper_request = request.get_hyper_request();
let global_non_standard_codes_list = self.global_non_standard_codes_list.iter();
let empty_vector = Vec::new();
let another_empty_vector = Vec::new();
let mut host_non_standard_codes_list = empty_vector.iter();
let mut location_non_standard_codes_list = another_empty_vector.iter();
// Should have used a HashMap instead of iterating over an array for better performance...
for host_non_standard_codes_list_wrap in self.host_non_standard_codes_lists.iter() {
if match_hostname(
match &host_non_standard_codes_list_wrap.domain {
Some(value) => Some(value as &str),
None => None,
},
match hyper_request.headers().get(header::HOST) {
Some(value) => value.to_str().ok(),
None => None,
},
) && match &host_non_standard_codes_list_wrap.ip {
Some(value) => ip_match(value as &str, socket_data.remote_addr.ip()),
None => true,
} {
host_non_standard_codes_list =
host_non_standard_codes_list_wrap.non_standard_codes.iter();
if let Ok(path_decoded) = urlencoding::decode(
request
.get_original_url()
.unwrap_or(request.get_hyper_request().uri())
.path(),
) {
for location_wrap in host_non_standard_codes_list_wrap.locations.iter() {
if match_location(&location_wrap.path, &path_decoded) {
location_non_standard_codes_list = location_wrap.non_standard_codes.iter();
break;
}
}
}
break;
}
}
let combined_non_standard_codes_list = global_non_standard_codes_list
.chain(host_non_standard_codes_list)
.chain(location_non_standard_codes_list);
let request_url = format!(
"{}{}",
hyper_request.uri().path(),
match hyper_request.uri().query() {
Some(query) => format!("?{}", query),
None => String::from(""),
}
);
let mut auth_user = None;
for non_standard_code in combined_non_standard_codes_list {
let mut redirect_url = None;
let mut url_matched = false;
if let Some(users) = &non_standard_code.users {
if !users.is_blocked(socket_data.remote_addr.ip()) {
// Don't process this non-standard code
continue;
}
}
if let Some(regex) = &non_standard_code.regex {
let regex_match_option = regex.find(&request_url)?;
if let Some(regex_match) = regex_match_option {
url_matched = true;
if non_standard_code.status_code == 301
|| non_standard_code.status_code == 302
|| non_standard_code.status_code == 307
|| non_standard_code.status_code == 308
{
let matched_text = regex_match.as_str();
if let Some(location) = &non_standard_code.location {
redirect_url = Some(regex.replace(matched_text, location).to_string());
}
}
}
}
if !url_matched {
if let Some(url) = &non_standard_code.url {
if url == hyper_request.uri().path() {
url_matched = true;
if non_standard_code.status_code == 301
|| non_standard_code.status_code == 302
|| non_standard_code.status_code == 307
|| non_standard_code.status_code == 308
{
if let Some(location) = &non_standard_code.location {
redirect_url = Some(format!(
"{}{}",
location,
match hyper_request.uri().query() {
Some(query) => format!("?{}", query),
None => String::from(""),
}
));
}
}
}
}
}
if url_matched {
match non_standard_code.status_code {
301 | 302 | 307 | 308 => {
return Ok(
ResponseData::builder(request)
.response(
Response::builder()
.status(StatusCode::from_u16(non_standard_code.status_code)?)
.header(header::LOCATION, redirect_url.unwrap_or(request_url))
.body(Empty::new().map_err(|e| match e {}).boxed())?,
)
.build(),
);
}
401 => {
let brute_force_db_key = socket_data.remote_addr.ip().to_string();
if !non_standard_code.disable_brute_force_protection {
let rwlock_read = self.brute_force_db.read().await;
let current_attempts = rwlock_read.get(&brute_force_db_key).unwrap_or(0);
if current_attempts >= 10 {
error_logger
.log(&format!(
"Too many failed authorization attempts for client \"{}\"",
socket_data.remote_addr.ip()
))
.await;
return Ok(
ResponseData::builder(request)
.status(StatusCode::TOO_MANY_REQUESTS)
.build(),
);
}
}
let mut header_map = HeaderMap::new();
header_map.insert(
header::WWW_AUTHENTICATE,
HeaderValue::from_str(&format!(
"Basic realm=\"{}\", charset=\"UTF-8\"",
non_standard_code
.realm
.clone()
.unwrap_or("Ferron HTTP Basic Authorization".to_string())
.replace("\\", "\\\\")
.replace("\"", "\\\"")
))?,
);
if let Some(authorization_header_value) =
hyper_request.headers().get(header::AUTHORIZATION)
{
let authorization_str = match authorization_header_value.to_str() {
Ok(str) => str,
Err(_) => {
return Ok(
ResponseData::builder(request)
.status(StatusCode::BAD_REQUEST)
.build(),
);
}
};
if let Some((username, password)) = parse_basic_auth(authorization_str) {
if let Some(users_vec_yaml) = config["users"].as_vec() {
let mut authorized_user = None;
for user_yaml in users_vec_yaml {
if let Some(username_db) = user_yaml["name"].as_str() {
if username_db != username {
continue;
}
if let Some(user_list) = &non_standard_code.user_list {
if !user_list.contains(&username) {
continue;
}
}
if let Some(password_hash_db) = user_yaml["pass"].as_str() {
let password_cloned = password.clone();
let password_hash_db_cloned = password_hash_db.to_string();
// Offload verifying the hash into a separate blocking thread.
let password_valid = tokio::task::spawn_blocking(move || {
verify_password(password_cloned, &password_hash_db_cloned).is_ok()
})
.await?;
if password_valid {
authorized_user = Some(&username);
break;
}
}
}
}
if let Some(authorized_user) = authorized_user {
auth_user = Some(authorized_user.to_owned());
continue;
}
}
if !non_standard_code.disable_brute_force_protection {
let mut rwlock_write = self.brute_force_db.write().await;
rwlock_write.cleanup();
let current_attempts = rwlock_write.get(&brute_force_db_key).unwrap_or(0);
rwlock_write.insert(brute_force_db_key, current_attempts + 1);
}
error_logger
.log(&format!(
"Authorization failed for user \"{}\" and client \"{}\"",
username,
socket_data.remote_addr.ip()
))
.await;
}
}
return Ok(
ResponseData::builder(request)
.status(StatusCode::UNAUTHORIZED)
.headers(header_map)
.build(),
);
}
_ => {
return Ok(
ResponseData::builder(request)
.status(StatusCode::from_u16(non_standard_code.status_code)?)
.build(),
)
}
}
}
}
if auth_user.is_some() {
let (hyper_request, _, original_url) = request.into_parts();
Ok(ResponseData::builder(RequestData::new(hyper_request, auth_user, original_url)).build())
} else {
Ok(ResponseData::builder(request).build())
}
})
.await
}
async fn proxy_request_handler(
&mut self,
request: RequestData,
_config: &ServerConfig,
_socket_data: &SocketData,
_error_logger: &ErrorLogger,
) -> Result<ResponseData, Box<dyn Error + Send + Sync>> {
Ok(ResponseData::builder(request).build())
}
async fn response_modifying_handler(
&mut self,
response: HyperResponse,
) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
Ok(response)
}
async fn proxy_response_modifying_handler(
&mut self,
response: HyperResponse,
) -> Result<HyperResponse, Box<dyn Error + Send + Sync>> {
Ok(response)
}
async fn connect_proxy_request_handler(
&mut self,
_upgraded_request: HyperUpgraded,
_connect_address: &str,
_config: &ServerConfig,
_socket_data: &SocketData,
_error_logger: &ErrorLogger,
) -> Result<(), Box<dyn Error + Send + Sync>> {
Ok(())
}
fn does_connect_proxy_requests(&mut self) -> bool {
false
}
async fn websocket_request_handler(
&mut self,
_websocket: HyperWebsocket,
_uri: &hyper::Uri,
_config: &ServerConfig,
_socket_data: &SocketData,
_error_logger: &ErrorLogger,
) -> Result<(), Box<dyn Error + Send + Sync>> {
Ok(())
}
fn does_websocket_requests(&mut self, _config: &ServerConfig, _socket_data: &SocketData) -> bool {
false
}
}