diff --git a/master/src/config.rs b/master/src/config.rs index 32157e9..7c06d5a 100644 --- a/master/src/config.rs +++ b/master/src/config.rs @@ -1,18 +1,19 @@ // SPDX-License-Identifier: GPL-3.0-only // SPDX-FileCopyrightText: 2023 Denis Drakhnia -use std::fs; -use std::io; -use std::net::{IpAddr, Ipv4Addr}; -use std::path::Path; +use std::{ + fs, io, + net::{IpAddr, Ipv4Addr}, + path::Path, +}; use log::LevelFilter; use serde::{de::Error as _, Deserialize, Deserializer}; use thiserror::Error; -use xash3d_protocol::admin; -use xash3d_protocol::filter::Version; +use xash3d_protocol::{admin, filter::Version}; + +pub const DEFAULT_MASTER_SERVER_IP: IpAddr = IpAddr::V4(Ipv4Addr::UNSPECIFIED); -pub const DEFAULT_MASTER_SERVER_IP: IpAddr = IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)); pub const DEFAULT_MASTER_SERVER_PORT: u16 = 27010; pub const DEFAULT_CHALLENGE_TIMEOUT: u32 = 10; pub const DEFAULT_SERVER_TIMEOUT: u32 = 300; diff --git a/master/src/main.rs b/master/src/main.rs index 7c9bb3e..97eeb0c 100644 --- a/master/src/main.rs +++ b/master/src/main.rs @@ -9,9 +9,13 @@ mod logger; mod master_server; mod stats; -use std::process; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::{ + process, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; use log::{error, info}; #[cfg(not(windows))] @@ -19,7 +23,7 @@ use signal_hook::{consts::signal::*, flag as signal_flag}; use crate::cli::Cli; use crate::config::Config; -use crate::master_server::{Error, MasterServer}; +use crate::master_server::{Error, Master}; fn load_config(cli: &Cli) -> Result { let mut cfg = match cli.config_path { @@ -64,7 +68,7 @@ fn run() -> Result<(), Error> { process::exit(1); }); - let mut master = MasterServer::new(cfg)?; + let mut master = Master::new(cfg)?; let sig_flag = Arc::new(AtomicBool::new(false)); // XXX: Windows does not support SIGUSR1. #[cfg(not(windows))] diff --git a/master/src/master_server.rs b/master/src/master_server.rs index 27dbd69..d6d8ad4 100644 --- a/master/src/master_server.rs +++ b/master/src/master_server.rs @@ -1,25 +1,87 @@ // SPDX-License-Identifier: GPL-3.0-only // SPDX-FileCopyrightText: 2023 Denis Drakhnia -use std::collections::hash_map; -use std::io; -use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs, UdpSocket}; -use std::ops::Deref; -use std::sync::atomic::{AtomicBool, Ordering}; -use std::time::{Duration, Instant}; +use std::{ + cmp::Eq, + collections::hash_map, + fmt::Display, + hash::Hash, + io, + net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs, UdpSocket}, + ops::Deref, + str::FromStr, + sync::atomic::{AtomicBool, Ordering}, + time::{Duration, Instant}, +}; use ahash::{AHashMap as HashMap, AHashSet as HashSet}; use blake2b_simd::Params; use fastrand::Rng; use log::{debug, error, info, trace, warn}; use thiserror::Error; -use xash3d_protocol::filter::{Filter, FilterFlags, Version}; -use xash3d_protocol::server::Region; -use xash3d_protocol::wrappers::Str; -use xash3d_protocol::{admin, game, master, server, Error as ProtocolError, ServerInfo}; +use xash3d_protocol::{ + admin, + filter::{Filter, FilterFlags, Version}, + game, + master::{self, ServerAddress}, + server, + server::Region, + wrappers::Str, + Error as ProtocolError, ServerInfo, +}; + +use crate::{ + config::{self, Config}, + stats::Stats, +}; + +pub trait AddrExt: Sized + Eq + Hash + Display + Copy + ToSocketAddrs + ServerAddress { + type Ip: Eq + Hash + Display + Copy + FromStr; + + fn extract(addr: SocketAddr) -> Result; + fn ip(&self) -> &Self::Ip; + fn wrap(self) -> SocketAddr; +} + +impl AddrExt for SocketAddrV4 { + type Ip = Ipv4Addr; + + fn extract(addr: SocketAddr) -> Result { + if let SocketAddr::V4(addr) = addr { + Ok(addr) + } else { + Err(addr) + } + } -use crate::config::{self, Config}; -use crate::stats::Stats; + fn ip(&self) -> &Self::Ip { + SocketAddrV4::ip(self) + } + + fn wrap(self) -> SocketAddr { + SocketAddr::V4(self) + } +} + +impl AddrExt for SocketAddrV6 { + type Ip = Ipv6Addr; + + fn extract(addr: SocketAddr) -> Result { + if let SocketAddr::V6(addr) = addr { + Ok(addr) + } else { + Err(addr) + } + } + + fn ip(&self) -> &Self::Ip { + SocketAddrV6::ip(self) + } + + fn wrap(self) -> SocketAddr { + SocketAddr::V6(self) + } +} /// The maximum size of UDP packets. const MAX_PACKET_SIZE: usize = 512; @@ -70,8 +132,8 @@ impl Entry { } impl Entry { - fn matches(&self, addr: SocketAddrV4, region: Region, filter: &Filter) -> bool { - self.region == region && filter.matches(addr, &self.value) + fn matches(&self, addr: Addr, region: Region, filter: &Filter) -> bool { + self.region == region && filter.matches(addr.wrap(), &self.value) } } @@ -104,11 +166,11 @@ impl Counter { } } -pub struct MasterServer { +pub struct MasterServer { sock: UdpSocket, - challenges: HashMap>, + challenges: HashMap>, challenges_counter: Counter, - servers: HashMap>, + servers: HashMap>, servers_counter: Counter, max_servers_per_ip: u16, rng: Rng, @@ -119,39 +181,38 @@ pub struct MasterServer { clver: Version, update_title: Box, update_map: Box, - update_addr: SocketAddrV4, + update_addr: SocketAddr, - admin_challenges: HashMap>, + admin_challenges: HashMap>, admin_challenges_counter: Counter, admin_list: Box<[config::AdminConfig]>, // rate limit if hash is invalid - admin_limit: HashMap>, + admin_limit: HashMap>, admin_limit_counter: Counter, hash: config::HashConfig, - blocklist: HashSet, + blocklist: HashSet, stats: Stats, // temporary data - filtered_servers: Vec, - filtered_servers_nat: Vec, + filtered_servers: Vec, + filtered_servers_nat: Vec, } -fn resolve_socket_addr(addr: A) -> io::Result> +fn resolve_socket_addr(addr: A, is_ipv4: bool) -> io::Result> where A: ToSocketAddrs, { for i in addr.to_socket_addrs()? { - match i { - SocketAddr::V4(i) => return Ok(Some(i)), - SocketAddr::V6(_) => {} + if i.is_ipv4() == is_ipv4 { + return Ok(Some(i)); } } Ok(None) } -fn resolve_update_addr(cfg: &Config, local_addr: SocketAddr) -> SocketAddrV4 { +fn resolve_update_addr(cfg: &Config, local_addr: SocketAddr) -> SocketAddr { if let Some(s) = cfg.client.update_addr.as_deref() { let addr = if !s.contains(':') { format!("{}:{}", s, local_addr.port()) @@ -159,28 +220,57 @@ fn resolve_update_addr(cfg: &Config, local_addr: SocketAddr) -> SocketAddrV4 { s.to_owned() }; - match resolve_socket_addr(&addr) { + match resolve_socket_addr(&addr, local_addr.is_ipv4()) { Ok(Some(x)) => return x, - Ok(None) => error!("Update address: failed to resolve IPv4 for \"{}\"", addr), + Ok(None) => error!("Update address: failed to resolve IP for \"{}\"", addr), Err(e) => error!("Update address: {}", e), } } + local_addr +} - match local_addr { - SocketAddr::V4(x) => x, - SocketAddr::V6(_) => todo!(), - } +pub enum Master { + V4(MasterServer), + V6(MasterServer), } -impl MasterServer { +impl Master { pub fn new(cfg: Config) -> Result { - let addr = SocketAddr::new(cfg.server.ip, cfg.server.port); + match SocketAddr::new(cfg.server.ip, cfg.server.port) { + SocketAddr::V4(addr) => MasterServer::new(cfg, addr).map(Self::V4), + SocketAddr::V6(addr) => MasterServer::new(cfg, addr).map(Self::V6), + } + } + + pub fn update_config(&mut self, cfg: Config) -> Result<(), Error> { + let cfg = match self { + Self::V4(inner) => inner.update_config(cfg)?, + Self::V6(inner) => inner.update_config(cfg)?, + }; + if let Some(cfg) = cfg { + info!("Server IP version changed, full restart"); + *self = Self::new(cfg)?; + } + Ok(()) + } + + pub fn run(&mut self, sig_flag: &AtomicBool) -> Result<(), Error> { + match self { + Self::V4(inner) => inner.run(sig_flag), + Self::V6(inner) => inner.run(sig_flag), + } + } +} + +impl MasterServer { + pub fn new(cfg: Config, addr: Addr) -> Result { info!("Listen address: {}", addr); + let sock = UdpSocket::bind(addr).map_err(Error::BindSocket)?; // make socket interruptable by singals sock.set_read_timeout(Some(Duration::from_secs(u32::MAX as u64)))?; - let update_addr = resolve_update_addr(&cfg, addr); + let update_addr = resolve_update_addr(&cfg, addr.wrap()); Ok(Self { sock, @@ -210,10 +300,16 @@ impl MasterServer { }) } - pub fn update_config(&mut self, cfg: Config) -> Result<(), Error> { - let local_addr = self.sock.local_addr()?; + fn local_addr(&self) -> io::Result { + self.sock.local_addr() + } + + pub fn update_config(&mut self, cfg: Config) -> Result, Error> { + let local_addr = self.local_addr()?; let addr = SocketAddr::new(cfg.server.ip, cfg.server.port); - if local_addr != addr { + if local_addr.is_ipv4() != addr.is_ipv4() { + return Ok(Some(cfg)); + } else if local_addr != addr { info!("Listen address: {}", addr); self.sock = UdpSocket::bind(addr).map_err(Error::BindSocket)?; // make socket interruptable by singals @@ -231,7 +327,7 @@ impl MasterServer { self.hash = cfg.hash; self.stats.update_config(cfg.stat); - Ok(()) + Ok(None) } pub fn run(&mut self, sig_flag: &AtomicBool) -> Result<(), Error> { @@ -246,12 +342,9 @@ impl MasterServer { }, }; - let from = match from { - SocketAddr::V4(a) => a, - _ => { - warn!("{}: Received message from IPv6, unimplemented", from); - continue; - } + let from = match Addr::extract(from) { + Ok(from) => from, + Err(_) => continue, }; let src = &buf[..n]; @@ -271,7 +364,7 @@ impl MasterServer { self.stats.clear(); } - fn handle_server_packet(&mut self, from: SocketAddrV4, p: server::Packet) -> Result<(), Error> { + fn handle_server_packet(&mut self, from: Addr, p: server::Packet) -> Result<(), Error> { trace!("{}: recv {:?}", from, p); match p { @@ -320,13 +413,20 @@ impl MasterServer { Ok(()) } - fn handle_game_packet(&mut self, from: SocketAddrV4, p: game::Packet) -> Result<(), Error> { + fn handle_game_packet(&mut self, from: Addr, p: game::Packet) -> Result<(), Error> { trace!("{}: recv {:?}", from, p); match p { game::Packet::QueryServers(p) => { if p.filter.clver.map_or(false, |v| v < self.clver) { - self.send_server_list(from, p.filter.key, &[self.update_addr])?; + match self.update_addr { + SocketAddr::V4(addr) => { + self.send_server_list(from, p.filter.key, &[addr])?; + } + SocketAddr::V6(addr) => { + self.send_server_list(from, p.filter.key, &[addr])?; + } + } } else { let now = self.now(); @@ -376,7 +476,7 @@ impl MasterServer { Ok(()) } - fn handle_admin_packet(&mut self, from: SocketAddrV4, p: admin::Packet) -> Result<(), Error> { + fn handle_admin_packet(&mut self, from: Addr, p: admin::Packet) -> Result<(), Error> { trace!("{}: recv {:?}", from, p); let now = self.now(); @@ -449,7 +549,7 @@ impl MasterServer { Ok(()) } - fn handle_packet(&mut self, from: SocketAddrV4, src: &[u8]) -> Result<(), Error> { + fn handle_packet(&mut self, from: Addr, src: &[u8]) -> Result<(), Error> { if self.is_blocked(from.ip()) { return Ok(()); } @@ -479,7 +579,7 @@ impl MasterServer { self.start_time.elapsed().as_secs() as u32 } - fn add_challenge(&mut self, addr: SocketAddrV4) -> u32 { + fn add_challenge(&mut self, addr: Addr) -> u32 { let x = self.rng.u32(..); let entry = Entry::new(self.now(), x); self.challenges.insert(addr, entry); @@ -494,7 +594,7 @@ impl MasterServer { } } - fn admin_challenge_add(&mut self, addr: SocketAddrV4) -> (u32, u32) { + fn admin_challenge_add(&mut self, addr: Addr) -> (u32, u32) { let x = self.rng.u32(..); let y = self.rng.u32(..); let entry = Entry::new(self.now(), (x, y)); @@ -502,7 +602,7 @@ impl MasterServer { (x, y) } - fn admin_challenge_remove(&mut self, addr: SocketAddrV4) { + fn admin_challenge_remove(&mut self, addr: Addr) { self.admin_challenges.remove(addr.ip()); } @@ -523,11 +623,11 @@ impl MasterServer { } } - fn count_servers(&self, addr: &Ipv4Addr) -> u16 { - self.servers.keys().filter(|i| i.ip() == addr).count() as u16 + fn count_servers(&self, ip: &Addr::Ip) -> u16 { + self.servers.keys().filter(|i| i.ip() == ip).count() as u16 } - fn add_server(&mut self, addr: SocketAddrV4, server: ServerInfo) { + fn add_server(&mut self, addr: Addr, server: ServerInfo) { let now = self.now(); match self.servers.entry(addr) { hash_map::Entry::Occupied(mut e) => { @@ -554,34 +654,25 @@ impl MasterServer { } } - fn send_server_list( - &self, - to: A, - key: Option, - servers: &[SocketAddrV4], - ) -> Result<(), Error> + fn send_server_list(&self, to: A, key: Option, servers: &[S]) -> Result<(), Error> where A: ToSocketAddrs, + S: ServerAddress, { - let mut list = master::QueryServersResponse::new(servers.iter().copied(), key); - loop { - let mut buf = [0; MAX_PACKET_SIZE]; - let (n, is_end) = list.encode(&mut buf)?; + let mut buf = [0; MAX_PACKET_SIZE]; + let mut offset = 0; + let mut list = master::QueryServersResponse::new(key); + while offset < servers.len() { + let (n, c) = list.encode(&mut buf, &servers[offset..])?; + offset += c; self.sock.send_to(&buf[..n], &to)?; - if is_end { - break; - } } Ok(()) } - fn send_client_to_nat_servers( - &self, - to: SocketAddrV4, - servers: &[SocketAddrV4], - ) -> Result<(), Error> { + fn send_client_to_nat_servers(&self, to: Addr, servers: &[Addr]) -> Result<(), Error> { let mut buf = [0; 64]; - let n = master::ClientAnnounce::new(to).encode(&mut buf)?; + let n = master::ClientAnnounce::new(to.wrap()).encode(&mut buf)?; let buf = &buf[..n]; for i in servers { self.sock.send_to(buf, i)?; @@ -590,15 +681,19 @@ impl MasterServer { } #[inline] - fn is_blocked(&self, ip: &Ipv4Addr) -> bool { + fn is_blocked(&self, ip: &Addr::Ip) -> bool { self.blocklist.contains(ip) } fn admin_command(&mut self, cmd: &str) { let args: Vec<_> = cmd.split(' ').collect(); - fn helper(args: &[&str], mut op: F) { - let iter = args.iter().map(|i| (i, i.parse::())); + fn helper(args: &[&str], mut op: F) + where + Addr: AddrExt, + F: FnMut(&str, Addr::Ip), + { + let iter = args.iter().map(|i| (i, i.parse::())); for (i, ip) in iter { match ip { Ok(ip) => op(i, ip), @@ -609,14 +704,14 @@ impl MasterServer { match args[0] { "ban" => { - helper(&args[1..], |_, ip| { + helper::(&args[1..], |_, ip| { if self.blocklist.insert(ip) { info!("ban ip: {}", ip); } }); } "unban" => { - helper(&args[1..], |_, ip| { + helper::(&args[1..], |_, ip| { if self.blocklist.remove(&ip) { info!("unban ip: {}", ip); } diff --git a/master/src/stats/stub.rs b/master/src/stats/stub.rs index 4df92a0..b353b67 100644 --- a/master/src/stats/stub.rs +++ b/master/src/stats/stub.rs @@ -6,7 +6,9 @@ struct Counters; pub struct Stats; impl Stats { - pub fn new(_: StatConfig) -> Self { Self } + pub fn new(_: StatConfig) -> Self { + Self + } pub fn update_config(&mut self, _: StatConfig) {} pub fn clear(&self) {} pub fn servers_count(&self, _: usize) {} diff --git a/protocol/src/filter.rs b/protocol/src/filter.rs index 5194f8f..135dc9e 100644 --- a/protocol/src/filter.rs +++ b/protocol/src/filter.rs @@ -30,7 +30,7 @@ //! * Is not protected by a password use std::fmt; -use std::net::SocketAddrV4; +use std::net::SocketAddr; use std::str::FromStr; use bitflags::bitflags; @@ -196,7 +196,8 @@ impl Filter<'_> { } /// Returns `true` if a server matches the filter. - pub fn matches(&self, _addr: SocketAddrV4, info: &ServerInfo) -> bool { + pub fn matches(&self, _addr: SocketAddr, info: &ServerInfo) -> bool { + // TODO: match addr !((info.flags & self.flags_mask) != self.flags || self.gamedir.map_or(false, |s| *s != &*info.gamedir) || self.map.map_or(false, |s| *s != &*info.map) @@ -308,6 +309,7 @@ mod tests { use super::*; use crate::cursor::CursorMut; use crate::wrappers::Str; + use std::net::SocketAddr; macro_rules! tests { ($($name:ident$(($($predefined_f:ident: $predefined_v:expr),+ $(,)?))? { @@ -450,7 +452,7 @@ mod tests { macro_rules! servers { ($($addr:expr => $info:expr $(=> $func:expr)?)+) => ( [$({ - let addr = $addr.parse::().unwrap(); + let addr = $addr.parse::().unwrap(); let mut buf = [0; 512]; let n = CursorMut::new(&mut buf) .put_bytes(ServerAdd::HEADER).unwrap() diff --git a/protocol/src/game.rs b/protocol/src/game.rs index c725857..1d4819d 100644 --- a/protocol/src/game.rs +++ b/protocol/src/game.rs @@ -4,7 +4,7 @@ //! Game client packets. use std::fmt; -use std::net::SocketAddrV4; +use std::net::SocketAddr; use crate::cursor::{Cursor, CursorMut}; use crate::filter::Filter; @@ -17,7 +17,7 @@ pub struct QueryServers { /// Servers must be from the `region`. pub region: Region, /// Last received server address __(not used)__. - pub last: SocketAddrV4, + pub last: SocketAddr, /// Select only servers that match the `filter`. pub filter: T, } @@ -131,13 +131,13 @@ mod tests { use super::*; use crate::filter::{FilterFlags, Version}; use crate::wrappers::Str; - use std::net::Ipv4Addr; + use std::net::{IpAddr, Ipv4Addr}; #[test] fn query_servers() { let p = QueryServers { region: Region::RestOfTheWorld, - last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), + last: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), filter: Filter { gamedir: Some(Str(&b"valve"[..])), map: Some(Str(&b"crossfire"[..])), @@ -157,7 +157,7 @@ mod tests { fn query_servers_filter_bug() { let p = QueryServers { region: Region::RestOfTheWorld, - last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), + last: SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0), filter: Filter { gamedir: None, protocol: Some(48), diff --git a/protocol/src/master.rs b/protocol/src/master.rs index 85fc978..5581ab5 100644 --- a/protocol/src/master.rs +++ b/protocol/src/master.rs @@ -3,7 +3,7 @@ //! Master server packets. -use std::net::{Ipv4Addr, SocketAddrV4}; +use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; use super::cursor::{Cursor, CursorMut}; use super::Error; @@ -58,6 +58,54 @@ impl ChallengeResponse { } } +/// Helper trait for dealing with server addresses. +pub trait ServerAddress: Sized { + /// Size of IP and port in bytes. + fn size() -> usize; + + /// Read address from a cursor. + fn get(cur: &mut Cursor) -> Result; + + /// Write address to a cursor. + fn put(&self, cur: &mut CursorMut) -> Result<(), Error>; +} + +impl ServerAddress for SocketAddrV4 { + fn size() -> usize { + 6 + } + + fn get(cur: &mut Cursor) -> Result { + let ip = Ipv4Addr::from(cur.get_array()?); + let port = cur.get_u16_be()?; + Ok(SocketAddrV4::new(ip, port)) + } + + fn put(&self, cur: &mut CursorMut) -> Result<(), Error> { + cur.put_array(&self.ip().octets())?; + cur.put_u16_be(self.port())?; + Ok(()) + } +} + +impl ServerAddress for SocketAddrV6 { + fn size() -> usize { + 18 + } + + fn get(cur: &mut Cursor) -> Result { + let ip = Ipv6Addr::from(cur.get_array()?); + let port = cur.get_u16_be()?; + Ok(SocketAddrV6::new(ip, port, 0, 0)) + } + + fn put(&self, cur: &mut CursorMut) -> Result<(), Error> { + cur.put_array(&self.ip().octets())?; + cur.put_u16_be(self.port())?; + Ok(()) + } +} + /// Game server addresses list. #[derive(Clone, Debug, PartialEq)] pub struct QueryServersResponse { @@ -76,33 +124,31 @@ impl<'a> QueryServersResponse<&'a [u8]> { pub fn decode(src: &'a [u8]) -> Result { let mut cur = Cursor::new(src); cur.expect(QueryServersResponse::HEADER)?; - if cur.remaining() % 6 != 0 { - return Err(Error::InvalidPacket); - } - let s = cur.get_bytes(cur.remaining())?; + let s = cur.end(); // extra header for key sent in QueryServers packet - let (s, key) = if s.len() >= 6 && s[0] == 0x7f && s[5] == 8 { - (&s[6..], Some(u32::from_le_bytes([s[1], s[2], s[3], s[4]]))) + let (inner, key) = if s.len() >= 6 && s[0] == 0x7f && s[5] == 8 { + let key = u32::from_le_bytes([s[1], s[2], s[3], s[4]]); + (&s[6..], Some(key)) } else { (s, None) }; - let inner = if s.ends_with(&[0; 6]) { - &s[..s.len() - 6] - } else { - s - }; Ok(Self { inner, key }) } /// Iterator over game server addresses. - pub fn iter(&self) -> impl 'a + Iterator { + pub fn iter(&self) -> impl 'a + Iterator + where + A: ServerAddress, + { let mut cur = Cursor::new(self.inner); - (0..self.inner.len() / 6).map(move |_| { - let ip = Ipv4Addr::from(cur.get_array().unwrap()); - let port = cur.get_u16_be().unwrap(); - SocketAddrV4::new(ip, port) + std::iter::from_fn(move || { + if cur.remaining() == A::size() && cur.end().ends_with(&[0; 2]) { + // skip last address with port 0 + return None; + } + A::get(&mut cur).ok() }) } @@ -112,13 +158,10 @@ impl<'a> QueryServersResponse<&'a [u8]> { } } -impl QueryServersResponse -where - I: Iterator, -{ +impl QueryServersResponse<()> { /// Creates a new `QueryServersResponse`. - pub fn new(iter: I, key: Option) -> Self { - Self { inner: iter, key } + pub fn new(key: Option) -> Self { + Self { inner: (), key } } /// Encode packet to `buf`. @@ -127,7 +170,10 @@ where /// multiple times until the end flag equals `true`. /// /// Returns how many bytes was written in `buf` and the end flag. - pub fn encode(&mut self, buf: &mut [u8]) -> Result<(usize, bool), Error> { + pub fn encode(&mut self, buf: &mut [u8], list: &[A]) -> Result<(usize, usize), Error> + where + A: ServerAddress, + { let mut cur = CursorMut::new(buf); cur.put_bytes(QueryServersResponse::HEADER)?; if let Some(key) = self.key { @@ -135,19 +181,20 @@ where cur.put_u32_le(key)?; cur.put_u8(8)?; } - let mut is_end = false; - while cur.remaining() >= 12 { - match self.inner.next() { - Some(i) => { - cur.put_array(&i.ip().octets())?.put_u16_be(i.port())?; - } - None => { - is_end = true; - break; - } + let mut count = 0; + let mut iter = list.iter(); + while cur.remaining() >= A::size() * 2 { + if let Some(i) = iter.next() { + i.put(&mut cur)?; + count += 1; + } else { + break; } } - Ok((cur.put_array(&[0; 6])?.pos(), is_end)) + for _ in 0..A::size() { + cur.put_u8(0)?; + } + Ok((cur.pos(), count)) } } @@ -155,7 +202,7 @@ where #[derive(Clone, Debug, PartialEq)] pub struct ClientAnnounce { /// Address of the client. - pub addr: SocketAddrV4, + pub addr: SocketAddr, } impl ClientAnnounce { @@ -163,7 +210,7 @@ impl ClientAnnounce { pub const HEADER: &'static [u8] = b"\xff\xff\xff\xffc "; /// Creates a new `ClientAnnounce`. - pub fn new(addr: SocketAddrV4) -> Self { + pub fn new(addr: SocketAddr) -> Self { Self { addr } } @@ -296,18 +343,39 @@ mod tests { } #[test] - fn query_servers_response() { - let servers: &[SocketAddrV4] = &[ + fn query_servers_response_ipv4() { + type Addr = SocketAddrV4; + let servers: &[Addr] = &[ "1.2.3.4:27001".parse().unwrap(), "1.2.3.4:27002".parse().unwrap(), "1.2.3.4:27003".parse().unwrap(), "1.2.3.4:27004".parse().unwrap(), ]; - let mut p = QueryServersResponse::new(servers.iter().cloned(), Some(0xdeadbeef)); + let mut p = QueryServersResponse::new(Some(0xdeadbeef)); + let mut buf = [0; 512]; + let (n, c) = p.encode(&mut buf, servers).unwrap(); + assert_eq!(c, servers.len()); + assert_eq!(n, 12 + Addr::size() * (servers.len() + 1)); + let e = QueryServersResponse::decode(&buf[..n]).unwrap(); + assert_eq!(e.iter::().collect::>(), servers); + } + + #[test] + fn query_servers_response_ipv6() { + type Addr = SocketAddrV6; + let servers: &[Addr] = &[ + "[::1]:27001".parse().unwrap(), + "[::2]:27002".parse().unwrap(), + "[::3]:27003".parse().unwrap(), + "[::4]:27004".parse().unwrap(), + ]; + let mut p = QueryServersResponse::new(Some(0xdeadbeef)); let mut buf = [0; 512]; - let (n, _) = p.encode(&mut buf).unwrap(); + let (n, c) = p.encode(&mut buf, servers).unwrap(); + assert_eq!(c, servers.len()); + assert_eq!(n, 12 + Addr::size() * (servers.len() + 1)); let e = QueryServersResponse::decode(&buf[..n]).unwrap(); - assert_eq!(e.iter().collect::>(), servers); + assert_eq!(e.iter::().collect::>(), servers); } #[test] diff --git a/supply-chain/imports.lock b/supply-chain/imports.lock index 0ffde12..0f43add 100644 --- a/supply-chain/imports.lock +++ b/supply-chain/imports.lock @@ -363,7 +363,7 @@ who = "Nick Fitzgerald " criteria = "safe-to-deploy" user-id = 696 # Nick Fitzgerald (fitzgen) start = "2019-03-16" -end = "2024-03-10" +end = "2025-07-30" [[audits.bytecode-alliance.audits.arrayref]] who = "Nick Fitzgerald "