diff --git a/protocol/src/master.rs b/protocol/src/master.rs index 9272f7e..bba4a65 100644 --- a/protocol/src/master.rs +++ b/protocol/src/master.rs @@ -76,6 +76,10 @@ impl<'a> QueryServersResponse<&'a [u8]> { SocketAddrV4::new(ip, port) }) } + + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } } impl QueryServersResponse diff --git a/query/src/main.rs b/query/src/main.rs index 64d597a..f986f30 100644 --- a/query/src/main.rs +++ b/query/src/main.rs @@ -7,10 +7,9 @@ use std::cmp; use std::collections::{HashMap, HashSet}; use std::fmt; use std::io; -use std::net::{Ipv4Addr, SocketAddrV4, UdpSocket}; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, UdpSocket}; use std::process; -use std::sync::{mpsc, Arc}; -use std::thread; +use std::sync::Arc; use std::time::{Duration, Instant}; use serde::{Serialize, Serializer}; @@ -50,34 +49,34 @@ enum ServerResultKind { #[derive(Clone, Debug, Serialize)] struct ServerResult { - address: String, + address: SocketAddrV4, ping: f32, #[serde(flatten)] kind: ServerResultKind, } impl ServerResult { - fn new(address: String, ping: f32, kind: ServerResultKind) -> Self { + fn new(address: SocketAddrV4, ping: f32, kind: ServerResultKind) -> Self { Self { - address: address.to_string(), + address, ping, kind, } } - fn ok(address: String, ping: f32, info: ServerInfo) -> Self { + fn ok(address: SocketAddrV4, ping: f32, info: ServerInfo) -> Self { Self::new(address, ping, ServerResultKind::Ok { info }) } - fn timeout(address: String) -> Self { + fn timeout(address: SocketAddrV4) -> Self { Self::new(address, 0.0, ServerResultKind::Timeout) } - fn protocol(address: String, ping: f32) -> Self { + fn protocol(address: SocketAddrV4, ping: f32) -> Self { Self::new(address, ping, ServerResultKind::Protocol) } - fn error(address: String, message: T) -> Self + fn error(address: SocketAddrV4, message: T) -> Self where T: fmt::Display, { @@ -90,7 +89,7 @@ impl ServerResult { ) } - fn invalid(address: String, ping: f32, message: T, response: &[u8]) -> Self + fn invalid(address: SocketAddrV4, ping: f32, message: T, response: &[u8]) -> Self where T: fmt::Display, { @@ -152,7 +151,7 @@ struct ListResult<'a> { master_timeout: u32, masters: &'a [Box], filter: &'a str, - servers: &'a [&'a str], + servers: &'a [SocketAddrV4], } fn serialize_colored(s: &str, ser: S) -> Result @@ -203,159 +202,243 @@ impl fmt::Display for Colored<'_> { } } -enum Message { - Servers(Vec), - ServerResult(ServerResult), - End, +fn get_socket_addrs<'a>(iter: impl Iterator) -> Result, Error> { + use std::net::ToSocketAddrs; + + let mut out = Vec::with_capacity(iter.size_hint().0); + for i in iter { + match i + .to_socket_addrs()? + .find(|i| matches!(i, SocketAddr::V4(_))) + { + Some(SocketAddr::V4(addr)) => out.push(addr), + _ => eprintln!("warn: failed to resolve address for {}", i), + } + } + + Ok(out) +} + +struct ServerQuery { + start: Instant, + protocol: usize, } -fn cmp_address(a: &str, b: &str) -> cmp::Ordering { - match (a.parse::(), b.parse::()) { - (Ok(a), Ok(b)) => a.cmp(&b), - _ => a.cmp(b), +impl ServerQuery { + fn ping(&self) -> f32 { + self.start.elapsed().as_micros() as f32 / 1000.0 } } -fn query_servers( - host: &str, - cli: &Cli, - timeout: Duration, - tx: &mpsc::Sender, -) -> Result<(), Error> { - let sock = UdpSocket::bind("0.0.0.0:0")?; - sock.connect(host)?; - - let mut buf = [0; 512]; - let p = game::QueryServers { - region: server::Region::RestOfTheWorld, - last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), - filter: cli.filter.as_str(), - }; - let n = p.encode(&mut buf)?; - sock.send(&buf[..n])?; - - let start_time = Instant::now(); - while let Some(timeout) = timeout.checked_sub(start_time.elapsed()) { - sock.set_read_timeout(Some(timeout))?; - let n = match sock.recv(&mut buf) { - Ok(n) => n, - Err(e) => match e.kind() { - io::ErrorKind::AddrInUse | io::ErrorKind::WouldBlock => break, - _ => Err(e)?, - }, - }; - if let Ok(packet) = master::QueryServersResponse::decode(&buf[..n]) { - tx.send(Message::Servers( - packet.iter().map(|i| i.to_string()).collect(), - )) - .unwrap(); - } else { - eprintln!("Unexpected packet from master {}", host); +impl ServerQuery { + fn new(protocol: usize) -> Self { + Self { + start: Instant::now(), + protocol, } } +} - Ok(()) +struct Scan<'a> { + cli: &'a Cli, + masters: Vec, + sock: UdpSocket, } -fn get_server_info( - addr: String, - versions: &[u8], - timeout: Duration, -) -> Result { - let sock = UdpSocket::bind("0.0.0.0:0")?; - sock.connect(&addr)?; - sock.set_read_timeout(Some(timeout))?; - - let mut ping = 0.0; - for &i in versions { - let p = game::GetServerInfo::new(i); - let mut buf = [0; 2048]; - let n = p.encode(&mut buf)?; - let start = Instant::now(); - sock.send(&buf[..n])?; - - let n = match sock.recv(&mut buf) { - Ok(n) => n, - Err(e) => match e.kind() { - io::ErrorKind::AddrInUse | io::ErrorKind::WouldBlock => { - return Ok(ServerResult::timeout(addr)); - } - _ => Err(e)?, - }, +impl<'a> Scan<'a> { + fn new(cli: &'a Cli) -> Result { + Ok(Self { + cli, + masters: get_socket_addrs(cli.masters.iter().map(|i| i.as_ref()))?, + sock: UdpSocket::bind(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0))?, + }) + } + + fn is_master(&self, addr: &SocketAddrV4) -> bool { + self.masters.iter().any(|i| i == addr) + } + + fn query_servers(&self) -> Result<(), Error> { + let mut buf = [0; 512]; + let packet = game::QueryServers { + region: server::Region::RestOfTheWorld, + last: SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0), + filter: self.cli.filter.as_str(), }; - ping = start.elapsed().as_micros() as f32 / 1000.0; + let n = packet.encode(&mut buf)?; + let packet = &buf[..n]; - let response = &buf[..n]; - match server::GetServerInfoResponse::decode(response) { - Ok(packet) => { - let info = ServerInfo::from(packet); - return Ok(ServerResult::ok(addr, ping, info)); - } - Err(ProtocolError::InvalidProtocolVersion) => { - // try another protocol version - } - Err(e) => { - return Ok(ServerResult::invalid(addr, ping, e, response)); - } + for i in &self.masters { + self.sock.send_to(packet, i)?; } + + Ok(()) } - Ok(ServerResult::protocol(addr, ping)) -} + fn servers(&self) -> Result, Error> { + self.query_servers()?; -fn query_server_info(cli: &Arc, servers: &[String]) -> Result<(), Error> { - let (tx, rx) = mpsc::channel(); - let mut workers = 0; - - if servers.is_empty() { - for i in cli.masters.iter() { - let master = i.to_owned(); - let tx = tx.clone(); - let timeout = Duration::from_secs(cli.master_timeout as u64); - let cli = cli.clone(); - thread::spawn(move || { - if let Err(e) = query_servers(&master, &cli, timeout, &tx) { - eprintln!("master({}) error: {}", master, e); + let mut set = HashSet::with_capacity(256); + let mut buf = [0; 2048]; + let timeout = Duration::from_secs(self.cli.master_timeout as u64); + let start_time = Instant::now(); + + while let Some(timeout) = timeout.checked_sub(start_time.elapsed()) { + self.sock.set_read_timeout(Some(timeout))?; + + let (n, from) = match self.sock.recv_from(&mut buf) { + Ok(x) => x, + Err(e) => match e.kind() { + io::ErrorKind::AddrInUse => break, + io::ErrorKind::WouldBlock => break, + _ => Err(e)?, + }, + }; + + let from = match from { + SocketAddr::V4(x) => x, + _ => todo!(), + }; + + if self.is_master(&from) { + if let Ok(packet) = master::QueryServersResponse::decode(&buf[..n]) { + set.extend(packet.iter()); + } else { + eprintln!("warn: invalid packet from master {}", from); } - tx.send(Message::End).unwrap(); - }); - workers += 1; + } } - } else { - tx.send(Message::Servers(servers.to_vec())).unwrap(); + + Ok(set) } - let mut servers = HashMap::new(); - while let Ok(msg) = rx.recv() { - match msg { - Message::Servers(list) => { - for address in list { - let tx = tx.clone(); - let timeout = Duration::from_secs(cli.server_timeout as u64); - let versions = cli.protocol.clone(); - thread::spawn(move || { - let result = get_server_info(address.clone(), &versions, timeout) - .unwrap_or_else(|e| ServerResult::error(address, e)); - tx.send(Message::ServerResult(result)).unwrap(); - tx.send(Message::End).unwrap(); - }); - workers += 1; + fn server_info( + &self, + list: &[SocketAddrV4], + ) -> Result, Error> { + let mut set = HashSet::new(); + let mut active = HashMap::new(); + let mut out = HashMap::new(); + let mut buf = [0; 2048]; + + let now = Instant::now(); + let master_timeout = Duration::from_secs(self.cli.master_timeout as u64); + let server_timeout = Duration::from_secs(self.cli.server_timeout as u64); + let master_end = now + master_timeout; + let mut server_end = now + server_timeout; + + if list.is_empty() { + self.query_servers()?; + } else { + let mut buf = [0; 512]; + let n = game::GetServerInfo::new(self.cli.protocol[0]).encode(&mut buf)?; + + for addr in list.iter().filter(|i| set.insert(**i)) { + match self.sock.send_to(&buf[..n], addr) { + Ok(_) => { + let query = ServerQuery::new(0); + server_end = query.start + server_timeout; + active.insert(*addr, query); + } + Err(e) => { + let res = ServerResult::error(*addr, e); + out.insert(*addr, res); + } } } - Message::End => { - workers -= 1; - if workers == 0 { - break; - } + } + + loop { + let time = cmp::max(master_end, server_end); + match time.checked_duration_since(Instant::now()) { + Some(t) => self.sock.set_read_timeout(Some(t))?, + None => break, } - Message::ServerResult(result) => { - servers.insert(result.address.clone(), result); + + let (n, from) = match self.sock.recv_from(&mut buf) { + Ok(x) => x, + Err(e) => match e.kind() { + io::ErrorKind::AddrInUse => break, + io::ErrorKind::WouldBlock => break, + _ => Err(e)?, + }, + }; + let from = match from { + SocketAddr::V4(x) => x, + _ => todo!(), + }; + let raw = &buf[..n]; + + if self.is_master(&from) { + if let Ok(packet) = master::QueryServersResponse::decode(raw) { + for addr in packet.iter().filter(|i| set.insert(*i)) { + let mut buf = [0; 512]; + let n = game::GetServerInfo::new(self.cli.protocol[0]).encode(&mut buf)?; + + match self.sock.send_to(&buf[..n], addr) { + Ok(_) => { + let query = ServerQuery::new(0); + server_end = query.start + server_timeout; + active.insert(addr, query); + } + Err(e) => { + let res = ServerResult::error(addr, e); + out.insert(addr, res); + } + } + } + } + } else if let Some(query) = active.remove(&from) { + match server::GetServerInfoResponse::decode(raw) { + Ok(packet) => { + let info = ServerInfo::from(packet); + let res = ServerResult::ok(from, query.ping(), info); + out.insert(from, res); + } + Err(ProtocolError::InvalidProtocolVersion) => { + let next_protocol = query.protocol + 1; + if let Some(protocol) = self.cli.protocol.get(next_protocol) { + let mut buf = [0; 512]; + let n = game::GetServerInfo::new(*protocol).encode(&mut buf)?; + + match self.sock.send_to(&buf[..n], from) { + Ok(_) => { + active.insert(from, ServerQuery::new(next_protocol)); + } + Err(e) => { + let res = ServerResult::error(from, e); + out.insert(from, res); + } + } + } else { + let res = ServerResult::protocol(from, query.ping()); + out.insert(from, res); + } + } + Err(e) => { + let res = ServerResult::invalid(from, query.ping(), e, raw); + out.insert(from, res); + } + } } } + + for (addr, _) in active { + let res = ServerResult::timeout(addr); + out.insert(addr, res); + } + + Ok(out) } +} + +fn query_server_info(cli: &Arc, servers: &[String]) -> Result<(), Error> { + let scan = Scan::new(cli)?; + let servers = get_socket_addrs(servers.iter().map(|i| i.as_str()))?; + let servers = scan.server_info(&servers)?; let mut servers: Vec<_> = servers.values().collect(); - servers.sort_by(|a, b| cmp_address(&a.address, &b.address)); + servers.sort_by(|a, b| a.address.cmp(&b.address)); if cli.json || cli.debug { let result = InfoResult { @@ -432,41 +515,9 @@ fn query_server_info(cli: &Arc, servers: &[String]) -> Result<(), Error> { } fn list_servers(cli: &Arc) -> Result<(), Error> { - let (tx, rx) = mpsc::channel(); - let mut workers = 0; - - for i in cli.masters.iter() { - let master = i.to_owned(); - let tx = tx.clone(); - let timeout = Duration::from_secs(cli.master_timeout as u64); - let cli = cli.clone(); - thread::spawn(move || { - if let Err(e) = query_servers(&master, &cli, timeout, &tx) { - eprintln!("master({}) error: {}", master, e); - } - tx.send(Message::End).unwrap(); - }); - workers += 1; - } - - let mut servers = HashSet::new(); - while let Ok(msg) = rx.recv() { - match msg { - Message::Servers(list) => { - servers.extend(list); - } - Message::End => { - workers -= 1; - if workers == 0 { - break; - } - } - _ => panic!(), - } - } - - let mut servers: Vec<_> = servers.iter().map(|i| i.as_str()).collect(); - servers.sort_by(|a, b| cmp_address(a, b)); + let scan = Scan::new(cli)?; + let mut servers: Vec<_> = scan.servers()?.into_iter().collect(); + servers.sort(); if cli.json || cli.debug { let result = ListResult {