chore: add http2 support with connection tracking and optimize middlewares

This commit is contained in:
2026-04-29 23:47:24 +02:00
committed by ForgeCode
parent ccfd200681
commit 9621033530
11 changed files with 425 additions and 450 deletions
+1
View File
@@ -154,6 +154,7 @@ impl<D: Clone + Send + Sync + 'static> ServerBuilder<D> {
config: Arc::new(self.config),
middlewares: Arc::new(self.middlewares),
data: self.data.map(Arc::new),
active_connections: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
}
}
}
+1 -5
View File
@@ -31,18 +31,14 @@ impl ApiKeyMiddleware {
impl Middleware for ApiKeyMiddleware {
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
let expected_key = self.api_key.clone();
Box::pin(async move {
match req.headers().get("X-API-Key") {
Some(header) => {
if header == expected_key.as_str() {
if header == self.api_key.as_str() {
MiddlewareResult::Continue(req)
} else {
warn!("X-API-Key validation failed for request");
// Return a default unauthorized response if Responder fails
let response = Responder::unauthorized().unwrap_or_else(|_| {
// Fallback to a basic unauthorized response
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
+16 -19
View File
@@ -76,8 +76,6 @@ impl IpFilterMiddleware {
///
/// Performance: O(1) lookup using HashSet.
pub fn is_authorized(&self, ip: &IpAddr) -> bool {
// Check private ranges first (fast path for local networks)
// Note: Only IPv4 has is_private() method
if self.allow_private
&& let IpAddr::V4(ipv4) = ip
&& ipv4.is_private()
@@ -85,12 +83,10 @@ impl IpFilterMiddleware {
return true;
}
// Empty allowlist means "allow all"
if self.allowed_ips.is_empty() {
return true;
}
// O(1) lookup
self.allowed_ips.contains(ip)
}
}
@@ -99,22 +95,23 @@ impl Middleware for IpFilterMiddleware {
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
let client_ip = req.extensions().get::<IpAddr>().copied();
Box::pin(async move {
match client_ip {
Some(ip) if self.is_authorized(&ip) => MiddlewareResult::Continue(req),
_ => {
warn!("Unauthorized IP access attempt");
let response = Responder::unauthorized().unwrap_or_else(|_| {
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
.expect("Failed to build fallback response")
});
MiddlewareResult::Respond(response)
}
match client_ip {
Some(ip) if self.is_authorized(&ip) => {
Box::pin(std::future::ready(MiddlewareResult::Continue(req)))
}
})
_ => {
warn!("Unauthorized IP access attempt");
// Build response synchronously (avoid async overhead)
let response = Responder::unauthorized().unwrap_or_else(|_| {
Response::builder()
.status(http::StatusCode::UNAUTHORIZED)
.header(http::header::CONTENT_TYPE, "text/plain")
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
.expect("Failed to build fallback response")
});
Box::pin(std::future::ready(MiddlewareResult::Respond(response)))
}
}
}
}
+30 -21
View File
@@ -5,7 +5,7 @@
use crate::{
Responder,
constants::{BEARER_PREFIX, FILE_EXTENSIONS, JWT_COOKIE_NAME},
constants::{BEARER_PREFIX, FILE_EXTENSIONS},
error::{Result, ServerError},
middleware::{Middleware, MiddlewareFuture, MiddlewareResult, auth_types::Claims},
};
@@ -14,6 +14,9 @@ use hyper::body::Incoming;
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
use log::warn;
/// Pre-computed cookie prefix for zero-allocation parsing.
const COOKIE_PREFIX: &str = "access_token=";
/// JWT authentication middleware.
///
/// Validates JWT tokens using RS256 algorithm. Supports both
@@ -40,62 +43,64 @@ impl JwtMiddleware {
})
}
/// Determines if the given path has a file extension.
/// Checks if the given path has a file extension.
///
/// Returns true if the last segment of the path contains a dot
/// followed by a known extension.
///
/// Optimized: Compares lowercase extension bytes directly against segment
/// without allocating a lowercase copy of the segment.
pub fn has_file_extension(path: &str) -> bool {
// Get the last segment of the path (after the last '/')
if let Some(segment) = path.rsplit('/').next() {
// Check if it contains a dot and has a known extension
if segment.contains('.') {
let lower = segment.to_lowercase();
return FILE_EXTENSIONS.iter().any(|ext| lower.ends_with(ext));
let segment_bytes = segment.as_bytes();
return FILE_EXTENSIONS.iter().any(|ext| {
let ext_lower = ext.to_ascii_lowercase();
let ext_bytes = ext_lower.as_bytes();
if segment_bytes.len() < ext_bytes.len() {
return false;
}
segment_bytes[segment_bytes.len() - ext_bytes.len()..]
.iter()
.zip(ext_bytes.iter())
.all(|(a, b)| a.eq_ignore_ascii_case(b))
});
}
}
false
}
/// Checks if a request path is a public route.
///
/// - For routes WITH a file extension: exact match required
/// - For routes WITHOUT a file extension: prefix match (allows all subpaths)
/// - Special case: "/" as public route allows everything
pub fn is_public_route(public_routes: &[String], request_path: &str) -> bool {
// Special case: "/" allows everything
if public_routes.iter().any(|r| r == "/") {
return true;
}
public_routes.iter().any(|route| {
// Skip empty routes
if route.is_empty() {
return false;
}
if Self::has_file_extension(route) {
// Exact match for file paths
request_path == route
} else {
// Prefix match for directory paths (allows /route and /route/*)
request_path == route || request_path.starts_with(&format!("{}/", route))
}
request_path == route.as_str()
|| (request_path.starts_with(route.as_str())
&& request_path.as_bytes().get(route.len()) == Some(&b'/'))
})
}
/// Validates the request and extracts claims from the JWT token.
fn validate_request(&self, req: &Request<Incoming>) -> Result<Claims> {
// Try to get token from cookie first
let cookie_header = req.headers().get("Cookie").and_then(|v| v.to_str().ok());
let token = cookie_header
.and_then(|c| {
c.split(';')
.find(|s| s.trim().starts_with(&format!("{}=", JWT_COOKIE_NAME)))
.find(|s| s.trim().starts_with(COOKIE_PREFIX))
})
.map(|s| {
s.trim()
.trim_start_matches(&format!("{}=", JWT_COOKIE_NAME))
s.trim().strip_prefix(COOKIE_PREFIX).unwrap_or(s.trim())
})
.or_else(|| {
req.headers()
@@ -117,11 +122,15 @@ impl JwtMiddleware {
}
impl Middleware for JwtMiddleware {
fn run(&self, mut req: Request<Incoming>) -> MiddlewareFuture<'_> {
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
// Capture path as owned String only once, outside the async block
// This avoids the borrow conflict with async move
let request_path = req.uri().path().to_string();
let is_public_path = Self::is_public_route(&self.public_routes, &request_path);
Box::pin(async move {
let mut req = req;
match self.validate_request(&req) {
Ok(claims) => {
req.extensions_mut().insert(claims);
+11 -1
View File
@@ -1,12 +1,13 @@
use http::Request;
use http_body_util::BodyExt;
use hyper::body::Incoming;
use hyper::body::{Bytes, Incoming};
use serde::de::DeserializeOwned;
use std::error::Error;
pub struct Requester;
impl Requester {
/// Extracts and deserializes JSON body.
pub async fn extract_body<T>(req: Request<Incoming>) -> Result<T, Box<dyn Error + Send + Sync>>
where
T: DeserializeOwned,
@@ -15,10 +16,19 @@ impl Requester {
Ok(serde_json::from_slice(&body)?)
}
pub async fn extract_body_str(
req: Request<Incoming>,
) -> Result<String, Box<dyn Error + Send + Sync>> {
let body = req.collect().await?.to_bytes();
Ok(String::from_utf8(body.to_vec())?)
}
pub async fn extract_body_bytes(
req: Request<Incoming>,
) -> Result<Bytes, Box<dyn Error + Send + Sync>> {
Ok(req.collect().await?.to_bytes())
}
}
+61 -34
View File
@@ -16,8 +16,8 @@ use hyper::{
};
use hyper_util::rt::TokioIo;
use log::{error, info, warn};
use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration};
use tokio::{net::TcpListener, signal, spawn, time::timeout};
use std::{future::Future, net::SocketAddr, sync::Arc, sync::atomic::{AtomicUsize, Ordering}, time::Duration};
use tokio::{net::TcpListener, signal, spawn};
/// Default connection timeout duration.
const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
@@ -33,6 +33,8 @@ pub struct Server<D = ()> {
pub middlewares: Arc<Vec<Box<dyn Middleware>>>,
/// Shared application state.
pub data: Option<Arc<D>>,
/// Counter for active connections (used for graceful shutdown).
pub(crate) active_connections: Arc<AtomicUsize>,
}
impl Server<()> {
@@ -89,7 +91,7 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
}
};
let listener = match TcpListener::bind(addr).await {
let std_listener = match std::net::TcpListener::bind(addr) {
Ok(l) => l,
Err(e) => {
error!("Failed to bind to address {}: {}", addr, e);
@@ -97,23 +99,36 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
}
};
if let Err(e) = std_listener.set_nonblocking(true) {
warn!("Failed to set non-blocking: {}", e);
}
let listener = match TcpListener::from_std(std_listener) {
Ok(l) => l,
Err(e) => {
error!("Failed to convert to Tokio listener: {}", e);
return;
}
};
info!("Server listening on {}", addr);
let handler = Arc::new(handler);
let shared_middlewares = self.middlewares.clone();
let active_connections = self.active_connections.clone();
// Main accept loop
loop {
tokio::select! {
// Handle incoming connections
accept_result = listener.accept() => {
match accept_result {
Ok((tcp, client_addr)) => {
active_connections.fetch_add(1, Ordering::Relaxed);
self.handle_connection(
tcp,
client_addr,
handler.clone(),
shared_middlewares.clone(),
active_connections.clone(),
);
}
Err(e) => {
@@ -122,7 +137,6 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
}
}
// Handle shutdown signal
_ = signal::ctrl_c() => {
info!("Shutdown signal received, stopping server...");
break;
@@ -130,18 +144,26 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
}
}
// Graceful shutdown
info!(
"Entering graceful shutdown (timeout: {}s)",
shutdown_timeout.as_secs()
);
// Give time for in-flight requests to complete
timeout(shutdown_timeout, async {
info!("Shutdown complete");
})
.await
.ok();
let start = std::time::Instant::now();
while active_connections.load(Ordering::Relaxed) > 0 && start.elapsed() < shutdown_timeout {
let remaining = shutdown_timeout - start.elapsed();
tokio::time::sleep(std::cmp::min(remaining, Duration::from_millis(100))).await;
}
let remaining = active_connections.load(Ordering::Relaxed);
if remaining > 0 {
warn!(
"Shutdown timeout reached with {} connection(s) still active",
remaining
);
} else {
info!("All connections completed, shutdown complete");
}
}
/// Handles a single incoming TCP connection.
@@ -151,6 +173,7 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
client_addr: SocketAddr,
handler: Arc<F>,
middlewares: Arc<Vec<Box<dyn Middleware>>>,
active_connections: Arc<AtomicUsize>,
) where
F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Response<Full<Bytes>>>> + Send,
@@ -160,33 +183,37 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
let client_ip = client_addr.ip();
spawn(async move {
let conn = Builder::new().serve_connection(
io,
service_fn(move |mut req| {
let mws = middlewares.clone();
let h = handler.clone();
let conn = Builder::new()
.max_buf_size(8 * 1024 * 1024)
.serve_connection(
io,
service_fn(move |mut req| {
let mws = middlewares.clone();
let h = handler.clone();
if let Some(ref d) = data_to_inject {
req.extensions_mut().insert(Arc::clone(d));
}
async move {
req.extensions_mut().insert(client_ip);
for mw in mws.iter() {
match mw.run(req).await {
MiddlewareResult::Continue(next_req) => req = next_req,
MiddlewareResult::Respond(res) => return Ok(res),
}
if let Some(ref d) = data_to_inject {
req.extensions_mut().insert(Arc::clone(d));
}
h(req).await
}
}),
);
async move {
req.extensions_mut().insert(client_ip);
for mw in mws.iter() {
match mw.run(req).await {
MiddlewareResult::Continue(next_req) => req = next_req,
MiddlewareResult::Respond(res) => return Ok(res),
}
}
h(req).await
}
}),
);
if let Err(err) = conn.await {
error!("Error serving connection from {}: {:?}", client_ip, err);
}
active_connections.fetch_sub(1, Ordering::Relaxed);
});
}
}