refactor: unify error handling, graceful shutdown, and constants across framework

This commit is contained in:
2026-04-29 23:23:46 +02:00
committed by ForgeCode
parent db7b26864b
commit f37befacdd
14 changed files with 1990 additions and 182 deletions
+50 -7
View File
@@ -1,25 +1,36 @@
//! API Key authentication middleware.
//!
//! Validates requests by checking for a valid API key in the X-API-Key header.
use crate::{
Responder,
middleware::{Middleware, MiddlewareFuture, MiddlewareResult},
};
use http::Request;
use hyper::body::Incoming;
use http::{Request, Response};
use hyper::body::{Bytes, Incoming};
use log::warn;
/// Middleware that validates API key authentication via X-API-Key header.
pub struct ApiKeyMiddleware {
api_key: String,
}
impl ApiKeyMiddleware {
/// Creates a new ApiKeyMiddleware with the specified expected API key.
pub fn new(api_key: &str) -> Self {
Self {
api_key: api_key.to_string(),
}
}
/// Checks if the given API key is invalid.
pub fn is_invalid_key(&self, key: &str) -> bool {
key != self.api_key
}
}
impl Middleware for ApiKeyMiddleware {
fn run<'a>(&'a self, req: Request<Incoming>) -> MiddlewareFuture<'a> {
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
let expected_key = self.api_key.clone();
Box::pin(async move {
@@ -28,15 +39,47 @@ impl Middleware for ApiKeyMiddleware {
if header == expected_key.as_str() {
MiddlewareResult::Continue(req)
} else {
warn!("X-API-Key wrong");
MiddlewareResult::Respond(Responder::unauthorized().unwrap())
warn!("X-API-Key validation failed for request");
// Return a default unauthorized response if Responder fails
let response = Responder::unauthorized()
.unwrap_or_else(|_| {
// Fallback to a basic unauthorized response
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(http_body_util::Full::new(
Bytes::from("Unauthorized")
))
.expect("Failed to build fallback response")
});
MiddlewareResult::Respond(response)
}
}
None => {
warn!("X-API-Key missing");
MiddlewareResult::Respond(Responder::unauthorized().unwrap())
warn!("X-API-Key header missing from request");
let response = Responder::unauthorized()
.unwrap_or_else(|_| {
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(http_body_util::Full::new(
Bytes::from("Unauthorized")
))
.expect("Failed to build fallback response")
});
MiddlewareResult::Respond(response)
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::Request;
#[test]
fn test_api_key_middleware_new() {
let middleware = ApiKeyMiddleware::new("test-key");
assert_eq!(middleware.api_key, "test-key");
}
}
+193 -19
View File
@@ -1,56 +1,230 @@
//! IP address filtering middleware.
//!
//! Allows or denies requests based on the client's IP address.
//! Supports allowlisting specific IPs and optionally allows private network ranges.
use crate::{
Responder,
error::{ServerError, Result},
middleware::{Middleware, MiddlewareFuture, MiddlewareResult},
};
use http::Request;
use hyper::body::Incoming;
use http::{Request, Response};
use hyper::body::{Bytes, Incoming};
use log::warn;
use std::collections::HashSet;
use std::net::IpAddr;
/// Maximum number of IPs that can be configured in the allow list.
const MAX_ALLOWED_IPS: usize = 1000;
/// Middleware that filters requests based on client IP address.
///
/// Uses a `HashSet` for O(1) lookups instead of O(n) with a Vec.
pub struct IpFilterMiddleware {
allowed_ips: Vec<String>,
allowed_ips: HashSet<IpAddr>,
allow_private: bool,
}
impl IpFilterMiddleware {
pub fn new(allowed_ips: Vec<String>, allow_private: bool) -> Self {
/// Creates a new IpFilterMiddleware.
///
/// Validates and parses IP addresses at construction time for optimal runtime performance.
///
/// # Arguments
/// * `allowed_ips` - List of IP addresses to allow (empty list allows all)
/// * `allow_private` - Whether to allow private network ranges
///
/// # Errors
/// Returns an error if any IP address cannot be parsed or if too many IPs are provided.
pub fn new(allowed_ips: Vec<String>, allow_private: bool) -> Result<Self> {
if allowed_ips.len() > MAX_ALLOWED_IPS {
return Err(ServerError::validation(
"allowed_ips",
format!("Too many IPs specified (max {})", MAX_ALLOWED_IPS),
));
}
let mut allowed_set = HashSet::with_capacity(allowed_ips.len());
for ip_str in allowed_ips {
let ip: IpAddr = ip_str.parse().map_err(|_| {
ServerError::validation("allowed_ips", format!("Invalid IP address: {}", ip_str))
})?;
allowed_set.insert(ip);
}
Ok(Self {
allowed_ips: allowed_set,
allow_private,
})
}
/// Creates a new IpFilterMiddleware without validation (for testing).
#[cfg(test)]
pub fn new_unchecked(allowed_ips: Vec<String>, allow_private: bool) -> Self {
let allowed_set: HashSet<IpAddr> = allowed_ips
.into_iter()
.filter_map(|s| s.parse().ok())
.collect();
Self {
allowed_ips,
allowed_ips: allowed_set,
allow_private,
}
}
fn is_authorized(&self, ip: &IpAddr) -> bool {
/// Checks if the given IP address is authorized.
///
/// Performance: O(1) lookup using HashSet.
pub fn is_authorized(&self, ip: &IpAddr) -> bool {
// Check private ranges first (fast path for local networks)
// Note: Only IPv4 has is_private() method
if self.allow_private {
let is_private = match ip {
IpAddr::V4(ip4) => ip4.is_private(),
IpAddr::V6(_) => false,
};
if is_private {
return true;
if let IpAddr::V4(ipv4) = ip {
if ipv4.is_private() {
return true;
}
}
}
// Empty allowlist means "allow all"
if self.allowed_ips.is_empty() {
return true;
}
self.allowed_ips.iter().any(|auth| &ip.to_string() == auth)
// O(1) lookup
self.allowed_ips.contains(ip)
}
}
impl Middleware for IpFilterMiddleware {
fn run<'a>(&'a self, req: Request<Incoming>) -> MiddlewareFuture<'a> {
Box::pin(async move {
let client_ip = req.extensions().get::<IpAddr>();
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
let client_ip = req.extensions().get::<IpAddr>().copied();
Box::pin(async move {
match client_ip {
Some(ip) if self.is_authorized(ip) => MiddlewareResult::Continue(req),
Some(ip) if self.is_authorized(&ip) => MiddlewareResult::Continue(req),
_ => {
warn!("Unauthorized IP");
MiddlewareResult::Respond(Responder::unauthorized().unwrap())
warn!("Unauthorized IP access attempt");
let response = Responder::unauthorized()
.unwrap_or_else(|_| {
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(http_body_util::Full::new(
Bytes::from("Unauthorized")
))
.expect("Failed to build fallback response")
});
MiddlewareResult::Respond(response)
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_validates_ip() {
// Valid IPs should work
let result = IpFilterMiddleware::new(vec![], false);
assert!(result.is_ok());
let result = IpFilterMiddleware::new(
vec!["192.168.1.1".to_string(), "10.0.0.1".to_string()],
false
);
assert!(result.is_ok());
}
#[test]
fn test_new_rejects_invalid_ip() {
let result = IpFilterMiddleware::new(
vec!["not-an-ip".to_string()],
false
);
assert!(result.is_err());
}
#[test]
fn test_new_rejects_too_many_ips() {
let ips: Vec<String> = (0..MAX_ALLOWED_IPS + 1)
.map(|i| format!("192.168.{}.{}", i / 256, i % 256))
.collect();
let result = IpFilterMiddleware::new(ips, false);
assert!(result.is_err());
}
#[test]
fn test_empty_allow_list_allows_all() {
let middleware = IpFilterMiddleware::new_unchecked(vec![], false);
let ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(middleware.is_authorized(&ip));
}
#[test]
fn test_specific_ip_in_allow_list() {
let middleware = IpFilterMiddleware::new_unchecked(
vec!["192.168.1.100".to_string()],
false
);
let allowed_ip: IpAddr = "192.168.1.100".parse().unwrap();
let denied_ip: IpAddr = "192.168.1.200".parse().unwrap();
assert!(middleware.is_authorized(&allowed_ip));
assert!(!middleware.is_authorized(&denied_ip));
}
#[test]
fn test_private_ip_with_allow_private() {
let middleware = IpFilterMiddleware::new_unchecked(vec![], true);
let private_ip: IpAddr = "192.168.1.1".parse().unwrap();
assert!(middleware.is_authorized(&private_ip));
}
#[test]
fn test_private_ip_without_allow_private() {
let middleware = IpFilterMiddleware::new_unchecked(vec![], false);
let private_ip: IpAddr = "192.168.1.1".parse().unwrap();
let public_ip: IpAddr = "8.8.8.8".parse().unwrap();
assert!(middleware.is_authorized(&private_ip));
assert!(middleware.is_authorized(&public_ip));
}
#[test]
fn test_multiple_allowed_ips() {
let middleware = IpFilterMiddleware::new_unchecked(
vec![
"192.168.1.100".to_string(),
"192.168.1.200".to_string(),
],
false,
);
let ip1: IpAddr = "192.168.1.100".parse().unwrap();
let ip2: IpAddr = "192.168.1.200".parse().unwrap();
let ip3: IpAddr = "192.168.1.150".parse().unwrap();
assert!(middleware.is_authorized(&ip1));
assert!(middleware.is_authorized(&ip2));
assert!(!middleware.is_authorized(&ip3));
}
#[test]
fn test_ipv6_support() {
let middleware = IpFilterMiddleware::new_unchecked(
vec!["::1".to_string()],
false,
);
let ipv6_local: IpAddr = "::1".parse().unwrap();
let ipv6_other: IpAddr = "::2".parse().unwrap();
assert!(middleware.is_authorized(&ipv6_local));
assert!(!middleware.is_authorized(&ipv6_other));
}
}
+115 -41
View File
@@ -1,34 +1,43 @@
//! JWT authentication middleware.
//!
//! Validates JWT tokens using RS256 algorithm with support for
//! Bearer tokens in Authorization header and access_token cookies.
use crate::{
constants::{BEARER_PREFIX, FILE_EXTENSIONS, JWT_COOKIE_NAME},
error::{ServerError, Result},
Responder,
middleware::{Middleware, MiddlewareFuture, MiddlewareResult, auth_types::Claims},
};
use http::Request;
use hyper::body::Incoming;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use log::error;
/// Common file extensions that indicate a file path
const FILE_EXTENSIONS: &[&str] = &[
".html", ".htm", ".js", ".mjs", ".css", ".scss", ".sass", ".less",
".json", ".xml", ".yaml", ".yml", ".toml", ".env",
".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".webp", ".avif", ".bmp",
".woff", ".woff2", ".ttf", ".eot", ".otf", ".css",
".pdf", ".txt", ".md", ".csv", ".xlsx", ".docx", ".zip", ".tar", ".gz",
".mp4", ".webm", ".mp3", ".wav", ".ogg", ".flac",
".wasm", ".br",
];
use log::warn;
/// JWT authentication middleware.
///
/// Validates JWT tokens using RS256 algorithm. Supports both
/// Bearer tokens in Authorization header and access_token cookies.
pub struct JwtMiddleware {
decoding_key: DecodingKey,
public_routes: Vec<String>,
}
impl JwtMiddleware {
/// Creates a new JwtMiddleware with the given RSA public key.
///
/// # Arguments
/// * `public_key` - RSA public key in PEM format
/// * `public_routes` - List of routes that don't require authentication
pub fn new(
public_key: &str,
public_routes: Vec<String>,
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
let decoding_key = DecodingKey::from_rsa_pem(public_key.as_bytes())?;
) -> Result<Self> {
let decoding_key = DecodingKey::from_rsa_pem(public_key.as_bytes())
.map_err(|e| ServerError::jwt_with_source(
"Failed to parse RSA public key",
Box::new(e),
))?;
Ok(Self {
decoding_key,
@@ -37,7 +46,9 @@ impl JwtMiddleware {
}
/// Determines if the given path has a file extension.
/// Returns true if the last segment of the path contains a dot followed by a known extension.
///
/// Returns true if the last segment of the path contains a dot
/// followed by a known extension.
pub fn has_file_extension(path: &str) -> bool {
// Get the last segment of the path (after the last '/')
if let Some(segment) = path.rsplit('/').next() {
@@ -51,6 +62,7 @@ impl JwtMiddleware {
}
/// Checks if a request path is a public route.
///
/// - For routes WITH a file extension: exact match required
/// - For routes WITHOUT a file extension: prefix match (allows all subpaths)
/// - Special case: "/" as public route allows everything
@@ -76,27 +88,33 @@ impl JwtMiddleware {
})
}
/// Validates the request and extracts claims from the JWT token.
fn validate_request(
&self,
req: &Request<Incoming>,
) -> Result<Claims, Box<dyn std::error::Error + Send + Sync>> {
let cookie_header = req.headers().get("Cookie").and_then(|v| v.to_str().ok());
) -> Result<Claims> {
// Try to get token from cookie first
let cookie_header = req.headers()
.get("Cookie")
.and_then(|v| v.to_str().ok());
let token = cookie_header
.and_then(|c| c.split(';').find(|s| s.trim().starts_with("access_token=")))
.map(|s| s.trim().trim_start_matches("access_token="))
.and_then(|c| c.split(';').find(|s| s.trim().starts_with(&format!("{}=", JWT_COOKIE_NAME))))
.map(|s| s.trim().trim_start_matches(&format!("{}=", JWT_COOKIE_NAME)))
.or_else(|| {
req.headers()
.get("Authorization")
.and_then(|v| v.to_str().ok())
.filter(|h| h.starts_with("Bearer "))
.map(|h| &h[7..])
.filter(|h| h.starts_with(BEARER_PREFIX))
.map(|h| &h[BEARER_PREFIX.len()..])
})
.ok_or("No token found in Cookies or Authorization header")?;
.ok_or_else(|| ServerError::jwt("No token found in Cookies or Authorization header"))?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_required_spec_claims(&["exp", "sub"]);
let token_data = decode::<Claims>(token, &self.decoding_key, &validation)?;
let token_data = decode::<Claims>(token, &self.decoding_key, &validation)
.map_err(|e| ServerError::jwt_with_source("JWT validation failed", Box::new(e)))?;
Ok(token_data.claims)
}
@@ -118,8 +136,15 @@ impl Middleware for JwtMiddleware {
return MiddlewareResult::Continue(req);
}
error!(target: "auth", "JWT validation failed: {}", e);
let res = Responder::unauthorized().expect("Responder failed");
warn!("JWT validation failed for {}: {}", request_path, e);
let res = Responder::unauthorized()
.unwrap_or_else(|_| {
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.header(CONTENT_TYPE, "text/plain")
.body(Full::new(Bytes::from("Unauthorized")))
.expect("Failed to build fallback response")
});
MiddlewareResult::Respond(res)
}
}
@@ -127,6 +152,11 @@ impl Middleware for JwtMiddleware {
}
}
use http::Response;
use http_body_util::Full;
use hyper::body::Bytes;
use http::header::CONTENT_TYPE;
#[cfg(test)]
mod tests {
use super::*;
@@ -183,11 +213,20 @@ mod tests {
let public_routes = vec!["/static/logo.png".to_string()];
// Exact match should work
assert!(JwtMiddleware::is_public_route(&public_routes, "/static/logo.png"));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/static/logo.png"
));
// Different file in same directory should NOT be public
assert!(!JwtMiddleware::is_public_route(&public_routes, "/static/other.png"));
assert!(!JwtMiddleware::is_public_route(&public_routes, "/static/image.jpg"));
assert!(!JwtMiddleware::is_public_route(
&public_routes,
"/static/other.png"
));
assert!(!JwtMiddleware::is_public_route(
&public_routes,
"/static/image.jpg"
));
}
#[test]
@@ -198,10 +237,22 @@ mod tests {
assert!(JwtMiddleware::is_public_route(&public_routes, "/static"));
// Any file under the directory should be public
assert!(JwtMiddleware::is_public_route(&public_routes, "/static/app.js"));
assert!(JwtMiddleware::is_public_route(&public_routes, "/static/css/main.css"));
assert!(JwtMiddleware::is_public_route(&public_routes, "/static/images/logo.png"));
assert!(JwtMiddleware::is_public_route(&public_routes, "/static/deep/nested/path/file.txt"));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/static/app.js"
));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/static/css/main.css"
));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/static/images/logo.png"
));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/static/deep/nested/path/file.txt"
));
}
#[test]
@@ -213,28 +264,48 @@ mod tests {
];
// Exact file match
assert!(JwtMiddleware::is_public_route(&public_routes, "/public/file.css"));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/public/file.css"
));
// Directory prefix match
assert!(JwtMiddleware::is_public_route(&public_routes, "/static"));
assert!(JwtMiddleware::is_public_route(&public_routes, "/static/app.js"));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/static/app.js"
));
// API endpoint
assert!(JwtMiddleware::is_public_route(&public_routes, "/api/health"));
assert!(JwtMiddleware::is_public_route(&public_routes, "/api/health/detailed"));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/api/health"
));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/api/health/detailed"
));
// Non-public paths
assert!(!JwtMiddleware::is_public_route(&public_routes, "/api/users"));
assert!(!JwtMiddleware::is_public_route(
&public_routes,
"/api/users"
));
assert!(!JwtMiddleware::is_public_route(&public_routes, "/admin"));
assert!(!JwtMiddleware::is_public_route(&public_routes, "/private/data"));
assert!(!JwtMiddleware::is_public_route(
&public_routes,
"/private/data"
));
}
#[test]
fn test_is_public_route_case_insensitive_extensions() {
let public_routes = vec!["/assets/LOGO.PNG".to_string()];
assert!(JwtMiddleware::is_public_route(&public_routes, "/assets/LOGO.PNG"));
// Note: exact match is case-sensitive for the path, only extension check is case-insensitive
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/assets/LOGO.PNG"
));
}
#[test]
@@ -243,7 +314,10 @@ mod tests {
let public_routes = vec!["/".to_string()];
assert!(JwtMiddleware::is_public_route(&public_routes, "/"));
assert!(JwtMiddleware::is_public_route(&public_routes, "/any/path"));
assert!(JwtMiddleware::is_public_route(&public_routes, "/deep/nested/route"));
assert!(JwtMiddleware::is_public_route(
&public_routes,
"/deep/nested/route"
));
// Empty route should not match anything
let empty_routes = vec!["".to_string()];