diff --git a/Cargo.toml b/Cargo.toml index fd3f8c0..adc0905 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,7 @@ http-body-util = "0.1.2" hyper = { version = "1.5.2", features = ["http1", "server"] } hyper-util = { version = "0.1", features = ["http1", "server", "tokio"] } -log = { version = "0.4.22", features=["kv"]} +log = { version = "0.4.25", features=["kv"]} [lib] name = "servme" diff --git a/src/builder.rs b/src/builder.rs index 3527516..98b0991 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -1,4 +1,5 @@ use crate::{config::ServerConfig, server::Server}; +use std::sync::Arc; pub struct ServerBuilder { pub config: ServerConfig, @@ -11,9 +12,28 @@ impl ServerBuilder { self } + pub fn private_ips(mut self) -> Self { + self.config.ips_filter = true; + + self.config.private_ips = true; + self + } + + pub fn ips(mut self, ips: Vec) -> Self { + self.config.ips_filter = true; + + self.config.ips = ips; + self + } + + pub fn api_key(mut self, api_key: &str) -> Self { + self.config.api_key = Some(api_key.to_string()); + self + } + pub fn build(self) -> Server { Server { - config: self.config, + config: Arc::new(self.config), } } } diff --git a/src/config.rs b/src/config.rs index 042ef2d..0d2c25f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,12 @@ pub struct ServerConfig { pub ip: String, pub port: u16, + + pub ips_filter: bool, + pub private_ips: bool, + pub ips: Vec, + + pub api_key: Option, } impl Default for ServerConfig { @@ -8,6 +14,12 @@ impl Default for ServerConfig { ServerConfig { ip: "127.0.0.1".to_string(), port: 8080, + + ips_filter: false, + private_ips: false, + ips: Vec::new(), + + api_key: None, } } } diff --git a/src/server.rs b/src/server.rs index bc0a6ab..203df40 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,14 +1,20 @@ use crate::{builder::ServerBuilder, config::ServerConfig}; +use http::HeaderValue; use http_body_util::Full; use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response}; use hyper_util::rt::{TokioIo, TokioTimer}; use log::error; -use std::{convert::Infallible, future::Future, net::SocketAddr, sync::Arc}; +use std::{ + convert::Infallible, + future::Future, + net::{IpAddr, SocketAddr}, + sync::Arc, +}; use tokio::{net::TcpListener, spawn}; use tokio_util::bytes::Bytes; pub struct Server { - pub config: ServerConfig, + pub config: Arc, } impl Server { @@ -30,14 +36,40 @@ impl Server { let handler = Arc::new(handler); loop { - let (tcp, _) = listener.accept().await.unwrap(); + let (tcp, client_addr) = listener.accept().await.unwrap(); let io = TokioIo::new(tcp); + let config = Arc::clone(&self.config); + if self.config.ips_filter && !self.is_ip_authorized(&client_addr.ip()) { + continue; + } + let handler = Arc::clone(&handler); spawn(async move { if let Err(error) = http1::Builder::new() .timer(TokioTimer::new()) - .serve_connection(io, service_fn(move |req| handler(req))) + .serve_connection( + io, + service_fn(move |req| { + let config = Arc::clone(&config); + let handler = Arc::clone(&handler); + + async move { + if let Some(ref token) = config.api_key { + if req.headers().get("X-API-Key") + != Some(&HeaderValue::from_str(token).unwrap()) + { + return Ok(Response::builder() + .status(401) + .body(Full::new(Bytes::from("Unauthorized"))) + .unwrap()); + } + } + + handler(req).await + } + }), + ) .await { error!(error = error.to_string().as_str(); @@ -47,4 +79,25 @@ impl Server { }); } } + + fn is_ip_authorized(&self, ip: &IpAddr) -> bool { + if self.config.private_ips { + let is_private = match ip { + IpAddr::V4(ip4) => ip4.is_private(), + IpAddr::V6(_) => false, + }; + + if is_private { + return true; + } + } + + let ips = &self.config.ips; + if ips.is_empty() { + return true; + } + + ips.iter() + .any(|authorized_ip| &ip.to_string() == authorized_ip) + } }