diff --git a/src/config.rs b/src/config.rs index 0d2c25f..8b465e7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,3 +1,7 @@ +use http::Request; +use hyper::body::Incoming; +use std::net::IpAddr; + pub struct ServerConfig { pub ip: String, pub port: u16, @@ -9,6 +13,44 @@ pub struct ServerConfig { pub api_key: Option, } +impl ServerConfig { + pub fn is_ip_authorized(&self, ip: &IpAddr) -> bool { + if !self.ips_filter { + return true; + } + + if self.private_ips { + let is_private = match ip { + IpAddr::V4(ip4) => ip4.is_private(), + IpAddr::V6(_) => false, + }; + + if is_private { + return true; + } + } + + let ips = &self.ips; + if ips.is_empty() { + return true; + } + + ips.iter() + .any(|authorized_ip| &ip.to_string() == authorized_ip) + } + + pub fn is_req_authorized(&self, req: &Request) -> bool { + if self.api_key.is_none() { + return true; + } + + match req.headers().get("X-API-Key") { + Some(header) => header.eq(self.api_key.as_ref().unwrap()), + None => false, + } + } +} + impl Default for ServerConfig { fn default() -> Self { ServerConfig { diff --git a/src/server.rs b/src/server.rs index 203df40..edb19ec 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,15 +1,10 @@ use crate::{builder::ServerBuilder, config::ServerConfig}; -use http::HeaderValue; +use http1::Builder; use http_body_util::Full; use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response}; use hyper_util::rt::{TokioIo, TokioTimer}; use log::error; -use std::{ - convert::Infallible, - future::Future, - net::{IpAddr, SocketAddr}, - sync::Arc, -}; +use std::{convert::Infallible, future::Future, net::SocketAddr, sync::Arc}; use tokio::{net::TcpListener, spawn}; use tokio_util::bytes::Bytes; @@ -36,17 +31,22 @@ impl Server { let handler = Arc::new(handler); loop { - let (tcp, client_addr) = listener.accept().await.unwrap(); - let io = TokioIo::new(tcp); - - let config = Arc::clone(&self.config); - if self.config.ips_filter && !self.is_ip_authorized(&client_addr.ip()) { + let listener_res = listener.accept().await; + if listener_res.is_err() { continue; } + let (tcp, client_addr) = listener_res.unwrap(); + let io = TokioIo::new(tcp); + + if !self.config.is_ip_authorized(&client_addr.ip()) { + continue; + } + + let config = Arc::clone(&self.config); let handler = Arc::clone(&handler); spawn(async move { - if let Err(error) = http1::Builder::new() + if let Err(error) = Builder::new() .timer(TokioTimer::new()) .serve_connection( io, @@ -55,18 +55,14 @@ impl Server { let handler = Arc::clone(&handler); async move { - if let Some(ref token) = config.api_key { - if req.headers().get("X-API-Key") - != Some(&HeaderValue::from_str(token).unwrap()) - { - return Ok(Response::builder() - .status(401) - .body(Full::new(Bytes::from("Unauthorized"))) - .unwrap()); - } + if config.is_req_authorized(&req) { + handler(req).await + } else { + Ok(Response::builder() + .status(401) + .body(Full::new(Bytes::from("Unauthorized"))) + .unwrap()) } - - handler(req).await } }), ) @@ -79,25 +75,4 @@ impl Server { }); } } - - fn is_ip_authorized(&self, ip: &IpAddr) -> bool { - if self.config.private_ips { - let is_private = match ip { - IpAddr::V4(ip4) => ip4.is_private(), - IpAddr::V6(_) => false, - }; - - if is_private { - return true; - } - } - - let ips = &self.config.ips; - if ips.is_empty() { - return true; - } - - ips.iter() - .any(|authorized_ip| &ip.to_string() == authorized_ip) - } }