diff --git a/Cargo.lock b/Cargo.lock index 87d3c52..403e8ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -141,6 +141,12 @@ dependencies = [ "cc", ] +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + [[package]] name = "js-sys" version = "0.3.64" @@ -219,6 +225,12 @@ dependencies = [ "redox_syscall", ] +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + [[package]] name = "serde" version = "1.0.188" @@ -239,6 +251,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "syn" version = "2.0.37" @@ -481,3 +504,14 @@ dependencies = [ "log", "thiserror", ] + +[[package]] +name = "xash3d-query" +version = "0.1.0" +dependencies = [ + "getopts", + "serde", + "serde_json", + "thiserror", + "xash3d-protocol", +] diff --git a/Cargo.toml b/Cargo.toml index 8985122..319ed37 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,4 +4,5 @@ members = [ "protocol", "master", "admin", + "query", ] diff --git a/protocol/src/cursor.rs b/protocol/src/cursor.rs index 8e7f0bb..620df7b 100644 --- a/protocol/src/cursor.rs +++ b/protocol/src/cursor.rs @@ -33,6 +33,24 @@ impl<'a> GetKeyValue<'a> for &'a str { } } +impl<'a> GetKeyValue<'a> for Box { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + let raw = cur.get_key_value_raw()?; + str::from_utf8(raw) + .map(|s| s.to_owned().into_boxed_str()) + .map_err(|_| Error::InvalidString) + } +} + +impl<'a> GetKeyValue<'a> for String { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + let raw = cur.get_key_value_raw()?; + str::from_utf8(raw) + .map(|s| s.to_owned()) + .map_err(|_| Error::InvalidString) + } +} + impl<'a> GetKeyValue<'a> for bool { fn get_key_value(cur: &mut Cursor<'a>) -> Result { match cur.get_key_value_raw()? { @@ -91,6 +109,10 @@ impl<'a> Cursor<'a> { self.buffer } + pub fn as_slice(&'a self) -> &'a [u8] { + self.buffer + } + #[inline(always)] pub fn remaining(&self) -> usize { self.buffer.len() diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index e4fde48..dd082f0 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -15,7 +15,7 @@ pub use server_info::ServerInfo; use thiserror::Error; -pub const VERSION: u32 = 49; +pub const VERSION: u8 = 49; #[derive(Error, Debug, PartialEq, Eq)] pub enum Error { @@ -25,4 +25,6 @@ pub enum Error { InvalidString, #[error("Unexpected end of buffer")] UnexpectedEnd, + #[error("Invalid protocol version")] + InvalidProtocolVersion, } diff --git a/protocol/src/server.rs b/protocol/src/server.rs index ed64158..cf181ce 100644 --- a/protocol/src/server.rs +++ b/protocol/src/server.rs @@ -361,6 +361,16 @@ where let mut cur = Cursor::new(src); cur.expect(GetServerInfoResponse::HEADER)?; + if !cur.as_slice().starts_with(b"\\") { + let s = cur.get_str(cur.remaining())?; + let p = s.rfind(':').ok_or(Error::InvalidPacket)?; + let msg = &s[p + 1..]; + match msg.trim() { + "wrong version" => return Err(Error::InvalidProtocolVersion), + _ => return Err(Error::InvalidPacket), + } + } + let mut ret = Self::default(); loop { let key = match cur.get_key_raw() { diff --git a/protocol/src/types.rs b/protocol/src/types.rs index 9c12859..d9a9eb9 100644 --- a/protocol/src/types.rs +++ b/protocol/src/types.rs @@ -15,6 +15,15 @@ impl From for Str { } impl fmt::Debug for Str +where + T: AsRef<[u8]>, +{ + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "\"{}\"", self) + } +} + +impl fmt::Display for Str where T: AsRef<[u8]>, { @@ -23,6 +32,7 @@ where match c { b'\n' => write!(fmt, "\\n")?, b'\t' => write!(fmt, "\\t")?, + b'\\' => write!(fmt, "\\\\")?, _ if c.is_ascii_graphic() || c == b' ' => { write!(fmt, "{}", c as char)?; } @@ -33,15 +43,6 @@ where } } -impl fmt::Display for Str -where - T: AsRef<[u8]>, -{ - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - ::fmt(self, fmt) - } -} - impl Deref for Str { type Target = T; diff --git a/query/Cargo.toml b/query/Cargo.toml new file mode 100644 index 0000000..88102e6 --- /dev/null +++ b/query/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "xash3d-query" +version = "0.1.0" +license = "GPL-3.0-only" +authors = ["Denis Drakhnia "] +edition = "2021" +rust-version = "1.56" + +[dependencies] +thiserror = "1.0.49" +getopts = "0.2.21" +serde = { version = "1.0.188", features = ["derive"] } +serde_json = "1.0.107" +xash3d-protocol = { path = "../protocol", version = "0.1.0" } diff --git a/query/src/cli.rs b/query/src/cli.rs new file mode 100644 index 0000000..361831b --- /dev/null +++ b/query/src/cli.rs @@ -0,0 +1,166 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +use std::process; + +use getopts::Options; + +const BIN_NAME: &str = env!("CARGO_BIN_NAME"); +const PKG_NAME: &str = env!("CARGO_PKG_NAME"); +const PKG_VERSION: &str = env!("CARGO_PKG_VERSION"); + +const DEFAULT_HOST: &str = "mentality.rip"; +const DEFAULT_PORT: u16 = 27010; + +#[derive(Debug)] +pub struct Cli { + pub masters: Vec>, + pub args: Vec, + pub master_timeout: u32, + pub server_timeout: u32, + pub protocol: Vec, + pub json: bool, + pub debug: bool, +} + +impl Default for Cli { + fn default() -> Cli { + Cli { + masters: vec![ + format!("{}:{}", DEFAULT_HOST, DEFAULT_PORT).into_boxed_str(), + //format!("{}:{}", DEFAULT_HOST, DEFAULT_PORT + 1).into_boxed_str(), + ], + args: Default::default(), + master_timeout: 2, + server_timeout: 2, + protocol: vec![xash3d_protocol::VERSION, xash3d_protocol::VERSION - 1], + json: false, + debug: false, + } + } +} + +fn print_usage(opts: Options) { + let brief = format!( + "\ +Usage: {} [options] [ARGS] + +COMMANDS: + all fetch servers from all masters and fetch info for each server + info hosts... fetch info for each server + list fetch servers from all masters and print server addresses\ + ", + BIN_NAME + ); + print!("{}", opts.usage(&brief)); +} + +fn print_version() { + println!("{} v{}", PKG_NAME, PKG_VERSION); +} + +pub fn parse() -> Cli { + let mut cli = Cli::default(); + + let args: Vec<_> = std::env::args().collect(); + let mut opts = Options::new(); + opts.optflag("h", "help", "print usage help"); + opts.optflag("v", "version", "print program version"); + let help = format!( + "master address to connect [default: {}]", + cli.masters.join(",") + ); + opts.optopt("m", "master", &help, "LIST"); + let help = format!( + "time to wait results from masters [default: {}]", + cli.master_timeout + ); + opts.optopt("T", "master-timeout", &help, "SECONDS"); + let help = format!( + "time to wait results from servers [default: {}]", + cli.server_timeout + ); + opts.optopt("t", "server-timeout", &help, "SECONDS"); + let protocols = cli + .protocol + .iter() + .map(|&i| format!("{}", i)) + .collect::>() + .join(","); + let help = format!("protocol version [default: {}]", protocols); + opts.optopt("p", "protocol", &help, "VERSION"); + opts.optflag("j", "json", "output JSON"); + opts.optflag("d", "debug", "output debug"); + + let matches = match opts.parse(&args[1..]) { + Ok(m) => m, + Err(e) => { + eprintln!("{}", e); + process::exit(1); + } + }; + + if matches.opt_present("help") { + print_usage(opts); + process::exit(0); + } + + if matches.opt_present("version") { + print_version(); + process::exit(0); + } + + if let Some(s) = matches.opt_str("master") { + cli.masters.clear(); + + for mut i in s.split(',').map(String::from) { + if !i.contains(':') { + i.push_str(":27010"); + } + cli.masters.push(i.into_boxed_str()); + } + } + + match matches.opt_get("master") { + Ok(Some(t)) => cli.master_timeout = t, + Ok(None) => {} + Err(_) => { + eprintln!("Invalid master-timeout"); + process::exit(1); + } + } + + match matches.opt_get("server-timeout") { + Ok(Some(t)) => cli.server_timeout = t, + Ok(None) => {} + Err(_) => { + eprintln!("Invalid server-timeout"); + process::exit(1); + } + } + + if let Some(s) = matches.opt_str("protocol") { + cli.protocol.clear(); + + let mut error = false; + for i in s.split(',') { + match i.parse() { + Ok(i) => cli.protocol.push(i), + Err(_) => { + eprintln!("Invalid protocol version: {}", i); + error = true; + } + } + } + + if error { + process::exit(1); + } + } + + cli.json = matches.opt_present("json"); + cli.debug = matches.opt_present("debug"); + cli.args = matches.free; + + cli +} diff --git a/query/src/main.rs b/query/src/main.rs new file mode 100644 index 0000000..4aec381 --- /dev/null +++ b/query/src/main.rs @@ -0,0 +1,447 @@ +// SPDX-License-Identifier: GPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +mod cli; + +use std::cmp; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::io; +use std::net::{Ipv4Addr, SocketAddrV4, UdpSocket}; +use std::process; +use std::sync::mpsc; +use std::thread; +use std::time::{Duration, Instant}; + +use serde::Serialize; +use thiserror::Error; +use xash3d_protocol::types::Str; +use xash3d_protocol::{filter, game, master, server, Error as ProtocolError}; + +use crate::cli::Cli; + +#[derive(Error, Debug)] +enum Error { + #[error("Undefined command")] + UndefinedCommand, + #[error(transparent)] + Protocol(#[from] ProtocolError), + #[error(transparent)] + Io(#[from] io::Error), +} + +#[derive(Clone, Debug, Serialize)] +#[serde(tag = "type")] +enum ServerResultKind { + #[serde(rename = "ok")] + Ok { info: ServerInfo }, + #[serde(rename = "error")] + Error { message: String }, + #[serde(rename = "invalid")] + Invalid { message: String, response: String }, + #[serde(rename = "timeout")] + Timeout, + #[serde(rename = "protocol")] + Protocol, +} + +#[derive(Clone, Debug, Serialize)] +struct ServerResult { + address: String, + #[serde(flatten)] + kind: ServerResultKind, +} + +impl ServerResult { + fn new(address: String, kind: ServerResultKind) -> Self { + Self { + address: address.to_string(), + kind, + } + } + + fn ok>(address: String, info: T) -> Self { + Self::new(address, ServerResultKind::Ok { info: info.into() }) + } + + fn timeout(address: String) -> Self { + Self::new(address, ServerResultKind::Timeout) + } + + fn protocol(address: String) -> Self { + Self::new(address, ServerResultKind::Protocol) + } + + fn error(address: String, message: T) -> Self + where + T: fmt::Display, + { + Self::new( + address, + ServerResultKind::Error { + message: message.to_string(), + }, + ) + } + + fn invalid(address: String, message: T, response: &[u8]) -> Self + where + T: fmt::Display, + { + Self::new( + address, + ServerResultKind::Invalid { + message: message.to_string(), + response: Str(response).to_string(), + }, + ) + } +} + +#[derive(Clone, Debug, Serialize)] +struct ServerInfo { + pub gamedir: String, + pub map: String, + pub host: String, + pub protocol: u8, + pub numcl: u8, + pub maxcl: u8, + pub dm: bool, + pub team: bool, + pub coop: bool, + pub password: bool, +} + +impl From> for ServerInfo { + fn from(other: server::GetServerInfoResponse<&str>) -> Self { + Self { + gamedir: other.gamedir.to_owned(), + map: other.map.to_owned(), + host: other.host.to_owned(), + protocol: other.protocol, + numcl: other.numcl, + maxcl: other.maxcl, + dm: other.dm, + team: other.team, + coop: other.coop, + password: other.password, + } + } +} + +#[derive(Clone, Debug, Serialize)] +struct InfoResult<'a> { + protocol: &'a [u8], + master_timeout: u32, + server_timeout: u32, + masters: &'a [Box], + servers: &'a [&'a ServerResult], +} + +#[derive(Clone, Debug, Serialize)] +struct ListResult<'a> { + master_timeout: u32, + masters: &'a [Box], + servers: &'a [&'a str], +} + +enum Message { + Servers(Vec), + ServerResult(ServerResult), + End, +} + +fn cmp_address(a: &str, b: &str) -> cmp::Ordering { + match (a.parse::(), b.parse::()) { + (Ok(a), Ok(b)) => a.cmp(&b), + _ => a.cmp(b), + } +} + +fn query_servers(host: &str, timeout: Duration, tx: &mpsc::Sender) -> Result<(), Error> { + let sock = UdpSocket::bind("0.0.0.0:0")?; + sock.connect(host)?; + + let mut buf = [0; 512]; + let p = game::QueryServers { + region: server::Region::RestOfTheWorld, + last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), + filter: filter::Filter { + gamedir: b"valve", + clver: filter::Version::new(0, 20), + // TODO: filter + ..Default::default() + }, + }; + let n = p.encode(&mut buf)?; + sock.send(&buf[..n])?; + + let start_time = Instant::now(); + while let Some(timeout) = timeout.checked_sub(start_time.elapsed()) { + sock.set_read_timeout(Some(timeout))?; + let n = match sock.recv(&mut buf) { + Ok(n) => n, + Err(e) => match e.kind() { + io::ErrorKind::AddrInUse | io::ErrorKind::WouldBlock => break, + _ => Err(e)?, + }, + }; + if let Ok(packet) = master::QueryServersResponse::decode(&buf[..n]) { + tx.send(Message::Servers( + packet.iter().map(|i| i.to_string()).collect(), + )) + .unwrap(); + } else { + eprintln!("Unexpected packet from master {}", host); + } + } + + Ok(()) +} + +fn get_server_info( + addr: String, + versions: &[u8], + timeout: Duration, +) -> Result { + let sock = UdpSocket::bind("0.0.0.0:0")?; + sock.connect(&addr)?; + sock.set_read_timeout(Some(timeout))?; + + for &i in versions { + let p = game::GetServerInfo::new(i); + let mut buf = [0; 2048]; + let n = p.encode(&mut buf)?; + sock.send(&buf[..n])?; + + let n = match sock.recv(&mut buf) { + Ok(n) => n, + Err(e) => match e.kind() { + io::ErrorKind::AddrInUse | io::ErrorKind::WouldBlock => { + return Ok(ServerResult::timeout(addr)); + } + _ => Err(e)?, + }, + }; + + let response = &buf[..n]; + match server::GetServerInfoResponse::decode(response) { + Ok(packet) => { + return Ok(ServerResult::ok(addr, packet)); + } + Err(ProtocolError::InvalidProtocolVersion) => { + // try another protocol version + } + Err(e) => { + return Ok(ServerResult::invalid(addr, e, response)); + } + } + } + + Ok(ServerResult::protocol(addr)) +} + +fn query_server_info(cli: &Cli, servers: &[String]) -> Result<(), Error> { + let (tx, rx) = mpsc::channel(); + let mut workers = 0; + + if servers.is_empty() { + for i in cli.masters.iter() { + let master = i.to_owned(); + let tx = tx.clone(); + let timeout = Duration::from_secs(cli.master_timeout as u64); + thread::spawn(move || { + if let Err(e) = query_servers(&master, timeout, &tx) { + eprintln!("master({}) error: {}", master, e); + } + tx.send(Message::End).unwrap(); + }); + workers += 1; + } + } else { + tx.send(Message::Servers(servers.to_vec())).unwrap(); + } + + let mut servers = HashMap::new(); + while let Ok(msg) = rx.recv() { + match msg { + Message::Servers(list) => { + for address in list { + let tx = tx.clone(); + let timeout = Duration::from_secs(cli.server_timeout as u64); + let versions = cli.protocol.clone(); + thread::spawn(move || { + let result = get_server_info(address.clone(), &versions, timeout) + .unwrap_or_else(|e| ServerResult::error(address, e)); + tx.send(Message::ServerResult(result)).unwrap(); + tx.send(Message::End).unwrap(); + }); + workers += 1; + } + } + Message::End => { + workers -= 1; + if workers == 0 { + break; + } + } + Message::ServerResult(result) => { + servers.insert(result.address.clone(), result); + } + } + } + + let mut servers: Vec<_> = servers.values().collect(); + servers.sort_by(|a, b| cmp_address(&a.address, &b.address)); + + if cli.json || cli.debug { + let result = InfoResult { + protocol: &cli.protocol, + master_timeout: cli.master_timeout, + server_timeout: cli.server_timeout, + masters: &cli.masters, + servers: &servers, + }; + + if cli.json { + println!("{}", serde_json::to_string_pretty(&result).unwrap()); + } else if cli.debug { + println!("{:#?}", result); + } else { + todo!() + } + } else { + for i in servers { + println!("server: {}", i.address); + + macro_rules! p { + ($($key:ident: $value:expr),+ $(,)?) => { + $(println!(" {}: \"{}\"", stringify!($key), $value);)+ + }; + } + + match &i.kind { + ServerResultKind::Ok { info } => { + p! { + type: "ok", + host: info.host, + gamedir: info.gamedir, + map: info.map, + protocol: info.protocol, + numcl: info.numcl, + maxcl: info.maxcl, + dm: info.dm, + team: info.team, + coop: info.coop, + password: info.password, + } + } + ServerResultKind::Timeout => { + p! { + type: "timeout", + } + } + ServerResultKind::Protocol => { + p! { + type: "protocol", + } + } + ServerResultKind::Error { message } => { + p! { + type: "error", + message: message, + } + } + ServerResultKind::Invalid { message, response } => { + p! { + type: "invalid", + message: message, + response: response, + } + } + } + println!(); + } + } + + Ok(()) +} + +fn list_servers(cli: &Cli) -> Result<(), Error> { + let (tx, rx) = mpsc::channel(); + let mut workers = 0; + + for i in cli.masters.iter() { + let master = i.to_owned(); + let tx = tx.clone(); + let timeout = Duration::from_secs(cli.master_timeout as u64); + thread::spawn(move || { + if let Err(e) = query_servers(&master, timeout, &tx) { + eprintln!("master({}) error: {}", master, e); + } + tx.send(Message::End).unwrap(); + }); + workers += 1; + } + + let mut servers = HashSet::new(); + while let Ok(msg) = rx.recv() { + match msg { + Message::Servers(list) => { + servers.extend(list); + } + Message::End => { + workers -= 1; + if workers == 0 { + break; + } + } + _ => panic!(), + } + } + + let mut servers: Vec<_> = servers.iter().map(|i| i.as_str()).collect(); + servers.sort_by(|a, b| cmp_address(a, b)); + + if cli.json || cli.debug { + let result = ListResult { + master_timeout: cli.master_timeout, + masters: &cli.masters, + servers: &servers, + }; + + if cli.json { + println!("{}", serde_json::to_string_pretty(&result).unwrap()); + } else if cli.debug { + println!("{:#?}", result); + } else { + todo!() + } + } else { + for i in servers { + println!("{}", i); + } + } + + Ok(()) +} + +fn execute(cli: Cli) -> Result<(), Error> { + match cli.args.get(0).map(|s| s.as_str()).unwrap_or_default() { + "all" | "" => query_server_info(&cli, &[])?, + "info" => query_server_info(&cli, &cli.args[1..])?, + "list" => list_servers(&cli)?, + _ => return Err(Error::UndefinedCommand), + } + + Ok(()) +} + +fn main() { + let cli = cli::parse(); + + if let Err(e) = execute(cli) { + eprintln!("error: {}", e); + process::exit(1); + } +}