From 1672e5367ed3a7cc1047551d57247c8c3ceed2d8 Mon Sep 17 00:00:00 2001 From: midefos Date: Wed, 29 Apr 2026 23:25:15 +0200 Subject: [PATCH] style: format imports and refactor ip filter authorization logic --- src/constants.rs | 18 +++++-------- src/error.rs | 15 +++-------- src/lib.rs | 13 +++++---- src/middleware/api_key.rs | 36 +++++++++++-------------- src/middleware/ip_filter.rs | 53 +++++++++++++------------------------ src/middleware/jwt.rs | 53 +++++++++++++++++-------------------- src/responder.rs | 39 ++++++++++++++------------- src/server.rs | 22 +++++++++------ tests/integration_tests.rs | 39 +++++++++------------------ 9 files changed, 120 insertions(+), 168 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 787ab83..30e8c16 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -19,20 +19,14 @@ pub const BEARER_PREFIX: &str = "Bearer "; /// Used by JWT middleware to determine public routes. pub const FILE_EXTENSIONS: &[&str] = &[ // HTML/CSS/JS - ".html", ".htm", ".js", ".mjs", ".css", ".scss", ".sass", ".less", - // Data formats - ".json", ".xml", ".yaml", ".yml", ".toml", ".env", - // Images + ".html", ".htm", ".js", ".mjs", ".css", ".scss", ".sass", ".less", // Data formats + ".json", ".xml", ".yaml", ".yml", ".toml", ".env", // Images ".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".webp", ".avif", ".bmp", // Fonts - ".woff", ".woff2", ".ttf", ".eot", ".otf", - // Documents - ".pdf", ".txt", ".md", ".csv", ".xlsx", ".docx", - // Archives - ".zip", ".tar", ".gz", - // Media - ".mp4", ".webm", ".mp3", ".wav", ".ogg", ".flac", - // Other + ".woff", ".woff2", ".ttf", ".eot", ".otf", // Documents + ".pdf", ".txt", ".md", ".csv", ".xlsx", ".docx", // Archives + ".zip", ".tar", ".gz", // Media + ".mp4", ".webm", ".mp3", ".wav", ".ogg", ".flac", // Other ".wasm", ".br", ]; diff --git a/src/error.rs b/src/error.rs index 0613911..5f9614b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -11,10 +11,7 @@ use std::net::AddrParseError; #[derive(Debug)] pub enum ServerError { /// Failed to bind to the specified address. - Bind { - address: String, - source: io::Error, - }, + Bind { address: String, source: io::Error }, /// Failed to parse an address string into a SocketAddr. ParseAddress { @@ -23,10 +20,7 @@ pub enum ServerError { }, /// Validation failed for a configuration value. - Validation { - field: String, - message: String, - }, + Validation { field: String, message: String }, /// JWT authentication or validation failed. Jwt { @@ -35,10 +29,7 @@ pub enum ServerError { }, /// Middleware execution failed. - Middleware { - name: String, - message: String, - }, + Middleware { name: String, message: String }, /// Request body parsing or processing failed. Request { diff --git a/src/lib.rs b/src/lib.rs index baa25b8..304d47b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,7 +2,7 @@ mod builder; mod config; pub mod constants; mod error; -pub mod middleware; // Export entire module for testing +pub mod middleware; // Export entire module for testing mod requester; mod responder; mod server; @@ -11,14 +11,13 @@ mod url_extract; pub use builder::ServerBuilder; pub use config::ServerConfig; pub use constants::{ - DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHUTDOWN_TIMEOUT_SECS, - FILE_EXTENSIONS, JWT_COOKIE_NAME, BEARER_PREFIX, - MAX_ALLOWED_IPS, + BEARER_PREFIX, DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHUTDOWN_TIMEOUT_SECS, FILE_EXTENSIONS, + JWT_COOKIE_NAME, MAX_ALLOWED_IPS, }; -pub use error::{ServerError, Result}; +pub use error::{Result, ServerError}; pub use middleware::{ - Claims, ApiKeyMiddleware, IpFilterMiddleware, JwtMiddleware, - Middleware, MiddlewareFuture, MiddlewareResult, + ApiKeyMiddleware, Claims, IpFilterMiddleware, JwtMiddleware, Middleware, MiddlewareFuture, + MiddlewareResult, }; pub use requester::Requester; pub use responder::Responder; diff --git a/src/middleware/api_key.rs b/src/middleware/api_key.rs index bfdea3b..a01cbcc 100644 --- a/src/middleware/api_key.rs +++ b/src/middleware/api_key.rs @@ -41,30 +41,24 @@ impl Middleware for ApiKeyMiddleware { } else { warn!("X-API-Key validation failed for request"); // Return a default unauthorized response if Responder fails - let response = Responder::unauthorized() - .unwrap_or_else(|_| { - // Fallback to a basic unauthorized response - Response::builder() - .status(http::StatusCode::UNAUTHORIZED) - .body(http_body_util::Full::new( - Bytes::from("Unauthorized") - )) - .expect("Failed to build fallback response") - }); + let response = Responder::unauthorized().unwrap_or_else(|_| { + // Fallback to a basic unauthorized response + Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(http_body_util::Full::new(Bytes::from("Unauthorized"))) + .expect("Failed to build fallback response") + }); MiddlewareResult::Respond(response) } } None => { warn!("X-API-Key header missing from request"); - let response = Responder::unauthorized() - .unwrap_or_else(|_| { - Response::builder() - .status(http::StatusCode::UNAUTHORIZED) - .body(http_body_util::Full::new( - Bytes::from("Unauthorized") - )) - .expect("Failed to build fallback response") - }); + let response = Responder::unauthorized().unwrap_or_else(|_| { + Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .body(http_body_util::Full::new(Bytes::from("Unauthorized"))) + .expect("Failed to build fallback response") + }); MiddlewareResult::Respond(response) } } @@ -75,11 +69,11 @@ impl Middleware for ApiKeyMiddleware { #[cfg(test)] mod tests { use super::*; - use http::Request; #[test] fn test_api_key_middleware_new() { let middleware = ApiKeyMiddleware::new("test-key"); - assert_eq!(middleware.api_key, "test-key"); + assert!(!middleware.is_invalid_key("test-key")); + assert!(middleware.is_invalid_key("wrong-key")); } } diff --git a/src/middleware/ip_filter.rs b/src/middleware/ip_filter.rs index 12364d5..42de712 100644 --- a/src/middleware/ip_filter.rs +++ b/src/middleware/ip_filter.rs @@ -5,7 +5,7 @@ use crate::{ Responder, - error::{ServerError, Result}, + error::{Result, ServerError}, middleware::{Middleware, MiddlewareFuture, MiddlewareResult}, }; use http::{Request, Response}; @@ -78,12 +78,11 @@ impl IpFilterMiddleware { pub fn is_authorized(&self, ip: &IpAddr) -> bool { // Check private ranges first (fast path for local networks) // Note: Only IPv4 has is_private() method - if self.allow_private { - if let IpAddr::V4(ipv4) = ip { - if ipv4.is_private() { - return true; - } - } + if self.allow_private + && let IpAddr::V4(ipv4) = ip + && ipv4.is_private() + { + return true; } // Empty allowlist means "allow all" @@ -105,16 +104,13 @@ impl Middleware for IpFilterMiddleware { Some(ip) if self.is_authorized(&ip) => MiddlewareResult::Continue(req), _ => { warn!("Unauthorized IP access attempt"); - let response = Responder::unauthorized() - .unwrap_or_else(|_| { - Response::builder() - .status(http::StatusCode::UNAUTHORIZED) - .header(http::header::CONTENT_TYPE, "text/plain") - .body(http_body_util::Full::new( - Bytes::from("Unauthorized") - )) - .expect("Failed to build fallback response") - }); + let response = Responder::unauthorized().unwrap_or_else(|_| { + Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .header(http::header::CONTENT_TYPE, "text/plain") + .body(http_body_util::Full::new(Bytes::from("Unauthorized"))) + .expect("Failed to build fallback response") + }); MiddlewareResult::Respond(response) } } @@ -134,17 +130,14 @@ mod tests { let result = IpFilterMiddleware::new( vec!["192.168.1.1".to_string(), "10.0.0.1".to_string()], - false + false, ); assert!(result.is_ok()); } #[test] fn test_new_rejects_invalid_ip() { - let result = IpFilterMiddleware::new( - vec!["not-an-ip".to_string()], - false - ); + let result = IpFilterMiddleware::new(vec!["not-an-ip".to_string()], false); assert!(result.is_err()); } @@ -167,10 +160,8 @@ mod tests { #[test] fn test_specific_ip_in_allow_list() { - let middleware = IpFilterMiddleware::new_unchecked( - vec!["192.168.1.100".to_string()], - false - ); + let middleware = + IpFilterMiddleware::new_unchecked(vec!["192.168.1.100".to_string()], false); let allowed_ip: IpAddr = "192.168.1.100".parse().unwrap(); let denied_ip: IpAddr = "192.168.1.200".parse().unwrap(); @@ -199,10 +190,7 @@ mod tests { #[test] fn test_multiple_allowed_ips() { let middleware = IpFilterMiddleware::new_unchecked( - vec![ - "192.168.1.100".to_string(), - "192.168.1.200".to_string(), - ], + vec!["192.168.1.100".to_string(), "192.168.1.200".to_string()], false, ); @@ -217,10 +205,7 @@ mod tests { #[test] fn test_ipv6_support() { - let middleware = IpFilterMiddleware::new_unchecked( - vec!["::1".to_string()], - false, - ); + let middleware = IpFilterMiddleware::new_unchecked(vec!["::1".to_string()], false); let ipv6_local: IpAddr = "::1".parse().unwrap(); let ipv6_other: IpAddr = "::2".parse().unwrap(); diff --git a/src/middleware/jwt.rs b/src/middleware/jwt.rs index bda5beb..37679a1 100644 --- a/src/middleware/jwt.rs +++ b/src/middleware/jwt.rs @@ -4,9 +4,9 @@ //! Bearer tokens in Authorization header and access_token cookies. use crate::{ - constants::{BEARER_PREFIX, FILE_EXTENSIONS, JWT_COOKIE_NAME}, - error::{ServerError, Result}, Responder, + constants::{BEARER_PREFIX, FILE_EXTENSIONS, JWT_COOKIE_NAME}, + error::{Result, ServerError}, middleware::{Middleware, MiddlewareFuture, MiddlewareResult, auth_types::Claims}, }; use http::Request; @@ -29,15 +29,10 @@ impl JwtMiddleware { /// # Arguments /// * `public_key` - RSA public key in PEM format /// * `public_routes` - List of routes that don't require authentication - pub fn new( - public_key: &str, - public_routes: Vec, - ) -> Result { - let decoding_key = DecodingKey::from_rsa_pem(public_key.as_bytes()) - .map_err(|e| ServerError::jwt_with_source( - "Failed to parse RSA public key", - Box::new(e), - ))?; + pub fn new(public_key: &str, public_routes: Vec) -> Result { + let decoding_key = DecodingKey::from_rsa_pem(public_key.as_bytes()).map_err(|e| { + ServerError::jwt_with_source("Failed to parse RSA public key", Box::new(e)) + })?; Ok(Self { decoding_key, @@ -89,18 +84,19 @@ impl JwtMiddleware { } /// Validates the request and extracts claims from the JWT token. - fn validate_request( - &self, - req: &Request, - ) -> Result { + fn validate_request(&self, req: &Request) -> Result { // Try to get token from cookie first - let cookie_header = req.headers() - .get("Cookie") - .and_then(|v| v.to_str().ok()); + let cookie_header = req.headers().get("Cookie").and_then(|v| v.to_str().ok()); let token = cookie_header - .and_then(|c| c.split(';').find(|s| s.trim().starts_with(&format!("{}=", JWT_COOKIE_NAME)))) - .map(|s| s.trim().trim_start_matches(&format!("{}=", JWT_COOKIE_NAME))) + .and_then(|c| { + c.split(';') + .find(|s| s.trim().starts_with(&format!("{}=", JWT_COOKIE_NAME))) + }) + .map(|s| { + s.trim() + .trim_start_matches(&format!("{}=", JWT_COOKIE_NAME)) + }) .or_else(|| { req.headers() .get("Authorization") @@ -137,14 +133,13 @@ impl Middleware for JwtMiddleware { } warn!("JWT validation failed for {}: {}", request_path, e); - let res = Responder::unauthorized() - .unwrap_or_else(|_| { - Response::builder() - .status(http::StatusCode::UNAUTHORIZED) - .header(CONTENT_TYPE, "text/plain") - .body(Full::new(Bytes::from("Unauthorized"))) - .expect("Failed to build fallback response") - }); + let res = Responder::unauthorized().unwrap_or_else(|_| { + Response::builder() + .status(http::StatusCode::UNAUTHORIZED) + .header(CONTENT_TYPE, "text/plain") + .body(Full::new(Bytes::from("Unauthorized"))) + .expect("Failed to build fallback response") + }); MiddlewareResult::Respond(res) } } @@ -153,9 +148,9 @@ impl Middleware for JwtMiddleware { } use http::Response; +use http::header::CONTENT_TYPE; use http_body_util::Full; use hyper::body::Bytes; -use http::header::CONTENT_TYPE; #[cfg(test)] mod tests { diff --git a/src/responder.rs b/src/responder.rs index b0d63cf..56a19f1 100644 --- a/src/responder.rs +++ b/src/responder.rs @@ -11,7 +11,7 @@ use http_body_util::Full; use hyper::body::Bytes; use serde::Serialize; -use crate::error::{ServerError, Result}; +use crate::error::{Result, ServerError}; /// Builder utility for constructing HTTP responses. /// @@ -38,8 +38,7 @@ impl Responder { .status(StatusCode::OK) .header(CONTENT_TYPE, "text/html; charset=utf-8") .body(Full::new(body.into())) - .map_err(|e| ServerError::response("Failed to build HTML response") - .with_source(e)) + .map_err(|e| ServerError::response("Failed to build HTML response").with_source(e)) } /// Creates a JSON response with the given value. @@ -63,8 +62,7 @@ impl Responder { .status(StatusCode::SEE_OTHER) .header(LOCATION, url) .body(Full::new(Bytes::new())) - .map_err(|e| ServerError::response("Failed to build redirect response") - .with_source(e)) + .map_err(|e| ServerError::response("Failed to build redirect response").with_source(e)) } /// Creates a 404 Not Found response. @@ -95,8 +93,7 @@ impl Responder { Response::builder() .status(status) .body(Full::new(body.into())) - .map_err(|e| ServerError::response("Failed to build response") - .with_source(e)) + .map_err(|e| ServerError::response("Failed to build response").with_source(e)) } /// Creates a response with custom headers. @@ -111,10 +108,9 @@ impl Responder { builder = builder.header(name, value); } - builder - .body(Full::new(body.into())) - .map_err(|e| ServerError::response("Failed to build response with headers") - .with_source(e)) + builder.body(Full::new(body.into())).map_err(|e| { + ServerError::response("Failed to build response with headers").with_source(e) + }) } /// Creates a JSON response with a custom status code. @@ -123,15 +119,13 @@ impl Responder { value: &T, ) -> Result>> { let bytes = serde_json::to_vec(value) - .map_err(|e| ServerError::response("JSON serialization failed") - .with_source(e))?; + .map_err(|e| ServerError::response("JSON serialization failed").with_source(e))?; Response::builder() .status(status) .header(CONTENT_TYPE, "application/json") .body(Full::new(Bytes::from(bytes))) - .map_err(|e| ServerError::response("Failed to build JSON response") - .with_source(e)) + .map_err(|e| ServerError::response("Failed to build JSON response").with_source(e)) } /// Creates a 400 Bad Request response. @@ -144,18 +138,25 @@ impl Responder { Response::builder() .status(StatusCode::NO_CONTENT) .body(Full::new(Bytes::new())) - .map_err(|e| ServerError::response("Failed to build no content response") - .with_source(e)) + .map_err(|e| { + ServerError::response("Failed to build no content response").with_source(e) + }) } } // Helper trait to add with_source method to ServerError trait WithSource { - fn with_source(self, source: impl Into>) -> ServerError; + fn with_source( + self, + source: impl Into>, + ) -> ServerError; } impl WithSource for ServerError { - fn with_source(mut self, source: impl Into>) -> ServerError { + fn with_source( + mut self, + source: impl Into>, + ) -> ServerError { match &mut self { ServerError::Response { source: s, .. } => *s = Some(source.into()), ServerError::Request { source: s, .. } => *s = Some(source.into()), diff --git a/src/server.rs b/src/server.rs index bef8934..076ea68 100644 --- a/src/server.rs +++ b/src/server.rs @@ -11,7 +11,9 @@ use crate::{ }; use http_body_util::Full; use http1::Builder; -use hyper::{Request, Response, body::Incoming, server::conn::http1, service::service_fn, body::Bytes}; +use hyper::{ + Request, Response, body::Bytes, body::Incoming, server::conn::http1, service::service_fn, +}; use hyper_util::rt::TokioIo; use log::{error, info, warn}; use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration}; @@ -63,7 +65,8 @@ impl Server { F: Fn(Request) -> Fut + Send + Sync + 'static, Fut: Future>>> + Send, { - self.run_with_shutdown(handler, DEFAULT_SHUTDOWN_TIMEOUT).await; + self.run_with_shutdown(handler, DEFAULT_SHUTDOWN_TIMEOUT) + .await; } /// Runs the HTTP server with a custom shutdown timeout. @@ -75,13 +78,13 @@ impl Server { F: Fn(Request) -> Fut + Send + Sync + 'static, Fut: Future>>> + Send, { - let addr: SocketAddr = match format!("{}:{}", self.config.ip, self.config.port) - .parse() - { + let addr: SocketAddr = match format!("{}:{}", self.config.ip, self.config.port).parse() { Ok(addr) => addr, Err(e) => { - error!("Failed to parse server address '{}:{}': {}", - self.config.ip, self.config.port, e); + error!( + "Failed to parse server address '{}:{}': {}", + self.config.ip, self.config.port, e + ); return; } }; @@ -128,7 +131,10 @@ impl Server { } // Graceful shutdown - info!("Entering graceful shutdown (timeout: {}s)", shutdown_timeout.as_secs()); + info!( + "Entering graceful shutdown (timeout: {}s)", + shutdown_timeout.as_secs() + ); // Give time for in-flight requests to complete timeout(shutdown_timeout, async { diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index 33bb42a..fcdce2d 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -4,8 +4,8 @@ //! including middleware chains and request handling. use servme::{ - ApiKeyMiddleware, Claims, IpFilterMiddleware, Responder, - ServerBuilder, ServerConfig, ServerError, UrlExtract, + ApiKeyMiddleware, Claims, IpFilterMiddleware, Responder, ServerBuilder, ServerConfig, + ServerError, UrlExtract, }; use std::net::IpAddr; @@ -26,8 +26,7 @@ fn test_server_builder_default_config() { #[test] fn test_server_builder_with_address() { - let builder = ServerBuilder::new() - .address("0.0.0.0", 3000); + let builder = ServerBuilder::new().address("0.0.0.0", 3000); assert_eq!(builder.config.ip, "0.0.0.0"); assert_eq!(builder.config.port, 3000); @@ -96,10 +95,7 @@ fn test_responder_redirect() { let response = result.unwrap(); assert_eq!(response.status(), http::StatusCode::SEE_OTHER); - assert_eq!( - response.headers().get("location").unwrap(), - "/new-location" - ); + assert_eq!(response.headers().get("location").unwrap(), "/new-location"); } #[test] @@ -144,25 +140,19 @@ fn test_ip_filter_middleware_validation() { let result = IpFilterMiddleware::new( vec!["192.168.1.1".to_string(), "10.0.0.1".to_string()], - false + false, ); assert!(result.is_ok()); // Invalid IP should fail - let result = IpFilterMiddleware::new( - vec!["not-an-ip".to_string()], - false - ); + let result = IpFilterMiddleware::new(vec!["not-an-ip".to_string()], false); assert!(result.is_err()); } #[test] fn test_ip_filter_authorization() { // Test with checked middleware for valid IPs - let middleware = IpFilterMiddleware::new( - vec!["192.168.1.100".to_string()], - false - ).unwrap(); + let middleware = IpFilterMiddleware::new(vec!["192.168.1.100".to_string()], false).unwrap(); let allowed_ip: IpAddr = "192.168.1.100".parse().unwrap(); let denied_ip: IpAddr = "192.168.1.200".parse().unwrap(); @@ -173,10 +163,7 @@ fn test_ip_filter_authorization() { #[test] fn test_ip_filter_ipv6() { - let middleware = IpFilterMiddleware::new( - vec!["::1".to_string()], - false, - ).unwrap(); + let middleware = IpFilterMiddleware::new(vec!["::1".to_string()], false).unwrap(); let ipv6_local: IpAddr = "::1".parse().unwrap(); let ipv6_other: IpAddr = "::2".parse().unwrap(); @@ -237,10 +224,10 @@ fn test_claims_username() { #[test] fn test_server_error_display() { - let error = ServerError::bind("127.0.0.1:8080", std::io::Error::new( - std::io::ErrorKind::AddrInUse, - "Address already in use" - )); + let error = ServerError::bind( + "127.0.0.1:8080", + std::io::Error::new(std::io::ErrorKind::AddrInUse, "Address already in use"), + ); let display = format!("{}", error); assert!(display.contains("Failed to bind")); @@ -263,7 +250,7 @@ fn test_server_error_validation() { #[test] fn test_constants_values() { use servme::constants::{ - DEFAULT_HOST, DEFAULT_PORT, JWT_COOKIE_NAME, BEARER_PREFIX, FILE_EXTENSIONS, + BEARER_PREFIX, DEFAULT_HOST, DEFAULT_PORT, FILE_EXTENSIONS, JWT_COOKIE_NAME, }; assert_eq!(DEFAULT_HOST, "127.0.0.1");