adding jwt middlewares, and improving overall middlewares structure and server
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
use crate::{config::ServerConfig, server::Server};
|
||||
use crate::{
|
||||
config::ServerConfig,
|
||||
middleware::{ApiKeyMiddleware, IpFilterMiddleware, JwtMiddleware, Middleware},
|
||||
server::Server,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct ServerBuilder {
|
||||
pub config: ServerConfig,
|
||||
pub middlewares: Vec<Box<dyn Middleware>>,
|
||||
}
|
||||
|
||||
impl ServerBuilder {
|
||||
@@ -12,36 +17,51 @@ impl ServerBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn enable_private_ips(self) -> Self {
|
||||
self.set_private_ips(true)
|
||||
}
|
||||
|
||||
pub fn set_private_ips(mut self, enabled: bool) -> Self {
|
||||
self.config.private_ips = enabled;
|
||||
self.update_ip_filter_state();
|
||||
pub fn add_api_key_middleware(mut self, api_key: &str) -> Self {
|
||||
self.middlewares
|
||||
.push(Box::new(ApiKeyMiddleware::new(api_key)));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn ips(mut self, ips: Vec<String>) -> Self {
|
||||
self.config.ips = ips;
|
||||
self.update_ip_filter_state();
|
||||
pub fn add_ip_filter_middleware(
|
||||
mut self,
|
||||
allowed_ips: Vec<String>,
|
||||
allow_private: bool,
|
||||
) -> Self {
|
||||
self.middlewares.push(Box::new(IpFilterMiddleware::new(
|
||||
allowed_ips,
|
||||
allow_private,
|
||||
)));
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_jwt_middleware(mut self, public_key: &str, public_routes: Vec<String>) -> Self {
|
||||
let middleware = JwtMiddleware::new(public_key, public_routes)
|
||||
.expect("Failed to initialize JWT Middleware: Invalid Public Key");
|
||||
|
||||
self.middlewares.push(Box::new(middleware));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn api_key(mut self, api_key: &str) -> Self {
|
||||
self.config.api_key = Some(api_key.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn middleware<M>(mut self, middleware: M) -> Self
|
||||
where
|
||||
M: Middleware + 'static,
|
||||
{
|
||||
self.middlewares.push(Box::new(middleware));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Server {
|
||||
Server {
|
||||
config: Arc::new(self.config),
|
||||
middlewares: Arc::new(self.middlewares),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_ip_filter_state(&mut self) {
|
||||
self.config.ip_filter = self.config.private_ips || !self.config.ips.is_empty();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,58 +1,9 @@
|
||||
use http::Request;
|
||||
use hyper::body::Incoming;
|
||||
use std::net::IpAddr;
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub ip: String,
|
||||
pub port: u16,
|
||||
|
||||
// IP filtering
|
||||
pub ip_filter: bool,
|
||||
pub private_ips: bool,
|
||||
pub ips: Vec<String>,
|
||||
|
||||
// Request filtering
|
||||
pub api_key: Option<String>,
|
||||
|
||||
pub log_unauthorized: bool,
|
||||
}
|
||||
|
||||
impl ServerConfig {
|
||||
pub fn is_ip_authorized(&self, ip: &IpAddr) -> bool {
|
||||
if !self.ip_filter {
|
||||
return true;
|
||||
}
|
||||
|
||||
if self.private_ips {
|
||||
let is_private = match ip {
|
||||
IpAddr::V4(ip4) => ip4.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
};
|
||||
|
||||
if is_private {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
let ips = &self.ips;
|
||||
if ips.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
ips.iter()
|
||||
.any(|authorized_ip| &ip.to_string() == authorized_ip)
|
||||
}
|
||||
|
||||
pub fn is_req_authorized(&self, req: &Request<Incoming>) -> bool {
|
||||
if self.api_key.is_none() {
|
||||
return true;
|
||||
}
|
||||
|
||||
match req.headers().get("X-API-Key") {
|
||||
Some(header) => header.eq(self.api_key.as_ref().unwrap()),
|
||||
None => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
@@ -61,13 +12,7 @@ impl Default for ServerConfig {
|
||||
ip: "127.0.0.1".to_string(),
|
||||
port: 8080,
|
||||
|
||||
ip_filter: false,
|
||||
private_ips: false,
|
||||
ips: Vec::new(),
|
||||
|
||||
api_key: None,
|
||||
|
||||
log_unauthorized: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
mod builder;
|
||||
mod config;
|
||||
mod middleware;
|
||||
mod requester;
|
||||
mod responder;
|
||||
mod server;
|
||||
mod url_extract;
|
||||
|
||||
pub use middleware::{Middleware, MiddlewareFuture, MiddlewareResult};
|
||||
pub use requester::Requester;
|
||||
pub use responder::Responder;
|
||||
pub use server::Server;
|
||||
|
||||
42
src/middleware/api_key.rs
Normal file
42
src/middleware/api_key.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use crate::{
|
||||
middleware::{Middleware, MiddlewareFuture, MiddlewareResult},
|
||||
Responder,
|
||||
};
|
||||
use http::Request;
|
||||
use hyper::body::Incoming;
|
||||
use log::warn;
|
||||
|
||||
pub struct ApiKeyMiddleware {
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
impl ApiKeyMiddleware {
|
||||
pub fn new(api_key: &str) -> Self {
|
||||
Self {
|
||||
api_key: api_key.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Middleware for ApiKeyMiddleware {
|
||||
fn run<'a>(&'a self, req: Request<Incoming>) -> MiddlewareFuture<'a> {
|
||||
let expected_key = self.api_key.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
match req.headers().get("X-API-Key") {
|
||||
Some(header) => {
|
||||
if header == expected_key.as_str() {
|
||||
MiddlewareResult::Continue(req)
|
||||
} else {
|
||||
warn!("X-API-Key wrong");
|
||||
MiddlewareResult::Respond(Responder::unauthorized().unwrap())
|
||||
}
|
||||
}
|
||||
None => {
|
||||
warn!("X-API-Key missing");
|
||||
MiddlewareResult::Respond(Responder::unauthorized().unwrap())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
17
src/middleware/auth_types.rs
Normal file
17
src/middleware/auth_types.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
pub exp: usize,
|
||||
}
|
||||
|
||||
impl Claims {
|
||||
pub fn is_expired(&self, current_timestamp: i64) -> bool {
|
||||
current_timestamp > self.exp as i64
|
||||
}
|
||||
|
||||
pub fn username(&self) -> &str {
|
||||
&self.sub
|
||||
}
|
||||
}
|
||||
56
src/middleware/ip_filter.rs
Normal file
56
src/middleware/ip_filter.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
use crate::{
|
||||
Responder,
|
||||
middleware::{Middleware, MiddlewareFuture, MiddlewareResult},
|
||||
};
|
||||
use http::Request;
|
||||
use hyper::body::Incoming;
|
||||
use log::warn;
|
||||
use std::net::IpAddr;
|
||||
|
||||
pub struct IpFilterMiddleware {
|
||||
allowed_ips: Vec<String>,
|
||||
allow_private: bool,
|
||||
}
|
||||
|
||||
impl IpFilterMiddleware {
|
||||
pub fn new(allowed_ips: Vec<String>, allow_private: bool) -> Self {
|
||||
Self {
|
||||
allowed_ips,
|
||||
allow_private,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_authorized(&self, ip: &IpAddr) -> bool {
|
||||
if self.allow_private {
|
||||
let is_private = match ip {
|
||||
IpAddr::V4(ip4) => ip4.is_private(),
|
||||
IpAddr::V6(_) => false,
|
||||
};
|
||||
if is_private {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if self.allowed_ips.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
self.allowed_ips.iter().any(|auth| &ip.to_string() == auth)
|
||||
}
|
||||
}
|
||||
|
||||
impl Middleware for IpFilterMiddleware {
|
||||
fn run<'a>(&'a self, req: Request<Incoming>) -> MiddlewareFuture<'a> {
|
||||
Box::pin(async move {
|
||||
let client_ip = req.extensions().get::<IpAddr>();
|
||||
|
||||
match client_ip {
|
||||
Some(ip) if self.is_authorized(ip) => MiddlewareResult::Continue(req),
|
||||
_ => {
|
||||
warn!("Unauthorized IP");
|
||||
MiddlewareResult::Respond(Responder::unauthorized().unwrap())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
71
src/middleware/jwt.rs
Normal file
71
src/middleware/jwt.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use crate::{
|
||||
middleware::{auth_types::Claims, Middleware, MiddlewareFuture, MiddlewareResult},
|
||||
Responder,
|
||||
};
|
||||
use http::Request;
|
||||
use hyper::body::Incoming;
|
||||
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
|
||||
use log::error;
|
||||
|
||||
pub struct JwtMiddleware {
|
||||
decoding_key: DecodingKey,
|
||||
public_routes: Vec<String>,
|
||||
}
|
||||
|
||||
impl JwtMiddleware {
|
||||
pub fn new(
|
||||
public_key: &str,
|
||||
public_routes: Vec<String>,
|
||||
) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let decoding_key = DecodingKey::from_rsa_pem(public_key.as_bytes())?;
|
||||
|
||||
Ok(Self {
|
||||
decoding_key,
|
||||
public_routes,
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_request(
|
||||
&self,
|
||||
req: &Request<Incoming>,
|
||||
) -> Result<Claims, Box<dyn std::error::Error + Send + Sync>> {
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get("Authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.filter(|h| h.starts_with("Bearer "))
|
||||
.map(|h| &h[7..])
|
||||
.ok_or("No token found")?;
|
||||
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
validation.set_required_spec_claims(&["exp", "sub"]);
|
||||
let token_data = decode::<Claims>(auth_header, &self.decoding_key, &validation)?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
}
|
||||
|
||||
impl Middleware for JwtMiddleware {
|
||||
fn run(&self, mut req: Request<Incoming>) -> MiddlewareFuture<'_> {
|
||||
let path = req.uri().path().to_string();
|
||||
let public_routes = self.public_routes.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
if public_routes.contains(&path) {
|
||||
return MiddlewareResult::Continue(req);
|
||||
}
|
||||
|
||||
match self.validate_request(&req) {
|
||||
Ok(claims) => {
|
||||
req.extensions_mut().insert(claims);
|
||||
MiddlewareResult::Continue(req)
|
||||
}
|
||||
Err(e) => {
|
||||
error!(target: "auth", "JWT validation failed: {}", e);
|
||||
let res = Responder::unauthorized().expect("Responder failed");
|
||||
MiddlewareResult::Respond(res)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
26
src/middleware/mod.rs
Normal file
26
src/middleware/mod.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
mod api_key;
|
||||
mod auth_types;
|
||||
mod ip_filter;
|
||||
mod jwt;
|
||||
|
||||
use http::{Request, Response};
|
||||
use http_body_util::Full;
|
||||
use hyper::body::{Bytes, Incoming};
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
pub enum MiddlewareResult {
|
||||
Continue(Request<Incoming>),
|
||||
Respond(Response<Full<Bytes>>),
|
||||
}
|
||||
|
||||
pub type MiddlewareFuture<'a> = Pin<Box<dyn Future<Output = MiddlewareResult> + Send + 'a>>;
|
||||
|
||||
pub trait Middleware: Send + Sync {
|
||||
fn run<'a>(&'a self, req: Request<Incoming>) -> MiddlewareFuture<'a>;
|
||||
}
|
||||
|
||||
pub use api_key::ApiKeyMiddleware;
|
||||
pub use auth_types::Claims;
|
||||
pub use ip_filter::IpFilterMiddleware;
|
||||
pub use jwt::JwtMiddleware;
|
||||
@@ -1,4 +1,4 @@
|
||||
use http::{header::CONTENT_TYPE, response::Builder, Response};
|
||||
use http::{header::CONTENT_TYPE, Response, StatusCode};
|
||||
use http_body_util::Full;
|
||||
use hyper::body::Bytes;
|
||||
use serde::Serialize;
|
||||
@@ -8,50 +8,52 @@ pub struct Responder;
|
||||
|
||||
impl Responder {
|
||||
pub fn not_found() -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
Self::response_using_builder(Self::create_builder(404), "Not Found")
|
||||
Self::text_using_status(StatusCode::NOT_FOUND.as_u16(), "Not Found")
|
||||
}
|
||||
|
||||
pub fn unathorized() -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
Self::response_using_builder(Self::create_builder(401), "Unathorized")
|
||||
pub fn unauthorized() -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
Self::text_using_status(StatusCode::UNAUTHORIZED.as_u16(), "Unauthorized")
|
||||
}
|
||||
|
||||
pub fn text(response: &str) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
Self::response_using_builder(Self::create_builder(200), response)
|
||||
Self::text_using_status(StatusCode::OK.as_u16(), response)
|
||||
}
|
||||
|
||||
pub fn json<T>(json: &T) -> Result<Response<Full<Bytes>>, Infallible>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
Self::json_using_status(200, json)
|
||||
pub fn json<T: Serialize>(json: &T) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
Self::json_using_status(StatusCode::OK.as_u16(), json)
|
||||
}
|
||||
|
||||
pub fn text_using_status(
|
||||
status: u16,
|
||||
response: &str,
|
||||
) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
Self::response_using_builder(Self::create_builder(status), response)
|
||||
let builder = Response::builder().status(status);
|
||||
Self::build_response(builder, response.to_string().into())
|
||||
}
|
||||
|
||||
pub fn json_using_status<T>(status: u16, json: &T) -> Result<Response<Full<Bytes>>, Infallible>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let builder = Self::create_builder(status).header(CONTENT_TYPE, "application/json");
|
||||
|
||||
Self::response_using_builder(builder, &serde_json::to_string(json).unwrap())
|
||||
}
|
||||
|
||||
fn response_using_builder(
|
||||
builder: Builder,
|
||||
response: &str,
|
||||
pub fn json_using_status<T: Serialize>(
|
||||
status: u16,
|
||||
json: &T,
|
||||
) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
Ok(builder
|
||||
.body(Full::<Bytes>::from(response.to_string()))
|
||||
.unwrap())
|
||||
let builder = Response::builder()
|
||||
.status(status)
|
||||
.header(CONTENT_TYPE, "application/json");
|
||||
|
||||
match serde_json::to_string(json) {
|
||||
Ok(body) => Self::build_response(builder, body.into()),
|
||||
Err(e) => Self::text_using_status(500, &format!("JSON Error: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_builder(status: u16) -> Builder {
|
||||
Response::builder().status(status)
|
||||
// Método privado interno para centralizar la construcción
|
||||
fn build_response(
|
||||
builder: http::response::Builder,
|
||||
body: Bytes,
|
||||
) -> Result<Response<Full<Bytes>>, Infallible> {
|
||||
// En un servidor web real, un error de construcción aquí es casi imposible,
|
||||
// pero manejarlo formalmente es mejor que hacer unwrap()
|
||||
Ok(builder
|
||||
.body(Full::new(body))
|
||||
.unwrap_or_else(|_| Response::new(Full::new(Bytes::from("Internal Server Error")))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
use crate::{builder::ServerBuilder, config::ServerConfig, responder::Responder};
|
||||
use crate::{
|
||||
builder::ServerBuilder,
|
||||
config::ServerConfig,
|
||||
middleware::{Middleware, MiddlewareResult},
|
||||
};
|
||||
use http1::Builder;
|
||||
use http_body_util::Full;
|
||||
use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response};
|
||||
use hyper_util::rt::{TokioIo, TokioTimer};
|
||||
use hyper_util::rt::TokioIo;
|
||||
use log::error;
|
||||
use std::{convert::Infallible, future::Future, net::SocketAddr, sync::Arc};
|
||||
use tokio::{net::TcpListener, spawn};
|
||||
@@ -10,12 +14,14 @@ use tokio_util::bytes::Bytes;
|
||||
|
||||
pub struct Server {
|
||||
pub config: Arc<ServerConfig>,
|
||||
pub middlewares: Arc<Vec<Box<dyn Middleware>>>,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
pub fn builder() -> ServerBuilder {
|
||||
ServerBuilder {
|
||||
config: ServerConfig::default(),
|
||||
middlewares: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,56 +33,51 @@ impl Server {
|
||||
let addr: SocketAddr = format!("{}:{}", self.config.ip, self.config.port)
|
||||
.parse()
|
||||
.expect("Invalid IP or port");
|
||||
let listener = TcpListener::bind(addr).await.unwrap();
|
||||
|
||||
let listener = TcpListener::bind(addr)
|
||||
.await
|
||||
.expect("Failed to bind to address");
|
||||
let handler = Arc::new(handler);
|
||||
|
||||
let shared_middlewares = self.middlewares;
|
||||
loop {
|
||||
let (tcp, client_addr) = match listener.accept().await {
|
||||
Ok(conn) => conn,
|
||||
Err(error) => {
|
||||
error!(
|
||||
error = error.to_string().as_str();
|
||||
"Failed to accept connection"
|
||||
);
|
||||
Err(e) => {
|
||||
error!("Accept error: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let io = TokioIo::new(tcp);
|
||||
|
||||
let config = Arc::clone(&self.config);
|
||||
let handler = Arc::clone(&handler);
|
||||
let mws = Arc::clone(&shared_middlewares);
|
||||
let h = Arc::clone(&handler);
|
||||
let client_ip = client_addr.ip();
|
||||
|
||||
spawn(async move {
|
||||
if let Err(error) = Builder::new()
|
||||
.timer(TokioTimer::new())
|
||||
.serve_connection(
|
||||
io,
|
||||
service_fn(move |req| {
|
||||
let config = Arc::clone(&config);
|
||||
let handler = Arc::clone(&handler);
|
||||
let conn = Builder::new().serve_connection(
|
||||
io,
|
||||
service_fn(move |mut req| {
|
||||
let mws = Arc::clone(&mws);
|
||||
let h = Arc::clone(&h);
|
||||
|
||||
async move {
|
||||
if !config.is_ip_authorized(&client_addr.ip())
|
||||
|| !config.is_req_authorized(&req)
|
||||
{
|
||||
if config.log_unauthorized {
|
||||
error!(tag = "ban",
|
||||
ip = client_addr.ip().to_string().as_str();
|
||||
"Unauthorized"
|
||||
);
|
||||
}
|
||||
async move {
|
||||
req.extensions_mut().insert(client_ip);
|
||||
|
||||
Responder::unathorized()
|
||||
} else {
|
||||
handler(req).await
|
||||
for mw in mws.iter() {
|
||||
match mw.run(req).await {
|
||||
MiddlewareResult::Continue(next_req) => req = next_req,
|
||||
MiddlewareResult::Respond(res) => return Ok(res),
|
||||
}
|
||||
}
|
||||
}),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!(error = error.to_string().as_str();
|
||||
"Serving connection"
|
||||
);
|
||||
h(req).await
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
if let Err(err) = conn.await {
|
||||
error!("Error serving connection: {:?}", err);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user