style: format imports and refactor ip filter authorization logic
This commit is contained in:
+6
-12
@@ -19,20 +19,14 @@ pub const BEARER_PREFIX: &str = "Bearer ";
|
|||||||
/// Used by JWT middleware to determine public routes.
|
/// Used by JWT middleware to determine public routes.
|
||||||
pub const FILE_EXTENSIONS: &[&str] = &[
|
pub const FILE_EXTENSIONS: &[&str] = &[
|
||||||
// HTML/CSS/JS
|
// HTML/CSS/JS
|
||||||
".html", ".htm", ".js", ".mjs", ".css", ".scss", ".sass", ".less",
|
".html", ".htm", ".js", ".mjs", ".css", ".scss", ".sass", ".less", // Data formats
|
||||||
// Data formats
|
".json", ".xml", ".yaml", ".yml", ".toml", ".env", // Images
|
||||||
".json", ".xml", ".yaml", ".yml", ".toml", ".env",
|
|
||||||
// Images
|
|
||||||
".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".webp", ".avif", ".bmp",
|
".png", ".jpg", ".jpeg", ".gif", ".svg", ".ico", ".webp", ".avif", ".bmp",
|
||||||
// Fonts
|
// Fonts
|
||||||
".woff", ".woff2", ".ttf", ".eot", ".otf",
|
".woff", ".woff2", ".ttf", ".eot", ".otf", // Documents
|
||||||
// Documents
|
".pdf", ".txt", ".md", ".csv", ".xlsx", ".docx", // Archives
|
||||||
".pdf", ".txt", ".md", ".csv", ".xlsx", ".docx",
|
".zip", ".tar", ".gz", // Media
|
||||||
// Archives
|
".mp4", ".webm", ".mp3", ".wav", ".ogg", ".flac", // Other
|
||||||
".zip", ".tar", ".gz",
|
|
||||||
// Media
|
|
||||||
".mp4", ".webm", ".mp3", ".wav", ".ogg", ".flac",
|
|
||||||
// Other
|
|
||||||
".wasm", ".br",
|
".wasm", ".br",
|
||||||
];
|
];
|
||||||
|
|
||||||
|
|||||||
+3
-12
@@ -11,10 +11,7 @@ use std::net::AddrParseError;
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum ServerError {
|
pub enum ServerError {
|
||||||
/// Failed to bind to the specified address.
|
/// Failed to bind to the specified address.
|
||||||
Bind {
|
Bind { address: String, source: io::Error },
|
||||||
address: String,
|
|
||||||
source: io::Error,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Failed to parse an address string into a SocketAddr.
|
/// Failed to parse an address string into a SocketAddr.
|
||||||
ParseAddress {
|
ParseAddress {
|
||||||
@@ -23,10 +20,7 @@ pub enum ServerError {
|
|||||||
},
|
},
|
||||||
|
|
||||||
/// Validation failed for a configuration value.
|
/// Validation failed for a configuration value.
|
||||||
Validation {
|
Validation { field: String, message: String },
|
||||||
field: String,
|
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// JWT authentication or validation failed.
|
/// JWT authentication or validation failed.
|
||||||
Jwt {
|
Jwt {
|
||||||
@@ -35,10 +29,7 @@ pub enum ServerError {
|
|||||||
},
|
},
|
||||||
|
|
||||||
/// Middleware execution failed.
|
/// Middleware execution failed.
|
||||||
Middleware {
|
Middleware { name: String, message: String },
|
||||||
name: String,
|
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
|
|
||||||
/// Request body parsing or processing failed.
|
/// Request body parsing or processing failed.
|
||||||
Request {
|
Request {
|
||||||
|
|||||||
+6
-7
@@ -2,7 +2,7 @@ mod builder;
|
|||||||
mod config;
|
mod config;
|
||||||
pub mod constants;
|
pub mod constants;
|
||||||
mod error;
|
mod error;
|
||||||
pub mod middleware; // Export entire module for testing
|
pub mod middleware; // Export entire module for testing
|
||||||
mod requester;
|
mod requester;
|
||||||
mod responder;
|
mod responder;
|
||||||
mod server;
|
mod server;
|
||||||
@@ -11,14 +11,13 @@ mod url_extract;
|
|||||||
pub use builder::ServerBuilder;
|
pub use builder::ServerBuilder;
|
||||||
pub use config::ServerConfig;
|
pub use config::ServerConfig;
|
||||||
pub use constants::{
|
pub use constants::{
|
||||||
DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHUTDOWN_TIMEOUT_SECS,
|
BEARER_PREFIX, DEFAULT_HOST, DEFAULT_PORT, DEFAULT_SHUTDOWN_TIMEOUT_SECS, FILE_EXTENSIONS,
|
||||||
FILE_EXTENSIONS, JWT_COOKIE_NAME, BEARER_PREFIX,
|
JWT_COOKIE_NAME, MAX_ALLOWED_IPS,
|
||||||
MAX_ALLOWED_IPS,
|
|
||||||
};
|
};
|
||||||
pub use error::{ServerError, Result};
|
pub use error::{Result, ServerError};
|
||||||
pub use middleware::{
|
pub use middleware::{
|
||||||
Claims, ApiKeyMiddleware, IpFilterMiddleware, JwtMiddleware,
|
ApiKeyMiddleware, Claims, IpFilterMiddleware, JwtMiddleware, Middleware, MiddlewareFuture,
|
||||||
Middleware, MiddlewareFuture, MiddlewareResult,
|
MiddlewareResult,
|
||||||
};
|
};
|
||||||
pub use requester::Requester;
|
pub use requester::Requester;
|
||||||
pub use responder::Responder;
|
pub use responder::Responder;
|
||||||
|
|||||||
+15
-21
@@ -41,30 +41,24 @@ impl Middleware for ApiKeyMiddleware {
|
|||||||
} else {
|
} else {
|
||||||
warn!("X-API-Key validation failed for request");
|
warn!("X-API-Key validation failed for request");
|
||||||
// Return a default unauthorized response if Responder fails
|
// Return a default unauthorized response if Responder fails
|
||||||
let response = Responder::unauthorized()
|
let response = Responder::unauthorized().unwrap_or_else(|_| {
|
||||||
.unwrap_or_else(|_| {
|
// Fallback to a basic unauthorized response
|
||||||
// Fallback to a basic unauthorized response
|
Response::builder()
|
||||||
Response::builder()
|
.status(http::StatusCode::UNAUTHORIZED)
|
||||||
.status(http::StatusCode::UNAUTHORIZED)
|
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
|
||||||
.body(http_body_util::Full::new(
|
.expect("Failed to build fallback response")
|
||||||
Bytes::from("Unauthorized")
|
});
|
||||||
))
|
|
||||||
.expect("Failed to build fallback response")
|
|
||||||
});
|
|
||||||
MiddlewareResult::Respond(response)
|
MiddlewareResult::Respond(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
None => {
|
None => {
|
||||||
warn!("X-API-Key header missing from request");
|
warn!("X-API-Key header missing from request");
|
||||||
let response = Responder::unauthorized()
|
let response = Responder::unauthorized().unwrap_or_else(|_| {
|
||||||
.unwrap_or_else(|_| {
|
Response::builder()
|
||||||
Response::builder()
|
.status(http::StatusCode::UNAUTHORIZED)
|
||||||
.status(http::StatusCode::UNAUTHORIZED)
|
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
|
||||||
.body(http_body_util::Full::new(
|
.expect("Failed to build fallback response")
|
||||||
Bytes::from("Unauthorized")
|
});
|
||||||
))
|
|
||||||
.expect("Failed to build fallback response")
|
|
||||||
});
|
|
||||||
MiddlewareResult::Respond(response)
|
MiddlewareResult::Respond(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,11 +69,11 @@ impl Middleware for ApiKeyMiddleware {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use http::Request;
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_api_key_middleware_new() {
|
fn test_api_key_middleware_new() {
|
||||||
let middleware = ApiKeyMiddleware::new("test-key");
|
let middleware = ApiKeyMiddleware::new("test-key");
|
||||||
assert_eq!(middleware.api_key, "test-key");
|
assert!(!middleware.is_invalid_key("test-key"));
|
||||||
|
assert!(middleware.is_invalid_key("wrong-key"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+19
-34
@@ -5,7 +5,7 @@
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
Responder,
|
Responder,
|
||||||
error::{ServerError, Result},
|
error::{Result, ServerError},
|
||||||
middleware::{Middleware, MiddlewareFuture, MiddlewareResult},
|
middleware::{Middleware, MiddlewareFuture, MiddlewareResult},
|
||||||
};
|
};
|
||||||
use http::{Request, Response};
|
use http::{Request, Response};
|
||||||
@@ -78,12 +78,11 @@ impl IpFilterMiddleware {
|
|||||||
pub fn is_authorized(&self, ip: &IpAddr) -> bool {
|
pub fn is_authorized(&self, ip: &IpAddr) -> bool {
|
||||||
// Check private ranges first (fast path for local networks)
|
// Check private ranges first (fast path for local networks)
|
||||||
// Note: Only IPv4 has is_private() method
|
// Note: Only IPv4 has is_private() method
|
||||||
if self.allow_private {
|
if self.allow_private
|
||||||
if let IpAddr::V4(ipv4) = ip {
|
&& let IpAddr::V4(ipv4) = ip
|
||||||
if ipv4.is_private() {
|
&& ipv4.is_private()
|
||||||
return true;
|
{
|
||||||
}
|
return true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Empty allowlist means "allow all"
|
// Empty allowlist means "allow all"
|
||||||
@@ -105,16 +104,13 @@ impl Middleware for IpFilterMiddleware {
|
|||||||
Some(ip) if self.is_authorized(&ip) => MiddlewareResult::Continue(req),
|
Some(ip) if self.is_authorized(&ip) => MiddlewareResult::Continue(req),
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Unauthorized IP access attempt");
|
warn!("Unauthorized IP access attempt");
|
||||||
let response = Responder::unauthorized()
|
let response = Responder::unauthorized().unwrap_or_else(|_| {
|
||||||
.unwrap_or_else(|_| {
|
Response::builder()
|
||||||
Response::builder()
|
.status(http::StatusCode::UNAUTHORIZED)
|
||||||
.status(http::StatusCode::UNAUTHORIZED)
|
.header(http::header::CONTENT_TYPE, "text/plain")
|
||||||
.header(http::header::CONTENT_TYPE, "text/plain")
|
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
|
||||||
.body(http_body_util::Full::new(
|
.expect("Failed to build fallback response")
|
||||||
Bytes::from("Unauthorized")
|
});
|
||||||
))
|
|
||||||
.expect("Failed to build fallback response")
|
|
||||||
});
|
|
||||||
MiddlewareResult::Respond(response)
|
MiddlewareResult::Respond(response)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -134,17 +130,14 @@ mod tests {
|
|||||||
|
|
||||||
let result = IpFilterMiddleware::new(
|
let result = IpFilterMiddleware::new(
|
||||||
vec!["192.168.1.1".to_string(), "10.0.0.1".to_string()],
|
vec!["192.168.1.1".to_string(), "10.0.0.1".to_string()],
|
||||||
false
|
false,
|
||||||
);
|
);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_new_rejects_invalid_ip() {
|
fn test_new_rejects_invalid_ip() {
|
||||||
let result = IpFilterMiddleware::new(
|
let result = IpFilterMiddleware::new(vec!["not-an-ip".to_string()], false);
|
||||||
vec!["not-an-ip".to_string()],
|
|
||||||
false
|
|
||||||
);
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,10 +160,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_specific_ip_in_allow_list() {
|
fn test_specific_ip_in_allow_list() {
|
||||||
let middleware = IpFilterMiddleware::new_unchecked(
|
let middleware =
|
||||||
vec!["192.168.1.100".to_string()],
|
IpFilterMiddleware::new_unchecked(vec!["192.168.1.100".to_string()], false);
|
||||||
false
|
|
||||||
);
|
|
||||||
let allowed_ip: IpAddr = "192.168.1.100".parse().unwrap();
|
let allowed_ip: IpAddr = "192.168.1.100".parse().unwrap();
|
||||||
let denied_ip: IpAddr = "192.168.1.200".parse().unwrap();
|
let denied_ip: IpAddr = "192.168.1.200".parse().unwrap();
|
||||||
|
|
||||||
@@ -199,10 +190,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_multiple_allowed_ips() {
|
fn test_multiple_allowed_ips() {
|
||||||
let middleware = IpFilterMiddleware::new_unchecked(
|
let middleware = IpFilterMiddleware::new_unchecked(
|
||||||
vec![
|
vec!["192.168.1.100".to_string(), "192.168.1.200".to_string()],
|
||||||
"192.168.1.100".to_string(),
|
|
||||||
"192.168.1.200".to_string(),
|
|
||||||
],
|
|
||||||
false,
|
false,
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -217,10 +205,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_ipv6_support() {
|
fn test_ipv6_support() {
|
||||||
let middleware = IpFilterMiddleware::new_unchecked(
|
let middleware = IpFilterMiddleware::new_unchecked(vec!["::1".to_string()], false);
|
||||||
vec!["::1".to_string()],
|
|
||||||
false,
|
|
||||||
);
|
|
||||||
let ipv6_local: IpAddr = "::1".parse().unwrap();
|
let ipv6_local: IpAddr = "::1".parse().unwrap();
|
||||||
let ipv6_other: IpAddr = "::2".parse().unwrap();
|
let ipv6_other: IpAddr = "::2".parse().unwrap();
|
||||||
|
|
||||||
|
|||||||
+24
-29
@@ -4,9 +4,9 @@
|
|||||||
//! Bearer tokens in Authorization header and access_token cookies.
|
//! Bearer tokens in Authorization header and access_token cookies.
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
constants::{BEARER_PREFIX, FILE_EXTENSIONS, JWT_COOKIE_NAME},
|
|
||||||
error::{ServerError, Result},
|
|
||||||
Responder,
|
Responder,
|
||||||
|
constants::{BEARER_PREFIX, FILE_EXTENSIONS, JWT_COOKIE_NAME},
|
||||||
|
error::{Result, ServerError},
|
||||||
middleware::{Middleware, MiddlewareFuture, MiddlewareResult, auth_types::Claims},
|
middleware::{Middleware, MiddlewareFuture, MiddlewareResult, auth_types::Claims},
|
||||||
};
|
};
|
||||||
use http::Request;
|
use http::Request;
|
||||||
@@ -29,15 +29,10 @@ impl JwtMiddleware {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `public_key` - RSA public key in PEM format
|
/// * `public_key` - RSA public key in PEM format
|
||||||
/// * `public_routes` - List of routes that don't require authentication
|
/// * `public_routes` - List of routes that don't require authentication
|
||||||
pub fn new(
|
pub fn new(public_key: &str, public_routes: Vec<String>) -> Result<Self> {
|
||||||
public_key: &str,
|
let decoding_key = DecodingKey::from_rsa_pem(public_key.as_bytes()).map_err(|e| {
|
||||||
public_routes: Vec<String>,
|
ServerError::jwt_with_source("Failed to parse RSA public key", Box::new(e))
|
||||||
) -> Result<Self> {
|
})?;
|
||||||
let decoding_key = DecodingKey::from_rsa_pem(public_key.as_bytes())
|
|
||||||
.map_err(|e| ServerError::jwt_with_source(
|
|
||||||
"Failed to parse RSA public key",
|
|
||||||
Box::new(e),
|
|
||||||
))?;
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
decoding_key,
|
decoding_key,
|
||||||
@@ -89,18 +84,19 @@ impl JwtMiddleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Validates the request and extracts claims from the JWT token.
|
/// Validates the request and extracts claims from the JWT token.
|
||||||
fn validate_request(
|
fn validate_request(&self, req: &Request<Incoming>) -> Result<Claims> {
|
||||||
&self,
|
|
||||||
req: &Request<Incoming>,
|
|
||||||
) -> Result<Claims> {
|
|
||||||
// Try to get token from cookie first
|
// Try to get token from cookie first
|
||||||
let cookie_header = req.headers()
|
let cookie_header = req.headers().get("Cookie").and_then(|v| v.to_str().ok());
|
||||||
.get("Cookie")
|
|
||||||
.and_then(|v| v.to_str().ok());
|
|
||||||
|
|
||||||
let token = cookie_header
|
let token = cookie_header
|
||||||
.and_then(|c| c.split(';').find(|s| s.trim().starts_with(&format!("{}=", JWT_COOKIE_NAME))))
|
.and_then(|c| {
|
||||||
.map(|s| s.trim().trim_start_matches(&format!("{}=", JWT_COOKIE_NAME)))
|
c.split(';')
|
||||||
|
.find(|s| s.trim().starts_with(&format!("{}=", JWT_COOKIE_NAME)))
|
||||||
|
})
|
||||||
|
.map(|s| {
|
||||||
|
s.trim()
|
||||||
|
.trim_start_matches(&format!("{}=", JWT_COOKIE_NAME))
|
||||||
|
})
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
req.headers()
|
req.headers()
|
||||||
.get("Authorization")
|
.get("Authorization")
|
||||||
@@ -137,14 +133,13 @@ impl Middleware for JwtMiddleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
warn!("JWT validation failed for {}: {}", request_path, e);
|
warn!("JWT validation failed for {}: {}", request_path, e);
|
||||||
let res = Responder::unauthorized()
|
let res = Responder::unauthorized().unwrap_or_else(|_| {
|
||||||
.unwrap_or_else(|_| {
|
Response::builder()
|
||||||
Response::builder()
|
.status(http::StatusCode::UNAUTHORIZED)
|
||||||
.status(http::StatusCode::UNAUTHORIZED)
|
.header(CONTENT_TYPE, "text/plain")
|
||||||
.header(CONTENT_TYPE, "text/plain")
|
.body(Full::new(Bytes::from("Unauthorized")))
|
||||||
.body(Full::new(Bytes::from("Unauthorized")))
|
.expect("Failed to build fallback response")
|
||||||
.expect("Failed to build fallback response")
|
});
|
||||||
});
|
|
||||||
MiddlewareResult::Respond(res)
|
MiddlewareResult::Respond(res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -153,9 +148,9 @@ impl Middleware for JwtMiddleware {
|
|||||||
}
|
}
|
||||||
|
|
||||||
use http::Response;
|
use http::Response;
|
||||||
|
use http::header::CONTENT_TYPE;
|
||||||
use http_body_util::Full;
|
use http_body_util::Full;
|
||||||
use hyper::body::Bytes;
|
use hyper::body::Bytes;
|
||||||
use http::header::CONTENT_TYPE;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
|||||||
+20
-19
@@ -11,7 +11,7 @@ use http_body_util::Full;
|
|||||||
use hyper::body::Bytes;
|
use hyper::body::Bytes;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
|
|
||||||
use crate::error::{ServerError, Result};
|
use crate::error::{Result, ServerError};
|
||||||
|
|
||||||
/// Builder utility for constructing HTTP responses.
|
/// Builder utility for constructing HTTP responses.
|
||||||
///
|
///
|
||||||
@@ -38,8 +38,7 @@ impl Responder {
|
|||||||
.status(StatusCode::OK)
|
.status(StatusCode::OK)
|
||||||
.header(CONTENT_TYPE, "text/html; charset=utf-8")
|
.header(CONTENT_TYPE, "text/html; charset=utf-8")
|
||||||
.body(Full::new(body.into()))
|
.body(Full::new(body.into()))
|
||||||
.map_err(|e| ServerError::response("Failed to build HTML response")
|
.map_err(|e| ServerError::response("Failed to build HTML response").with_source(e))
|
||||||
.with_source(e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a JSON response with the given value.
|
/// Creates a JSON response with the given value.
|
||||||
@@ -63,8 +62,7 @@ impl Responder {
|
|||||||
.status(StatusCode::SEE_OTHER)
|
.status(StatusCode::SEE_OTHER)
|
||||||
.header(LOCATION, url)
|
.header(LOCATION, url)
|
||||||
.body(Full::new(Bytes::new()))
|
.body(Full::new(Bytes::new()))
|
||||||
.map_err(|e| ServerError::response("Failed to build redirect response")
|
.map_err(|e| ServerError::response("Failed to build redirect response").with_source(e))
|
||||||
.with_source(e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a 404 Not Found response.
|
/// Creates a 404 Not Found response.
|
||||||
@@ -95,8 +93,7 @@ impl Responder {
|
|||||||
Response::builder()
|
Response::builder()
|
||||||
.status(status)
|
.status(status)
|
||||||
.body(Full::new(body.into()))
|
.body(Full::new(body.into()))
|
||||||
.map_err(|e| ServerError::response("Failed to build response")
|
.map_err(|e| ServerError::response("Failed to build response").with_source(e))
|
||||||
.with_source(e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a response with custom headers.
|
/// Creates a response with custom headers.
|
||||||
@@ -111,10 +108,9 @@ impl Responder {
|
|||||||
builder = builder.header(name, value);
|
builder = builder.header(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
builder
|
builder.body(Full::new(body.into())).map_err(|e| {
|
||||||
.body(Full::new(body.into()))
|
ServerError::response("Failed to build response with headers").with_source(e)
|
||||||
.map_err(|e| ServerError::response("Failed to build response with headers")
|
})
|
||||||
.with_source(e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a JSON response with a custom status code.
|
/// Creates a JSON response with a custom status code.
|
||||||
@@ -123,15 +119,13 @@ impl Responder {
|
|||||||
value: &T,
|
value: &T,
|
||||||
) -> Result<Response<Full<Bytes>>> {
|
) -> Result<Response<Full<Bytes>>> {
|
||||||
let bytes = serde_json::to_vec(value)
|
let bytes = serde_json::to_vec(value)
|
||||||
.map_err(|e| ServerError::response("JSON serialization failed")
|
.map_err(|e| ServerError::response("JSON serialization failed").with_source(e))?;
|
||||||
.with_source(e))?;
|
|
||||||
|
|
||||||
Response::builder()
|
Response::builder()
|
||||||
.status(status)
|
.status(status)
|
||||||
.header(CONTENT_TYPE, "application/json")
|
.header(CONTENT_TYPE, "application/json")
|
||||||
.body(Full::new(Bytes::from(bytes)))
|
.body(Full::new(Bytes::from(bytes)))
|
||||||
.map_err(|e| ServerError::response("Failed to build JSON response")
|
.map_err(|e| ServerError::response("Failed to build JSON response").with_source(e))
|
||||||
.with_source(e))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a 400 Bad Request response.
|
/// Creates a 400 Bad Request response.
|
||||||
@@ -144,18 +138,25 @@ impl Responder {
|
|||||||
Response::builder()
|
Response::builder()
|
||||||
.status(StatusCode::NO_CONTENT)
|
.status(StatusCode::NO_CONTENT)
|
||||||
.body(Full::new(Bytes::new()))
|
.body(Full::new(Bytes::new()))
|
||||||
.map_err(|e| ServerError::response("Failed to build no content response")
|
.map_err(|e| {
|
||||||
.with_source(e))
|
ServerError::response("Failed to build no content response").with_source(e)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper trait to add with_source method to ServerError
|
// Helper trait to add with_source method to ServerError
|
||||||
trait WithSource {
|
trait WithSource {
|
||||||
fn with_source(self, source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> ServerError;
|
fn with_source(
|
||||||
|
self,
|
||||||
|
source: impl Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
) -> ServerError;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WithSource for ServerError {
|
impl WithSource for ServerError {
|
||||||
fn with_source(mut self, source: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> ServerError {
|
fn with_source(
|
||||||
|
mut self,
|
||||||
|
source: impl Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||||
|
) -> ServerError {
|
||||||
match &mut self {
|
match &mut self {
|
||||||
ServerError::Response { source: s, .. } => *s = Some(source.into()),
|
ServerError::Response { source: s, .. } => *s = Some(source.into()),
|
||||||
ServerError::Request { source: s, .. } => *s = Some(source.into()),
|
ServerError::Request { source: s, .. } => *s = Some(source.into()),
|
||||||
|
|||||||
+14
-8
@@ -11,7 +11,9 @@ use crate::{
|
|||||||
};
|
};
|
||||||
use http_body_util::Full;
|
use http_body_util::Full;
|
||||||
use http1::Builder;
|
use http1::Builder;
|
||||||
use hyper::{Request, Response, body::Incoming, server::conn::http1, service::service_fn, body::Bytes};
|
use hyper::{
|
||||||
|
Request, Response, body::Bytes, body::Incoming, server::conn::http1, service::service_fn,
|
||||||
|
};
|
||||||
use hyper_util::rt::TokioIo;
|
use hyper_util::rt::TokioIo;
|
||||||
use log::{error, info, warn};
|
use log::{error, info, warn};
|
||||||
use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration};
|
use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration};
|
||||||
@@ -63,7 +65,8 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
|||||||
F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
|
F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
|
||||||
Fut: Future<Output = Result<Response<Full<Bytes>>>> + Send,
|
Fut: Future<Output = Result<Response<Full<Bytes>>>> + Send,
|
||||||
{
|
{
|
||||||
self.run_with_shutdown(handler, DEFAULT_SHUTDOWN_TIMEOUT).await;
|
self.run_with_shutdown(handler, DEFAULT_SHUTDOWN_TIMEOUT)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs the HTTP server with a custom shutdown timeout.
|
/// Runs the HTTP server with a custom shutdown timeout.
|
||||||
@@ -75,13 +78,13 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
|||||||
F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
|
F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
|
||||||
Fut: Future<Output = Result<Response<Full<Bytes>>>> + Send,
|
Fut: Future<Output = Result<Response<Full<Bytes>>>> + Send,
|
||||||
{
|
{
|
||||||
let addr: SocketAddr = match format!("{}:{}", self.config.ip, self.config.port)
|
let addr: SocketAddr = match format!("{}:{}", self.config.ip, self.config.port).parse() {
|
||||||
.parse()
|
|
||||||
{
|
|
||||||
Ok(addr) => addr,
|
Ok(addr) => addr,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to parse server address '{}:{}': {}",
|
error!(
|
||||||
self.config.ip, self.config.port, e);
|
"Failed to parse server address '{}:{}': {}",
|
||||||
|
self.config.ip, self.config.port, e
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -128,7 +131,10 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Graceful shutdown
|
// Graceful shutdown
|
||||||
info!("Entering graceful shutdown (timeout: {}s)", shutdown_timeout.as_secs());
|
info!(
|
||||||
|
"Entering graceful shutdown (timeout: {}s)",
|
||||||
|
shutdown_timeout.as_secs()
|
||||||
|
);
|
||||||
|
|
||||||
// Give time for in-flight requests to complete
|
// Give time for in-flight requests to complete
|
||||||
timeout(shutdown_timeout, async {
|
timeout(shutdown_timeout, async {
|
||||||
|
|||||||
+13
-26
@@ -4,8 +4,8 @@
|
|||||||
//! including middleware chains and request handling.
|
//! including middleware chains and request handling.
|
||||||
|
|
||||||
use servme::{
|
use servme::{
|
||||||
ApiKeyMiddleware, Claims, IpFilterMiddleware, Responder,
|
ApiKeyMiddleware, Claims, IpFilterMiddleware, Responder, ServerBuilder, ServerConfig,
|
||||||
ServerBuilder, ServerConfig, ServerError, UrlExtract,
|
ServerError, UrlExtract,
|
||||||
};
|
};
|
||||||
use std::net::IpAddr;
|
use std::net::IpAddr;
|
||||||
|
|
||||||
@@ -26,8 +26,7 @@ fn test_server_builder_default_config() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_server_builder_with_address() {
|
fn test_server_builder_with_address() {
|
||||||
let builder = ServerBuilder::new()
|
let builder = ServerBuilder::new().address("0.0.0.0", 3000);
|
||||||
.address("0.0.0.0", 3000);
|
|
||||||
|
|
||||||
assert_eq!(builder.config.ip, "0.0.0.0");
|
assert_eq!(builder.config.ip, "0.0.0.0");
|
||||||
assert_eq!(builder.config.port, 3000);
|
assert_eq!(builder.config.port, 3000);
|
||||||
@@ -96,10 +95,7 @@ fn test_responder_redirect() {
|
|||||||
|
|
||||||
let response = result.unwrap();
|
let response = result.unwrap();
|
||||||
assert_eq!(response.status(), http::StatusCode::SEE_OTHER);
|
assert_eq!(response.status(), http::StatusCode::SEE_OTHER);
|
||||||
assert_eq!(
|
assert_eq!(response.headers().get("location").unwrap(), "/new-location");
|
||||||
response.headers().get("location").unwrap(),
|
|
||||||
"/new-location"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@@ -144,25 +140,19 @@ fn test_ip_filter_middleware_validation() {
|
|||||||
|
|
||||||
let result = IpFilterMiddleware::new(
|
let result = IpFilterMiddleware::new(
|
||||||
vec!["192.168.1.1".to_string(), "10.0.0.1".to_string()],
|
vec!["192.168.1.1".to_string(), "10.0.0.1".to_string()],
|
||||||
false
|
false,
|
||||||
);
|
);
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
|
|
||||||
// Invalid IP should fail
|
// Invalid IP should fail
|
||||||
let result = IpFilterMiddleware::new(
|
let result = IpFilterMiddleware::new(vec!["not-an-ip".to_string()], false);
|
||||||
vec!["not-an-ip".to_string()],
|
|
||||||
false
|
|
||||||
);
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_ip_filter_authorization() {
|
fn test_ip_filter_authorization() {
|
||||||
// Test with checked middleware for valid IPs
|
// Test with checked middleware for valid IPs
|
||||||
let middleware = IpFilterMiddleware::new(
|
let middleware = IpFilterMiddleware::new(vec!["192.168.1.100".to_string()], false).unwrap();
|
||||||
vec!["192.168.1.100".to_string()],
|
|
||||||
false
|
|
||||||
).unwrap();
|
|
||||||
|
|
||||||
let allowed_ip: IpAddr = "192.168.1.100".parse().unwrap();
|
let allowed_ip: IpAddr = "192.168.1.100".parse().unwrap();
|
||||||
let denied_ip: IpAddr = "192.168.1.200".parse().unwrap();
|
let denied_ip: IpAddr = "192.168.1.200".parse().unwrap();
|
||||||
@@ -173,10 +163,7 @@ fn test_ip_filter_authorization() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_ip_filter_ipv6() {
|
fn test_ip_filter_ipv6() {
|
||||||
let middleware = IpFilterMiddleware::new(
|
let middleware = IpFilterMiddleware::new(vec!["::1".to_string()], false).unwrap();
|
||||||
vec!["::1".to_string()],
|
|
||||||
false,
|
|
||||||
).unwrap();
|
|
||||||
|
|
||||||
let ipv6_local: IpAddr = "::1".parse().unwrap();
|
let ipv6_local: IpAddr = "::1".parse().unwrap();
|
||||||
let ipv6_other: IpAddr = "::2".parse().unwrap();
|
let ipv6_other: IpAddr = "::2".parse().unwrap();
|
||||||
@@ -237,10 +224,10 @@ fn test_claims_username() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_server_error_display() {
|
fn test_server_error_display() {
|
||||||
let error = ServerError::bind("127.0.0.1:8080", std::io::Error::new(
|
let error = ServerError::bind(
|
||||||
std::io::ErrorKind::AddrInUse,
|
"127.0.0.1:8080",
|
||||||
"Address already in use"
|
std::io::Error::new(std::io::ErrorKind::AddrInUse, "Address already in use"),
|
||||||
));
|
);
|
||||||
|
|
||||||
let display = format!("{}", error);
|
let display = format!("{}", error);
|
||||||
assert!(display.contains("Failed to bind"));
|
assert!(display.contains("Failed to bind"));
|
||||||
@@ -263,7 +250,7 @@ fn test_server_error_validation() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_constants_values() {
|
fn test_constants_values() {
|
||||||
use servme::constants::{
|
use servme::constants::{
|
||||||
DEFAULT_HOST, DEFAULT_PORT, JWT_COOKIE_NAME, BEARER_PREFIX, FILE_EXTENSIONS,
|
BEARER_PREFIX, DEFAULT_HOST, DEFAULT_PORT, FILE_EXTENSIONS, JWT_COOKIE_NAME,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(DEFAULT_HOST, "127.0.0.1");
|
assert_eq!(DEFAULT_HOST, "127.0.0.1");
|
||||||
|
|||||||
Reference in New Issue
Block a user