diff --git a/master/src/config.rs b/master/src/config.rs index f89bb13..7c06d5a 100644 --- a/master/src/config.rs +++ b/master/src/config.rs @@ -1,22 +1,18 @@ // SPDX-License-Identifier: GPL-3.0-only // SPDX-FileCopyrightText: 2023 Denis Drakhnia -use std::fs; -use std::io; -use std::net::IpAddr; -use std::path::Path; +use std::{ + fs, io, + net::{IpAddr, Ipv4Addr}, + path::Path, +}; use log::LevelFilter; use serde::{de::Error as _, Deserialize, Deserializer}; use thiserror::Error; -use xash3d_protocol::admin; -use xash3d_protocol::filter::Version; +use xash3d_protocol::{admin, filter::Version}; -#[cfg(not(feature = "ipv6"))] -pub const DEFAULT_MASTER_SERVER_IP: IpAddr = IpAddr::V4(std::net::Ipv4Addr::UNSPECIFIED); - -#[cfg(feature = "ipv6")] -pub const DEFAULT_MASTER_SERVER_IP: IpAddr = IpAddr::V6(std::net::Ipv6Addr::UNSPECIFIED); +pub const DEFAULT_MASTER_SERVER_IP: IpAddr = IpAddr::V4(Ipv4Addr::UNSPECIFIED); pub const DEFAULT_MASTER_SERVER_PORT: u16 = 27010; pub const DEFAULT_CHALLENGE_TIMEOUT: u32 = 10; diff --git a/master/src/main.rs b/master/src/main.rs index b939ef9..97eeb0c 100644 --- a/master/src/main.rs +++ b/master/src/main.rs @@ -23,7 +23,7 @@ use signal_hook::{consts::signal::*, flag as signal_flag}; use crate::cli::Cli; use crate::config::Config; -use crate::master_server::{Error, MasterServer}; +use crate::master_server::{Error, Master}; fn load_config(cli: &Cli) -> Result { let mut cfg = match cli.config_path { @@ -68,7 +68,7 @@ fn run() -> Result<(), Error> { process::exit(1); }); - let mut master = MasterServer::new(cfg)?; + let mut master = Master::new(cfg)?; let sig_flag = Arc::new(AtomicBool::new(false)); // XXX: Windows does not support SIGUSR1. #[cfg(not(windows))] diff --git a/master/src/master_server.rs b/master/src/master_server.rs index e82708f..034470b 100644 --- a/master/src/master_server.rs +++ b/master/src/master_server.rs @@ -2,9 +2,12 @@ // SPDX-FileCopyrightText: 2023 Denis Drakhnia use std::{ + cmp::Eq, collections::hash_map, + fmt::Display, + hash::Hash, io, - net::{SocketAddr, ToSocketAddrs, UdpSocket}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs, UdpSocket}, ops::Deref, sync::atomic::{AtomicBool, Ordering}, time::{Duration, Instant}, @@ -31,60 +34,54 @@ use crate::{ stats::Stats, }; -#[cfg(not(feature = "ipv6"))] -mod protocol { - use std::net::SocketAddr; +pub trait AddrExt: Sized + Eq + Hash + Display + Copy + ToSocketAddrs + ServerAddress { + type Ip; - pub type Addr = std::net::SocketAddrV4; - pub type Ip = std::net::Ipv4Addr; + fn extract(addr: SocketAddr) -> Result; + fn ip(&self) -> &Self::Ip; + fn wrap(self) -> SocketAddr; +} + +impl AddrExt for SocketAddrV4 { + type Ip = Ipv4Addr; - pub fn extract_addr(addr: SocketAddr) -> Option { - if let SocketAddr::V4(a) = addr { - Some(a) + fn extract(addr: SocketAddr) -> Result { + if let SocketAddr::V4(addr) = addr { + Ok(addr) } else { - None + Err(addr) } } - #[inline(always)] - pub fn wrap_addr(addr: Addr) -> SocketAddr { - SocketAddr::V4(addr) + fn ip(&self) -> &Self::Ip { + SocketAddrV4::ip(self) } - #[inline(always)] - pub fn check_addr(addr: &SocketAddr) -> bool { - addr.is_ipv4() + fn wrap(self) -> SocketAddr { + SocketAddr::V4(self) } } -#[cfg(feature = "ipv6")] -mod protocol { - use std::net::SocketAddr; +impl AddrExt for SocketAddrV6 { + type Ip = Ipv6Addr; - pub type Addr = std::net::SocketAddrV6; - pub type Ip = std::net::Ipv6Addr; - - pub fn extract_addr(addr: SocketAddr) -> Option { - if let SocketAddr::V6(a) = addr { - Some(a) + fn extract(addr: SocketAddr) -> Result { + if let SocketAddr::V6(addr) = addr { + Ok(addr) } else { - None + Err(addr) } } - #[inline(always)] - pub fn wrap_addr(addr: Addr) -> SocketAddr { - SocketAddr::V6(addr) + fn ip(&self) -> &Self::Ip { + SocketAddrV6::ip(self) } - #[inline(always)] - pub fn check_addr(addr: &SocketAddr) -> bool { - addr.is_ipv6() + fn wrap(self) -> SocketAddr { + SocketAddr::V6(self) } } -use self::protocol::*; - /// The maximum size of UDP packets. const MAX_PACKET_SIZE: usize = 512; @@ -104,8 +101,6 @@ const ADMIN_LIMIT_CLEANUP_MAX: usize = 100; pub enum Error { #[error("Failed to bind server socket: {0}")] BindSocket(io::Error), - #[error("IP version is not supported")] - Unsupported, #[error(transparent)] Protocol(#[from] ProtocolError), #[error(transparent)] @@ -136,8 +131,8 @@ impl Entry { } impl Entry { - fn matches(&self, addr: Addr, region: Region, filter: &Filter) -> bool { - self.region == region && filter.matches(wrap_addr(addr), &self.value) + fn matches(&self, addr: Addr, region: Region, filter: &Filter) -> bool { + self.region == region && filter.matches(addr.wrap(), &self.value) } } @@ -170,7 +165,7 @@ impl Counter { } } -pub struct MasterServer { +pub struct MasterServer { sock: UdpSocket, challenges: HashMap>, challenges_counter: Counter, @@ -187,15 +182,15 @@ pub struct MasterServer { update_map: Box, update_addr: SocketAddr, - admin_challenges: HashMap>, + admin_challenges: HashMap>, admin_challenges_counter: Counter, admin_list: Box<[config::AdminConfig]>, // rate limit if hash is invalid - admin_limit: HashMap>, + admin_limit: HashMap>, admin_limit_counter: Counter, hash: config::HashConfig, - blocklist: HashSet, + blocklist: HashSet, stats: Stats, @@ -233,18 +228,52 @@ fn resolve_update_addr(cfg: &Config, local_addr: SocketAddr) -> SocketAddr { local_addr } -impl MasterServer { +pub enum Master { + V4(MasterServer), + V6(MasterServer), +} + +impl Master { pub fn new(cfg: Config) -> Result { - let addr = SocketAddr::new(cfg.server.ip, cfg.server.port); - if !check_addr(&addr) { - return Err(Error::Unsupported); + match SocketAddr::new(cfg.server.ip, cfg.server.port) { + SocketAddr::V4(addr) => MasterServer::new(cfg, addr).map(Self::V4), + SocketAddr::V6(addr) => MasterServer::new(cfg, addr).map(Self::V6), } + } + + pub fn update_config(&mut self, cfg: Config) -> Result<(), Error> { + let cfg = match self { + Self::V4(inner) => inner.update_config(cfg)?, + Self::V6(inner) => inner.update_config(cfg)?, + }; + if let Some(cfg) = cfg { + info!("Server IP version changed, full restart"); + *self = Self::new(cfg)?; + } + Ok(()) + } + + pub fn run(&mut self, sig_flag: &AtomicBool) -> Result<(), Error> { + match self { + Self::V4(inner) => inner.run(sig_flag), + Self::V6(inner) => inner.run(sig_flag), + } + } +} + +impl MasterServer +where + Addr: AddrExt, + Addr::Ip: Eq + Hash + Display + Copy + std::str::FromStr, +{ + pub fn new(cfg: Config, addr: Addr) -> Result { info!("Listen address: {}", addr); + let sock = UdpSocket::bind(addr).map_err(Error::BindSocket)?; // make socket interruptable by singals sock.set_read_timeout(Some(Duration::from_secs(u32::MAX as u64)))?; - let update_addr = resolve_update_addr(&cfg, addr); + let update_addr = resolve_update_addr(&cfg, addr.wrap()); Ok(Self { sock, @@ -274,10 +303,16 @@ impl MasterServer { }) } - pub fn update_config(&mut self, cfg: Config) -> Result<(), Error> { - let local_addr = self.sock.local_addr()?; + fn local_addr(&self) -> io::Result { + self.sock.local_addr() + } + + pub fn update_config(&mut self, cfg: Config) -> Result, Error> { + let local_addr = self.local_addr()?; let addr = SocketAddr::new(cfg.server.ip, cfg.server.port); - if local_addr != addr { + if local_addr.is_ipv4() != addr.is_ipv4() { + return Ok(Some(cfg)); + } else if local_addr != addr { info!("Listen address: {}", addr); self.sock = UdpSocket::bind(addr).map_err(Error::BindSocket)?; // make socket interruptable by singals @@ -295,7 +330,7 @@ impl MasterServer { self.hash = cfg.hash; self.stats.update_config(cfg.stat); - Ok(()) + Ok(None) } pub fn run(&mut self, sig_flag: &AtomicBool) -> Result<(), Error> { @@ -310,14 +345,9 @@ impl MasterServer { }, }; - let from = match extract_addr(from) { - Some(from) => from, - None => { - if from.is_ipv6() { - debug!("{}: IPv6 is not implemented", from); - } - continue; - } + let from = match Addr::extract(from) { + Ok(from) => from, + Err(_) => continue, }; let src = &buf[..n]; @@ -596,7 +626,7 @@ impl MasterServer { } } - fn count_servers(&self, ip: &Ip) -> u16 { + fn count_servers(&self, ip: &Addr::Ip) -> u16 { self.servers.keys().filter(|i| i.ip() == ip).count() as u16 } @@ -645,7 +675,7 @@ impl MasterServer { fn send_client_to_nat_servers(&self, to: Addr, servers: &[Addr]) -> Result<(), Error> { let mut buf = [0; 64]; - let n = master::ClientAnnounce::new(wrap_addr(to)).encode(&mut buf)?; + let n = master::ClientAnnounce::new(to.wrap()).encode(&mut buf)?; let buf = &buf[..n]; for i in servers { self.sock.send_to(buf, i)?; @@ -654,15 +684,20 @@ impl MasterServer { } #[inline] - fn is_blocked(&self, ip: &Ip) -> bool { + fn is_blocked(&self, ip: &Addr::Ip) -> bool { self.blocklist.contains(ip) } fn admin_command(&mut self, cmd: &str) { let args: Vec<_> = cmd.split(' ').collect(); - fn helper(args: &[&str], mut op: F) { - let iter = args.iter().map(|i| (i, i.parse::())); + fn helper(args: &[&str], mut op: F) + where + Addr: AddrExt, + Addr::Ip: std::str::FromStr, + F: FnMut(&str, Addr::Ip), + { + let iter = args.iter().map(|i| (i, i.parse::())); for (i, ip) in iter { match ip { Ok(ip) => op(i, ip), @@ -673,14 +708,14 @@ impl MasterServer { match args[0] { "ban" => { - helper(&args[1..], |_, ip| { + helper::(&args[1..], |_, ip| { if self.blocklist.insert(ip) { info!("ban ip: {}", ip); } }); } "unban" => { - helper(&args[1..], |_, ip| { + helper::(&args[1..], |_, ip| { if self.blocklist.remove(&ip) { info!("unban ip: {}", ip); }