moving some code to server config
This commit is contained in:
		@@ -1,3 +1,7 @@
 | 
				
			|||||||
 | 
					use http::Request;
 | 
				
			||||||
 | 
					use hyper::body::Incoming;
 | 
				
			||||||
 | 
					use std::net::IpAddr;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub struct ServerConfig {
 | 
					pub struct ServerConfig {
 | 
				
			||||||
    pub ip: String,
 | 
					    pub ip: String,
 | 
				
			||||||
    pub port: u16,
 | 
					    pub port: u16,
 | 
				
			||||||
@@ -9,6 +13,44 @@ pub struct ServerConfig {
 | 
				
			|||||||
    pub api_key: Option<String>,
 | 
					    pub api_key: Option<String>,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					impl ServerConfig {
 | 
				
			||||||
 | 
					    pub fn is_ip_authorized(&self, ip: &IpAddr) -> bool {
 | 
				
			||||||
 | 
					        if !self.ips_filter {
 | 
				
			||||||
 | 
					            return true;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.private_ips {
 | 
				
			||||||
 | 
					            let is_private = match ip {
 | 
				
			||||||
 | 
					                IpAddr::V4(ip4) => ip4.is_private(),
 | 
				
			||||||
 | 
					                IpAddr::V6(_) => false,
 | 
				
			||||||
 | 
					            };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if is_private {
 | 
				
			||||||
 | 
					                return true;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        let ips = &self.ips;
 | 
				
			||||||
 | 
					        if ips.is_empty() {
 | 
				
			||||||
 | 
					            return true;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ips.iter()
 | 
				
			||||||
 | 
					            .any(|authorized_ip| &ip.to_string() == authorized_ip)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pub fn is_req_authorized(&self, req: &Request<Incoming>) -> bool {
 | 
				
			||||||
 | 
					        if self.api_key.is_none() {
 | 
				
			||||||
 | 
					            return true;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        match req.headers().get("X-API-Key") {
 | 
				
			||||||
 | 
					            Some(header) => header.eq(self.api_key.as_ref().unwrap()),
 | 
				
			||||||
 | 
					            None => false,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl Default for ServerConfig {
 | 
					impl Default for ServerConfig {
 | 
				
			||||||
    fn default() -> Self {
 | 
					    fn default() -> Self {
 | 
				
			||||||
        ServerConfig {
 | 
					        ServerConfig {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,15 +1,10 @@
 | 
				
			|||||||
use crate::{builder::ServerBuilder, config::ServerConfig};
 | 
					use crate::{builder::ServerBuilder, config::ServerConfig};
 | 
				
			||||||
use http::HeaderValue;
 | 
					use http1::Builder;
 | 
				
			||||||
use http_body_util::Full;
 | 
					use http_body_util::Full;
 | 
				
			||||||
use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response};
 | 
					use hyper::{body::Incoming, server::conn::http1, service::service_fn, Request, Response};
 | 
				
			||||||
use hyper_util::rt::{TokioIo, TokioTimer};
 | 
					use hyper_util::rt::{TokioIo, TokioTimer};
 | 
				
			||||||
use log::error;
 | 
					use log::error;
 | 
				
			||||||
use std::{
 | 
					use std::{convert::Infallible, future::Future, net::SocketAddr, sync::Arc};
 | 
				
			||||||
    convert::Infallible,
 | 
					 | 
				
			||||||
    future::Future,
 | 
					 | 
				
			||||||
    net::{IpAddr, SocketAddr},
 | 
					 | 
				
			||||||
    sync::Arc,
 | 
					 | 
				
			||||||
};
 | 
					 | 
				
			||||||
use tokio::{net::TcpListener, spawn};
 | 
					use tokio::{net::TcpListener, spawn};
 | 
				
			||||||
use tokio_util::bytes::Bytes;
 | 
					use tokio_util::bytes::Bytes;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -36,17 +31,22 @@ impl Server {
 | 
				
			|||||||
        let handler = Arc::new(handler);
 | 
					        let handler = Arc::new(handler);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        loop {
 | 
					        loop {
 | 
				
			||||||
            let (tcp, client_addr) = listener.accept().await.unwrap();
 | 
					            let listener_res = listener.accept().await;
 | 
				
			||||||
            let io = TokioIo::new(tcp);
 | 
					            if listener_res.is_err() {
 | 
				
			||||||
 | 
					 | 
				
			||||||
            let config = Arc::clone(&self.config);
 | 
					 | 
				
			||||||
            if self.config.ips_filter && !self.is_ip_authorized(&client_addr.ip()) {
 | 
					 | 
				
			||||||
                continue;
 | 
					                continue;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            let (tcp, client_addr) = listener_res.unwrap();
 | 
				
			||||||
 | 
					            let io = TokioIo::new(tcp);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if !self.config.is_ip_authorized(&client_addr.ip()) {
 | 
				
			||||||
 | 
					                continue;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            let config = Arc::clone(&self.config);
 | 
				
			||||||
            let handler = Arc::clone(&handler);
 | 
					            let handler = Arc::clone(&handler);
 | 
				
			||||||
            spawn(async move {
 | 
					            spawn(async move {
 | 
				
			||||||
                if let Err(error) = http1::Builder::new()
 | 
					                if let Err(error) = Builder::new()
 | 
				
			||||||
                    .timer(TokioTimer::new())
 | 
					                    .timer(TokioTimer::new())
 | 
				
			||||||
                    .serve_connection(
 | 
					                    .serve_connection(
 | 
				
			||||||
                        io,
 | 
					                        io,
 | 
				
			||||||
@@ -55,19 +55,15 @@ impl Server {
 | 
				
			|||||||
                            let handler = Arc::clone(&handler);
 | 
					                            let handler = Arc::clone(&handler);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            async move {
 | 
					                            async move {
 | 
				
			||||||
                                if let Some(ref token) = config.api_key {
 | 
					                                if config.is_req_authorized(&req) {
 | 
				
			||||||
                                    if req.headers().get("X-API-Key")
 | 
					                                    handler(req).await
 | 
				
			||||||
                                        != Some(&HeaderValue::from_str(token).unwrap())
 | 
					                                } else {
 | 
				
			||||||
                                    {
 | 
					                                    Ok(Response::builder()
 | 
				
			||||||
                                        return Ok(Response::builder()
 | 
					 | 
				
			||||||
                                        .status(401)
 | 
					                                        .status(401)
 | 
				
			||||||
                                        .body(Full::new(Bytes::from("Unauthorized")))
 | 
					                                        .body(Full::new(Bytes::from("Unauthorized")))
 | 
				
			||||||
                                            .unwrap());
 | 
					                                        .unwrap())
 | 
				
			||||||
                                }
 | 
					                                }
 | 
				
			||||||
                            }
 | 
					                            }
 | 
				
			||||||
 | 
					 | 
				
			||||||
                                handler(req).await
 | 
					 | 
				
			||||||
                            }
 | 
					 | 
				
			||||||
                        }),
 | 
					                        }),
 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    .await
 | 
					                    .await
 | 
				
			||||||
@@ -79,25 +75,4 @@ impl Server {
 | 
				
			|||||||
            });
 | 
					            });
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					 | 
				
			||||||
    fn is_ip_authorized(&self, ip: &IpAddr) -> bool {
 | 
					 | 
				
			||||||
        if self.config.private_ips {
 | 
					 | 
				
			||||||
            let is_private = match ip {
 | 
					 | 
				
			||||||
                IpAddr::V4(ip4) => ip4.is_private(),
 | 
					 | 
				
			||||||
                IpAddr::V6(_) => false,
 | 
					 | 
				
			||||||
            };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if is_private {
 | 
					 | 
				
			||||||
                return true;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        let ips = &self.config.ips;
 | 
					 | 
				
			||||||
        if ips.is_empty() {
 | 
					 | 
				
			||||||
            return true;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        ips.iter()
 | 
					 | 
				
			||||||
            .any(|authorized_ip| &ip.to_string() == authorized_ip)
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user