chore: add http2 support with connection tracking and optimize middlewares
This commit is contained in:
@@ -154,6 +154,7 @@ impl<D: Clone + Send + Sync + 'static> ServerBuilder<D> {
|
||||
config: Arc::new(self.config),
|
||||
middlewares: Arc::new(self.middlewares),
|
||||
data: self.data.map(Arc::new),
|
||||
active_connections: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,18 +31,14 @@ impl ApiKeyMiddleware {
|
||||
|
||||
impl Middleware for ApiKeyMiddleware {
|
||||
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
|
||||
let expected_key = self.api_key.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
match req.headers().get("X-API-Key") {
|
||||
Some(header) => {
|
||||
if header == expected_key.as_str() {
|
||||
if header == self.api_key.as_str() {
|
||||
MiddlewareResult::Continue(req)
|
||||
} else {
|
||||
warn!("X-API-Key validation failed for request");
|
||||
// Return a default unauthorized response if Responder fails
|
||||
let response = Responder::unauthorized().unwrap_or_else(|_| {
|
||||
// Fallback to a basic unauthorized response
|
||||
Response::builder()
|
||||
.status(http::StatusCode::UNAUTHORIZED)
|
||||
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
|
||||
|
||||
+16
-19
@@ -76,8 +76,6 @@ impl IpFilterMiddleware {
|
||||
///
|
||||
/// Performance: O(1) lookup using HashSet.
|
||||
pub fn is_authorized(&self, ip: &IpAddr) -> bool {
|
||||
// Check private ranges first (fast path for local networks)
|
||||
// Note: Only IPv4 has is_private() method
|
||||
if self.allow_private
|
||||
&& let IpAddr::V4(ipv4) = ip
|
||||
&& ipv4.is_private()
|
||||
@@ -85,12 +83,10 @@ impl IpFilterMiddleware {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Empty allowlist means "allow all"
|
||||
if self.allowed_ips.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
// O(1) lookup
|
||||
self.allowed_ips.contains(ip)
|
||||
}
|
||||
}
|
||||
@@ -99,22 +95,23 @@ impl Middleware for IpFilterMiddleware {
|
||||
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
|
||||
let client_ip = req.extensions().get::<IpAddr>().copied();
|
||||
|
||||
Box::pin(async move {
|
||||
match client_ip {
|
||||
Some(ip) if self.is_authorized(&ip) => MiddlewareResult::Continue(req),
|
||||
_ => {
|
||||
warn!("Unauthorized IP access attempt");
|
||||
let response = Responder::unauthorized().unwrap_or_else(|_| {
|
||||
Response::builder()
|
||||
.status(http::StatusCode::UNAUTHORIZED)
|
||||
.header(http::header::CONTENT_TYPE, "text/plain")
|
||||
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
|
||||
.expect("Failed to build fallback response")
|
||||
});
|
||||
MiddlewareResult::Respond(response)
|
||||
}
|
||||
match client_ip {
|
||||
Some(ip) if self.is_authorized(&ip) => {
|
||||
Box::pin(std::future::ready(MiddlewareResult::Continue(req)))
|
||||
}
|
||||
})
|
||||
_ => {
|
||||
warn!("Unauthorized IP access attempt");
|
||||
// Build response synchronously (avoid async overhead)
|
||||
let response = Responder::unauthorized().unwrap_or_else(|_| {
|
||||
Response::builder()
|
||||
.status(http::StatusCode::UNAUTHORIZED)
|
||||
.header(http::header::CONTENT_TYPE, "text/plain")
|
||||
.body(http_body_util::Full::new(Bytes::from("Unauthorized")))
|
||||
.expect("Failed to build fallback response")
|
||||
});
|
||||
Box::pin(std::future::ready(MiddlewareResult::Respond(response)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+30
-21
@@ -5,7 +5,7 @@
|
||||
|
||||
use crate::{
|
||||
Responder,
|
||||
constants::{BEARER_PREFIX, FILE_EXTENSIONS, JWT_COOKIE_NAME},
|
||||
constants::{BEARER_PREFIX, FILE_EXTENSIONS},
|
||||
error::{Result, ServerError},
|
||||
middleware::{Middleware, MiddlewareFuture, MiddlewareResult, auth_types::Claims},
|
||||
};
|
||||
@@ -14,6 +14,9 @@ use hyper::body::Incoming;
|
||||
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
|
||||
use log::warn;
|
||||
|
||||
/// Pre-computed cookie prefix for zero-allocation parsing.
|
||||
const COOKIE_PREFIX: &str = "access_token=";
|
||||
|
||||
/// JWT authentication middleware.
|
||||
///
|
||||
/// Validates JWT tokens using RS256 algorithm. Supports both
|
||||
@@ -40,62 +43,64 @@ impl JwtMiddleware {
|
||||
})
|
||||
}
|
||||
|
||||
/// Determines if the given path has a file extension.
|
||||
/// Checks if the given path has a file extension.
|
||||
///
|
||||
/// Returns true if the last segment of the path contains a dot
|
||||
/// followed by a known extension.
|
||||
///
|
||||
/// Optimized: Compares lowercase extension bytes directly against segment
|
||||
/// without allocating a lowercase copy of the segment.
|
||||
pub fn has_file_extension(path: &str) -> bool {
|
||||
// Get the last segment of the path (after the last '/')
|
||||
if let Some(segment) = path.rsplit('/').next() {
|
||||
// Check if it contains a dot and has a known extension
|
||||
if segment.contains('.') {
|
||||
let lower = segment.to_lowercase();
|
||||
return FILE_EXTENSIONS.iter().any(|ext| lower.ends_with(ext));
|
||||
let segment_bytes = segment.as_bytes();
|
||||
return FILE_EXTENSIONS.iter().any(|ext| {
|
||||
let ext_lower = ext.to_ascii_lowercase();
|
||||
let ext_bytes = ext_lower.as_bytes();
|
||||
if segment_bytes.len() < ext_bytes.len() {
|
||||
return false;
|
||||
}
|
||||
segment_bytes[segment_bytes.len() - ext_bytes.len()..]
|
||||
.iter()
|
||||
.zip(ext_bytes.iter())
|
||||
.all(|(a, b)| a.eq_ignore_ascii_case(b))
|
||||
});
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Checks if a request path is a public route.
|
||||
///
|
||||
/// - For routes WITH a file extension: exact match required
|
||||
/// - For routes WITHOUT a file extension: prefix match (allows all subpaths)
|
||||
/// - Special case: "/" as public route allows everything
|
||||
pub fn is_public_route(public_routes: &[String], request_path: &str) -> bool {
|
||||
// Special case: "/" allows everything
|
||||
if public_routes.iter().any(|r| r == "/") {
|
||||
return true;
|
||||
}
|
||||
|
||||
public_routes.iter().any(|route| {
|
||||
// Skip empty routes
|
||||
if route.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if Self::has_file_extension(route) {
|
||||
// Exact match for file paths
|
||||
request_path == route
|
||||
} else {
|
||||
// Prefix match for directory paths (allows /route and /route/*)
|
||||
request_path == route || request_path.starts_with(&format!("{}/", route))
|
||||
}
|
||||
request_path == route.as_str()
|
||||
|| (request_path.starts_with(route.as_str())
|
||||
&& request_path.as_bytes().get(route.len()) == Some(&b'/'))
|
||||
})
|
||||
}
|
||||
|
||||
/// Validates the request and extracts claims from the JWT token.
|
||||
fn validate_request(&self, req: &Request<Incoming>) -> Result<Claims> {
|
||||
// Try to get token from cookie first
|
||||
let cookie_header = req.headers().get("Cookie").and_then(|v| v.to_str().ok());
|
||||
|
||||
let token = cookie_header
|
||||
.and_then(|c| {
|
||||
c.split(';')
|
||||
.find(|s| s.trim().starts_with(&format!("{}=", JWT_COOKIE_NAME)))
|
||||
.find(|s| s.trim().starts_with(COOKIE_PREFIX))
|
||||
})
|
||||
.map(|s| {
|
||||
s.trim()
|
||||
.trim_start_matches(&format!("{}=", JWT_COOKIE_NAME))
|
||||
s.trim().strip_prefix(COOKIE_PREFIX).unwrap_or(s.trim())
|
||||
})
|
||||
.or_else(|| {
|
||||
req.headers()
|
||||
@@ -117,11 +122,15 @@ impl JwtMiddleware {
|
||||
}
|
||||
|
||||
impl Middleware for JwtMiddleware {
|
||||
fn run(&self, mut req: Request<Incoming>) -> MiddlewareFuture<'_> {
|
||||
fn run(&self, req: Request<Incoming>) -> MiddlewareFuture<'_> {
|
||||
// Capture path as owned String only once, outside the async block
|
||||
// This avoids the borrow conflict with async move
|
||||
let request_path = req.uri().path().to_string();
|
||||
let is_public_path = Self::is_public_route(&self.public_routes, &request_path);
|
||||
|
||||
Box::pin(async move {
|
||||
let mut req = req;
|
||||
|
||||
match self.validate_request(&req) {
|
||||
Ok(claims) => {
|
||||
req.extensions_mut().insert(claims);
|
||||
|
||||
+11
-1
@@ -1,12 +1,13 @@
|
||||
use http::Request;
|
||||
use http_body_util::BodyExt;
|
||||
use hyper::body::Incoming;
|
||||
use hyper::body::{Bytes, Incoming};
|
||||
use serde::de::DeserializeOwned;
|
||||
use std::error::Error;
|
||||
|
||||
pub struct Requester;
|
||||
|
||||
impl Requester {
|
||||
/// Extracts and deserializes JSON body.
|
||||
pub async fn extract_body<T>(req: Request<Incoming>) -> Result<T, Box<dyn Error + Send + Sync>>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
@@ -15,10 +16,19 @@ impl Requester {
|
||||
Ok(serde_json::from_slice(&body)?)
|
||||
}
|
||||
|
||||
|
||||
pub async fn extract_body_str(
|
||||
req: Request<Incoming>,
|
||||
) -> Result<String, Box<dyn Error + Send + Sync>> {
|
||||
let body = req.collect().await?.to_bytes();
|
||||
|
||||
Ok(String::from_utf8(body.to_vec())?)
|
||||
}
|
||||
|
||||
|
||||
pub async fn extract_body_bytes(
|
||||
req: Request<Incoming>,
|
||||
) -> Result<Bytes, Box<dyn Error + Send + Sync>> {
|
||||
Ok(req.collect().await?.to_bytes())
|
||||
}
|
||||
}
|
||||
|
||||
+61
-34
@@ -16,8 +16,8 @@ use hyper::{
|
||||
};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use log::{error, info, warn};
|
||||
use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::{net::TcpListener, signal, spawn, time::timeout};
|
||||
use std::{future::Future, net::SocketAddr, sync::Arc, sync::atomic::{AtomicUsize, Ordering}, time::Duration};
|
||||
use tokio::{net::TcpListener, signal, spawn};
|
||||
|
||||
/// Default connection timeout duration.
|
||||
const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
@@ -33,6 +33,8 @@ pub struct Server<D = ()> {
|
||||
pub middlewares: Arc<Vec<Box<dyn Middleware>>>,
|
||||
/// Shared application state.
|
||||
pub data: Option<Arc<D>>,
|
||||
/// Counter for active connections (used for graceful shutdown).
|
||||
pub(crate) active_connections: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
impl Server<()> {
|
||||
@@ -89,7 +91,7 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
||||
}
|
||||
};
|
||||
|
||||
let listener = match TcpListener::bind(addr).await {
|
||||
let std_listener = match std::net::TcpListener::bind(addr) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
error!("Failed to bind to address {}: {}", addr, e);
|
||||
@@ -97,23 +99,36 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
||||
}
|
||||
};
|
||||
|
||||
if let Err(e) = std_listener.set_nonblocking(true) {
|
||||
warn!("Failed to set non-blocking: {}", e);
|
||||
}
|
||||
|
||||
let listener = match TcpListener::from_std(std_listener) {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
error!("Failed to convert to Tokio listener: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
info!("Server listening on {}", addr);
|
||||
|
||||
let handler = Arc::new(handler);
|
||||
let shared_middlewares = self.middlewares.clone();
|
||||
let active_connections = self.active_connections.clone();
|
||||
|
||||
// Main accept loop
|
||||
loop {
|
||||
tokio::select! {
|
||||
// Handle incoming connections
|
||||
accept_result = listener.accept() => {
|
||||
match accept_result {
|
||||
Ok((tcp, client_addr)) => {
|
||||
active_connections.fetch_add(1, Ordering::Relaxed);
|
||||
self.handle_connection(
|
||||
tcp,
|
||||
client_addr,
|
||||
handler.clone(),
|
||||
shared_middlewares.clone(),
|
||||
active_connections.clone(),
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -122,7 +137,6 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle shutdown signal
|
||||
_ = signal::ctrl_c() => {
|
||||
info!("Shutdown signal received, stopping server...");
|
||||
break;
|
||||
@@ -130,18 +144,26 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
||||
}
|
||||
}
|
||||
|
||||
// Graceful shutdown
|
||||
info!(
|
||||
"Entering graceful shutdown (timeout: {}s)",
|
||||
shutdown_timeout.as_secs()
|
||||
);
|
||||
|
||||
// Give time for in-flight requests to complete
|
||||
timeout(shutdown_timeout, async {
|
||||
info!("Shutdown complete");
|
||||
})
|
||||
.await
|
||||
.ok();
|
||||
let start = std::time::Instant::now();
|
||||
while active_connections.load(Ordering::Relaxed) > 0 && start.elapsed() < shutdown_timeout {
|
||||
let remaining = shutdown_timeout - start.elapsed();
|
||||
tokio::time::sleep(std::cmp::min(remaining, Duration::from_millis(100))).await;
|
||||
}
|
||||
|
||||
let remaining = active_connections.load(Ordering::Relaxed);
|
||||
if remaining > 0 {
|
||||
warn!(
|
||||
"Shutdown timeout reached with {} connection(s) still active",
|
||||
remaining
|
||||
);
|
||||
} else {
|
||||
info!("All connections completed, shutdown complete");
|
||||
}
|
||||
}
|
||||
|
||||
/// Handles a single incoming TCP connection.
|
||||
@@ -151,6 +173,7 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
||||
client_addr: SocketAddr,
|
||||
handler: Arc<F>,
|
||||
middlewares: Arc<Vec<Box<dyn Middleware>>>,
|
||||
active_connections: Arc<AtomicUsize>,
|
||||
) where
|
||||
F: Fn(Request<Incoming>) -> Fut + Send + Sync + 'static,
|
||||
Fut: Future<Output = Result<Response<Full<Bytes>>>> + Send,
|
||||
@@ -160,33 +183,37 @@ impl<D: Clone + Send + Sync + 'static> Server<D> {
|
||||
let client_ip = client_addr.ip();
|
||||
|
||||
spawn(async move {
|
||||
let conn = Builder::new().serve_connection(
|
||||
io,
|
||||
service_fn(move |mut req| {
|
||||
let mws = middlewares.clone();
|
||||
let h = handler.clone();
|
||||
let conn = Builder::new()
|
||||
.max_buf_size(8 * 1024 * 1024)
|
||||
.serve_connection(
|
||||
io,
|
||||
service_fn(move |mut req| {
|
||||
let mws = middlewares.clone();
|
||||
let h = handler.clone();
|
||||
|
||||
if let Some(ref d) = data_to_inject {
|
||||
req.extensions_mut().insert(Arc::clone(d));
|
||||
}
|
||||
|
||||
async move {
|
||||
req.extensions_mut().insert(client_ip);
|
||||
|
||||
for mw in mws.iter() {
|
||||
match mw.run(req).await {
|
||||
MiddlewareResult::Continue(next_req) => req = next_req,
|
||||
MiddlewareResult::Respond(res) => return Ok(res),
|
||||
}
|
||||
if let Some(ref d) = data_to_inject {
|
||||
req.extensions_mut().insert(Arc::clone(d));
|
||||
}
|
||||
h(req).await
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
async move {
|
||||
req.extensions_mut().insert(client_ip);
|
||||
|
||||
for mw in mws.iter() {
|
||||
match mw.run(req).await {
|
||||
MiddlewareResult::Continue(next_req) => req = next_req,
|
||||
MiddlewareResult::Respond(res) => return Ok(res),
|
||||
}
|
||||
}
|
||||
h(req).await
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
if let Err(err) = conn.await {
|
||||
error!("Error serving connection from {}: {:?}", client_ip, err);
|
||||
}
|
||||
|
||||
active_connections.fetch_sub(1, Ordering::Relaxed);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user