diff --git a/src/builder.rs b/src/builder.rs index 98b0991..2734394 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -12,17 +12,21 @@ impl ServerBuilder { self } - pub fn private_ips(mut self) -> Self { - self.config.ips_filter = true; + pub fn enable_private_ips(self) -> Self { + self.set_private_ips(true) + } + + pub fn set_private_ips(mut self, enabled: bool) -> Self { + self.config.private_ips = enabled; + self.update_ip_filter_state(); - self.config.private_ips = true; self } pub fn ips(mut self, ips: Vec) -> Self { - self.config.ips_filter = true; - self.config.ips = ips; + self.update_ip_filter_state(); + self } @@ -36,4 +40,12 @@ impl ServerBuilder { config: Arc::new(self.config), } } + + fn update_ip_filter_state(&mut self) { + if self.config.private_ips || !self.config.ips.is_empty() { + self.config.ip_filter = true; + } else { + self.config.ip_filter = false; + } + } } diff --git a/src/config.rs b/src/config.rs index 8b465e7..6494d9b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -6,16 +6,20 @@ pub struct ServerConfig { pub ip: String, pub port: u16, - pub ips_filter: bool, + // IP filtering + pub ip_filter: bool, pub private_ips: bool, pub ips: Vec, + // Request filtering pub api_key: Option, + + pub log_unauthorized: bool, } impl ServerConfig { pub fn is_ip_authorized(&self, ip: &IpAddr) -> bool { - if !self.ips_filter { + if !self.ip_filter { return true; } @@ -57,11 +61,13 @@ impl Default for ServerConfig { ip: "127.0.0.1".to_string(), port: 8080, - ips_filter: false, + ip_filter: false, private_ips: false, ips: Vec::new(), api_key: None, + + log_unauthorized: true, } } } diff --git a/src/server.rs b/src/server.rs index 2356e48..dc2ac98 100644 --- a/src/server.rs +++ b/src/server.rs @@ -31,13 +31,16 @@ impl Server { let handler = Arc::new(handler); loop { - let listener_res = listener.accept().await; - if listener_res.is_err() { - continue; - } - - let (tcp, client_addr) = listener_res.unwrap(); - let client_ip = client_addr.ip(); + let (tcp, client_addr) = match listener.accept().await { + Ok(conn) => conn, + Err(error) => { + error!( + error = error.to_string().as_str(); + "Failed to accept connection" + ); + continue; + } + }; let io = TokioIo::new(tcp); let config = Arc::clone(&self.config); @@ -52,9 +55,16 @@ impl Server { let handler = Arc::clone(&handler); async move { - if !config.is_ip_authorized(&client_ip) + if !config.is_ip_authorized(&client_addr.ip()) || !config.is_req_authorized(&req) { + if config.log_unauthorized { + error!(tag = "ban", + ip = client_addr.ip().to_string().as_str(); + "Unauthorized" + ); + } + Responder::unathorized() } else { handler(req).await