From 1254e41adf50014fe978e9a17aae367e55aa3902 Mon Sep 17 00:00:00 2001 From: Denis Drakhnia Date: Thu, 19 Oct 2023 17:38:39 +0300 Subject: [PATCH] query: filter parameter --- master/src/master_server.rs | 2 +- protocol/src/color.rs | 5 +++- protocol/src/cursor.rs | 2 +- protocol/src/filter.rs | 59 ++++++++++++++++++++----------------- protocol/src/game.rs | 31 ++++++++++++------- protocol/src/lib.rs | 8 +++-- query/src/cli.rs | 23 +++++++++++++-- query/src/main.rs | 33 +++++++++++++-------- 8 files changed, 106 insertions(+), 57 deletions(-) diff --git a/master/src/master_server.rs b/master/src/master_server.rs index 91b4dc3..723cfc5 100644 --- a/master/src/master_server.rs +++ b/master/src/master_server.rs @@ -229,7 +229,7 @@ impl MasterServer { match p { game::Packet::QueryServers(p) => { trace!("{}: recv {:?}", from, p); - if p.filter.clver < self.clver { + if p.filter.clver.map_or(false, |v| v < self.clver) { let iter = std::iter::once(self.update_addr); self.send_server_list(from, iter)?; } else { diff --git a/protocol/src/color.rs b/protocol/src/color.rs index 0206514..266eabe 100644 --- a/protocol/src/color.rs +++ b/protocol/src/color.rs @@ -59,7 +59,10 @@ impl<'a> Iterator for ColorIter<'a> { fn next(&mut self) -> Option { if !self.inner.is_empty() { - let i = self.inner[1..].find('^').map(|i| i + 1).unwrap_or(self.inner.len()); + let i = self.inner[1..] + .find('^') + .map(|i| i + 1) + .unwrap_or(self.inner.len()); let (head, tail) = self.inner.split_at(i); let (color, text) = trim_start_color(head); self.inner = tail; diff --git a/protocol/src/cursor.rs b/protocol/src/cursor.rs index 4087a54..52effad 100644 --- a/protocol/src/cursor.rs +++ b/protocol/src/cursor.rs @@ -8,7 +8,7 @@ use std::slice; use std::str; use super::types::Str; -use super::{Error, color}; +use super::{color, Error}; pub trait GetKeyValue<'a>: Sized { fn get_key_value(cur: &mut Cursor<'a>) -> Result; diff --git a/protocol/src/filter.rs b/protocol/src/filter.rs index 0aa4572..eb55c70 100644 --- a/protocol/src/filter.rs +++ b/protocol/src/filter.rs @@ -90,7 +90,7 @@ pub struct Version { } impl Version { - pub fn new(major: u8, minor: u8) -> Self { + pub const fn new(major: u8, minor: u8) -> Self { Self { major, minor } } } @@ -136,11 +136,11 @@ impl PutKeyValue for Version { #[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct Filter<'a> { /// Servers running the specified modification (ex. cstrike) - pub gamedir: &'a [u8], + pub gamedir: Option<&'a [u8]>, /// Servers running the specified map (ex. cs_italy) - pub map: &'a [u8], + pub map: Option<&'a [u8]>, /// Client version. - pub clver: Version, + pub clver: Option, pub flags: FilterFlags, pub flags_mask: FilterFlags, @@ -154,13 +154,15 @@ impl Filter<'_> { pub fn matches(&self, _addr: SocketAddrV4, info: &ServerInfo) -> bool { !((info.flags & self.flags_mask) != self.flags - || (!self.gamedir.is_empty() && self.gamedir != &*info.gamedir) - || (!self.map.is_empty() && self.map != &*info.map)) + || self.gamedir.map_or(false, |s| s != &*info.gamedir) + || self.map.map_or(false, |s| s != &*info.map)) } } -impl<'a> Filter<'a> { - pub fn from_bytes(src: &'a [u8]) -> Result { +impl<'a> TryFrom<&'a [u8]> for Filter<'a> { + type Error = Error; + + fn try_from(src: &'a [u8]) -> Result { let mut cur = Cursor::new(src); let mut filter = Self::default(); @@ -174,17 +176,18 @@ impl<'a> Filter<'a> { match *key { b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get_key_value()?), b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get_key_value()?), - b"gamedir" => filter.gamedir = cur.get_key_value()?, - b"map" => filter.map = cur.get_key_value()?, + b"gamedir" => filter.gamedir = Some(cur.get_key_value()?), + b"map" => filter.map = Some(cur.get_key_value()?), b"empty" => filter.insert_flag(FilterFlags::NOT_EMPTY, cur.get_key_value()?), b"full" => filter.insert_flag(FilterFlags::FULL, cur.get_key_value()?), b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get_key_value()?), b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get_key_value()?), b"clver" => { - filter.clver = cur - .get_key_value::<&str>()? - .parse() - .map_err(|_| Error::InvalidPacket)? + filter.clver = Some( + cur.get_key_value::<&str>()? + .parse() + .map_err(|_| Error::InvalidPacket)?, + ); } b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get_key_value()?), b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get_key_value()?), @@ -214,18 +217,20 @@ impl fmt::Display for &Filter<'_> { display_flag!("dedicated", FilterFlags::DEDICATED); display_flag!("secure", FilterFlags::SECURE); - if !self.gamedir.is_empty() { - write!(fmt, "\\gamedir\\{}", Str(self.gamedir))?; + if let Some(s) = self.gamedir { + write!(fmt, "\\gamedir\\{}", Str(s))?; } display_flag!("secure", FilterFlags::SECURE); - if !self.map.is_empty() { - write!(fmt, "\\map\\{}", Str(self.map))?; + if let Some(s) = self.map { + write!(fmt, "\\map\\{}", Str(s))?; } display_flag!("empty", FilterFlags::NOT_EMPTY); display_flag!("full", FilterFlags::FULL); display_flag!("password", FilterFlags::PASSWORD); display_flag!("noplayers", FilterFlags::NOPLAYERS); - write!(fmt, "\\clver\\{}", self.clver)?; + if let Some(v) = self.clver { + write!(fmt, "\\clver\\{}", v)?; + } display_flag!("nat", FilterFlags::NAT); display_flag!("lan", FilterFlags::LAN); display_flag!("bots", FilterFlags::BOTS); @@ -253,7 +258,7 @@ mod tests { .. Filter::default() }; $(assert_eq!( - Filter::from_bytes($src), + Filter::try_from($src as &[u8]), Ok(Filter { $($field: $value,)* ..predefined @@ -266,17 +271,17 @@ mod tests { tests! { parse_gamedir { b"\\gamedir\\valve" => { - gamedir: &b"valve"[..], + gamedir: Some(&b"valve"[..]), } } parse_map { b"\\map\\crossfire" => { - map: &b"crossfire"[..], + map: Some(&b"crossfire"[..]), } } parse_clver { b"\\clver\\0.20" => { - clver: Version::new(0, 20), + clver: Some(Version::new(0, 20)), } } parse_dedicated(flags_mask: FilterFlags::DEDICATED) { @@ -349,9 +354,9 @@ mod tests { \\password\\1\ \\secure\\1\ " => { - gamedir: &b"valve"[..], - map: &b"crossfire"[..], - clver: Version::new(0, 20), + gamedir: Some(&b"valve"[..]), + map: Some(&b"crossfire"[..]), + clver: Some(Version::new(0, 20)), flags: FilterFlags::all(), flags_mask: FilterFlags::all(), } @@ -383,7 +388,7 @@ mod tests { macro_rules! matches { ($servers:expr, $filter:expr$(, $expected:expr)*) => ( let servers = &$servers; - let filter = Filter::from_bytes($filter).unwrap(); + let filter = Filter::try_from($filter as &[u8]).unwrap(); let iter = servers .iter() .enumerate() diff --git a/protocol/src/game.rs b/protocol/src/game.rs index c6df3ef..a5a03ee 100644 --- a/protocol/src/game.rs +++ b/protocol/src/game.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: GPL-3.0-only // SPDX-FileCopyrightText: 2023 Denis Drakhnia +use std::fmt; use std::net::SocketAddrV4; use crate::cursor::{Cursor, CursorMut}; @@ -9,18 +10,23 @@ use crate::server::Region; use crate::Error; #[derive(Clone, Debug, PartialEq)] -pub struct QueryServers<'a> { +pub struct QueryServers { pub region: Region, pub last: SocketAddrV4, - pub filter: Filter<'a>, + pub filter: T, } -impl<'a> QueryServers<'a> { +impl QueryServers<()> { pub const HEADER: &'static [u8] = b"1"; +} +impl<'a, T: 'a> QueryServers +where + T: TryFrom<&'a [u8], Error = Error>, +{ pub fn decode(src: &'a [u8]) -> Result { let mut cur = Cursor::new(src); - cur.expect(Self::HEADER)?; + cur.expect(QueryServers::HEADER)?; let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidPacket)?; let last = cur.get_cstr_as_str()?; let filter = cur.get_cstr()?; @@ -28,13 +34,18 @@ impl<'a> QueryServers<'a> { Ok(Self { region, last: last.parse().map_err(|_| Error::InvalidPacket)?, - filter: Filter::from_bytes(&filter)?, + filter: T::try_from(*filter)?, }) } +} +impl<'a, T: 'a> QueryServers +where + for<'b> &'b T: fmt::Display, +{ pub fn encode(&self, buf: &mut [u8]) -> Result { Ok(CursorMut::new(buf) - .put_bytes(Self::HEADER)? + .put_bytes(QueryServers::HEADER)? .put_u8(self.region as u8)? .put_as_str(self.last)? .put_u8(0)? @@ -76,7 +87,7 @@ impl GetServerInfo { #[derive(Clone, Debug, PartialEq)] pub enum Packet<'a> { - QueryServers(QueryServers<'a>), + QueryServers(QueryServers>), GetServerInfo(GetServerInfo), } @@ -106,9 +117,9 @@ mod tests { region: Region::RestOfTheWorld, last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), filter: Filter { - gamedir: &b"valve"[..], - map: &b"crossfire"[..], - clver: Version::new(0, 20), + gamedir: Some(&b"valve"[..]), + map: Some(&b"crossfire"[..]), + clver: Some(Version::new(0, 20)), flags: FilterFlags::all(), flags_mask: FilterFlags::all(), }, diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index 1ee1bff..83402f3 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -5,18 +5,22 @@ mod cursor; mod server_info; pub mod admin; +pub mod color; pub mod filter; pub mod game; pub mod master; pub mod server; pub mod types; -pub mod color; pub use server_info::ServerInfo; use thiserror::Error; -pub const VERSION: u8 = 49; +use crate::filter::Version; + +pub const PROTOCOL_VERSION: u8 = 49; + +pub const CLIENT_VERSION: Version = Version::new(0, 20); #[derive(Error, Debug, PartialEq, Eq)] pub enum Error { diff --git a/query/src/cli.rs b/query/src/cli.rs index b97e0aa..d87546c 100644 --- a/query/src/cli.rs +++ b/query/src/cli.rs @@ -5,6 +5,8 @@ use std::process; use getopts::Options; +use xash3d_protocol as proto; + const BIN_NAME: &str = env!("CARGO_BIN_NAME"); const PKG_NAME: &str = env!("CARGO_PKG_NAME"); const PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -22,6 +24,7 @@ pub struct Cli { pub json: bool, pub debug: bool, pub force_color: bool, + pub filter: String, } impl Default for Cli { @@ -34,10 +37,12 @@ impl Default for Cli { args: Default::default(), master_timeout: 2, server_timeout: 2, - protocol: vec![xash3d_protocol::VERSION, xash3d_protocol::VERSION - 1], + protocol: vec![proto::PROTOCOL_VERSION, proto::PROTOCOL_VERSION - 1], json: false, debug: false, force_color: false, + // if changed do not forget to update cli parsing + filter: format!("\\gamedir\\valve\\clver\\{}", proto::CLIENT_VERSION), } } } @@ -94,6 +99,8 @@ pub fn parse() -> Cli { opts.optflag("j", "json", "output JSON"); opts.optflag("d", "debug", "output debug"); opts.optflag("F", "force-color", "force colored output"); + let help = format!("query filter [default: {:?}]", cli.filter); + opts.optopt("f", "filter", &help, "FILTER"); let matches = match opts.parse(&args[1..]) { Ok(m) => m, @@ -124,7 +131,7 @@ pub fn parse() -> Cli { } } - match matches.opt_get("master") { + match matches.opt_get("master-timeout") { Ok(Some(t)) => cli.master_timeout = t, Ok(None) => {} Err(_) => { @@ -161,6 +168,18 @@ pub fn parse() -> Cli { } } + if let Some(s) = matches.opt_str("filter") { + let mut filter = String::with_capacity(cli.filter.len() + s.len()); + if !s.contains("\\gamedir") { + filter.push_str("\\gamedir\\valve"); + } + if !s.contains("\\clver") { + filter.push_str("\\clver\\0.20"); + } + filter.push_str(&s); + cli.filter = filter; + } + cli.json = matches.opt_present("json"); cli.debug = matches.opt_present("debug"); cli.force_color = matches.opt_present("force-color"); diff --git a/query/src/main.rs b/query/src/main.rs index 210ed12..056f6b0 100644 --- a/query/src/main.rs +++ b/query/src/main.rs @@ -9,14 +9,14 @@ use std::fmt; use std::io; use std::net::{Ipv4Addr, SocketAddrV4, UdpSocket}; use std::process; -use std::sync::mpsc; +use std::sync::{mpsc, Arc}; use std::thread; use std::time::{Duration, Instant}; use serde::Serialize; use thiserror::Error; use xash3d_protocol::types::Str; -use xash3d_protocol::{color, filter, game, master, server, Error as ProtocolError}; +use xash3d_protocol::{color, game, master, server, Error as ProtocolError}; use crate::cli::Cli; @@ -135,6 +135,7 @@ struct InfoResult<'a> { master_timeout: u32, server_timeout: u32, masters: &'a [Box], + filter: &'a str, servers: &'a [&'a ServerResult], } @@ -142,6 +143,7 @@ struct InfoResult<'a> { struct ListResult<'a> { master_timeout: u32, masters: &'a [Box], + filter: &'a str, servers: &'a [&'a str], } @@ -199,7 +201,12 @@ fn cmp_address(a: &str, b: &str) -> cmp::Ordering { } } -fn query_servers(host: &str, timeout: Duration, tx: &mpsc::Sender) -> Result<(), Error> { +fn query_servers( + host: &str, + cli: &Cli, + timeout: Duration, + tx: &mpsc::Sender, +) -> Result<(), Error> { let sock = UdpSocket::bind("0.0.0.0:0")?; sock.connect(host)?; @@ -207,12 +214,7 @@ fn query_servers(host: &str, timeout: Duration, tx: &mpsc::Sender) -> R let p = game::QueryServers { region: server::Region::RestOfTheWorld, last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), - filter: filter::Filter { - gamedir: b"valve", - clver: filter::Version::new(0, 20), - // TODO: filter - ..Default::default() - }, + filter: cli.filter.as_str(), }; let n = p.encode(&mut buf)?; sock.send(&buf[..n])?; @@ -282,7 +284,7 @@ fn get_server_info( Ok(ServerResult::protocol(addr)) } -fn query_server_info(cli: &Cli, servers: &[String]) -> Result<(), Error> { +fn query_server_info(cli: &Arc, servers: &[String]) -> Result<(), Error> { let (tx, rx) = mpsc::channel(); let mut workers = 0; @@ -291,8 +293,9 @@ fn query_server_info(cli: &Cli, servers: &[String]) -> Result<(), Error> { let master = i.to_owned(); let tx = tx.clone(); let timeout = Duration::from_secs(cli.master_timeout as u64); + let cli = cli.clone(); thread::spawn(move || { - if let Err(e) = query_servers(&master, timeout, &tx) { + if let Err(e) = query_servers(&master, &cli, timeout, &tx) { eprintln!("master({}) error: {}", master, e); } tx.send(Message::End).unwrap(); @@ -341,6 +344,7 @@ fn query_server_info(cli: &Cli, servers: &[String]) -> Result<(), Error> { master_timeout: cli.master_timeout, server_timeout: cli.server_timeout, masters: &cli.masters, + filter: &cli.filter, servers: &servers, }; @@ -408,7 +412,7 @@ fn query_server_info(cli: &Cli, servers: &[String]) -> Result<(), Error> { Ok(()) } -fn list_servers(cli: &Cli) -> Result<(), Error> { +fn list_servers(cli: &Arc) -> Result<(), Error> { let (tx, rx) = mpsc::channel(); let mut workers = 0; @@ -416,8 +420,9 @@ fn list_servers(cli: &Cli) -> Result<(), Error> { let master = i.to_owned(); let tx = tx.clone(); let timeout = Duration::from_secs(cli.master_timeout as u64); + let cli = cli.clone(); thread::spawn(move || { - if let Err(e) = query_servers(&master, timeout, &tx) { + if let Err(e) = query_servers(&master, &cli, timeout, &tx) { eprintln!("master({}) error: {}", master, e); } tx.send(Message::End).unwrap(); @@ -448,6 +453,7 @@ fn list_servers(cli: &Cli) -> Result<(), Error> { let result = ListResult { master_timeout: cli.master_timeout, masters: &cli.masters, + filter: &cli.filter, servers: &servers, }; @@ -468,6 +474,7 @@ fn list_servers(cli: &Cli) -> Result<(), Error> { } fn execute(cli: Cli) -> Result<(), Error> { + let cli = Arc::new(cli); match cli.args.get(0).map(|s| s.as_str()).unwrap_or_default() { "all" | "" => query_server_info(&cli, &[])?, "info" => query_server_info(&cli, &cli.args[1..])?,