diff --git a/admin/src/main.rs b/admin/src/main.rs index 41f7ee7..aea9f59 100644 --- a/admin/src/main.rs +++ b/admin/src/main.rs @@ -31,7 +31,7 @@ fn send_command(cli: &cli::Cli) -> Result<(), Error> { let n = sock.recv(&mut buf)?; let (master_challenge, hash_challenge) = match master::Packet::decode(&buf[..n])? { - master::Packet::AdminChallengeResponse(p) => (p.master_challenge, p.hash_challenge), + Some(master::Packet::AdminChallengeResponse(p)) => (p.master_challenge, p.hash_challenge), _ => return Err(Error::UnexpectedPacket), }; diff --git a/master/src/master_server.rs b/master/src/master_server.rs index 72259f2..ccb63f9 100644 --- a/master/src/master_server.rs +++ b/master/src/master_server.rs @@ -44,6 +44,10 @@ pub enum Error { Io(#[from] io::Error), #[error("Admin challenge do not exist")] AdminChallengeNotFound, + #[error("Undefined packet")] + UndefinedPacket, + #[error("Unexpected packet")] + UnexpectedPacket, } /// HashMap entry to keep tracking creation time. @@ -235,8 +239,9 @@ impl MasterServer { } }; - if let Err(e) = self.handle_packet(from, &buf[..n]) { - error!("{}: {}", from, e); + let src = &buf[..n]; + if let Err(e) = self.handle_packet(from, src) { + debug!("{}: {}: \"{}\"", from, e, Str(src)); } } Ok(()) @@ -249,167 +254,196 @@ impl MasterServer { self.admin_challenges.clear(); } - fn handle_packet(&mut self, from: SocketAddrV4, src: &[u8]) -> Result<(), Error> { - if self.is_blocked(from.ip()) { - return Ok(()); - } - - if let Ok(p) = server::Packet::decode(src) { - match p { - server::Packet::Challenge(p) => { - trace!("{}: recv {:?}", from, p); - let master_challenge = self.add_challenge(from); - let mut buf = [0; MAX_PACKET_SIZE]; - let p = master::ChallengeResponse::new(master_challenge, p.server_challenge); - trace!("{}: send {:?}", from, p); - let n = p.encode(&mut buf)?; - self.sock.send_to(&buf[..n], from)?; - self.remove_outdated_challenges(); - } - server::Packet::ServerAdd(p) => { - trace!("{}: recv {:?}", from, p); - let entry = match self.challenges.get(&from) { - Some(e) => e, - None => { - trace!("{}: Challenge does not exists", from); - return Ok(()); - } - }; - if !entry.is_valid(self.now(), self.timeout.challenge) { - return Ok(()); - } - if p.challenge != entry.value { - warn!( - "{}: Expected challenge {} but received {}", - from, entry.value, p.challenge - ); + fn handle_server_packet(&mut self, from: SocketAddrV4, p: server::Packet) -> Result<(), Error> { + trace!("{}: recv {:?}", from, p); + + match p { + server::Packet::Challenge(p) => { + let master_challenge = self.add_challenge(from); + let mut buf = [0; MAX_PACKET_SIZE]; + let p = master::ChallengeResponse::new(master_challenge, p.server_challenge); + trace!("{}: send {:?}", from, p); + let n = p.encode(&mut buf)?; + self.sock.send_to(&buf[..n], from)?; + self.remove_outdated_challenges(); + } + server::Packet::ServerAdd(p) => { + let entry = match self.challenges.get(&from) { + Some(e) => e, + None => { + trace!("{}: Challenge does not exists", from); return Ok(()); } - if self.challenges.remove(&from).is_some() { - self.add_server(from, ServerInfo::new(&p)); - } - self.remove_outdated_servers(); + }; + if !entry.is_valid(self.now(), self.timeout.challenge) { + return Ok(()); } - _ => { - trace!("{}: recv {:?}", from, p); + if p.challenge != entry.value { + warn!( + "{}: Expected challenge {} but received {}", + from, entry.value, p.challenge + ); + return Ok(()); + } + if self.challenges.remove(&from).is_some() { + self.add_server(from, ServerInfo::new(&p)); } + self.remove_outdated_servers(); + } + server::Packet::ServerRemove => { + // ignore } - } else if let Ok(p) = game::Packet::decode(src) { - match p { - game::Packet::QueryServers(p) => { - trace!("{}: recv {:?}", from, p); - if p.filter.clver.map_or(false, |v| v < self.clver) { - let iter = std::iter::once(self.update_addr); - self.send_server_list(from, p.filter.key, iter)?; - } else { - let now = self.now(); - let iter = self - .servers - .iter() - .filter(|i| i.1.is_valid(now, self.timeout.server)) - .filter(|i| i.1.matches(*i.0, p.region, &p.filter)) - .map(|i| *i.0); - - self.send_server_list(from, p.filter.key, iter.clone())?; - - if p.filter.flags.contains(FilterFlags::NAT) { - self.send_client_to_nat_servers(from, iter)?; - } + _ => { + return Err(Error::UnexpectedPacket); + } + } + + Ok(()) + } + + fn handle_game_packet(&mut self, from: SocketAddrV4, p: game::Packet) -> Result<(), Error> { + trace!("{}: recv {:?}", from, p); + + match p { + game::Packet::QueryServers(p) => { + if p.filter.clver.map_or(false, |v| v < self.clver) { + let iter = std::iter::once(self.update_addr); + self.send_server_list(from, p.filter.key, iter)?; + } else { + let now = self.now(); + let iter = self + .servers + .iter() + .filter(|i| i.1.is_valid(now, self.timeout.server)) + .filter(|i| i.1.matches(*i.0, p.region, &p.filter)) + .map(|i| *i.0); + + self.send_server_list(from, p.filter.key, iter.clone())?; + + if p.filter.flags.contains(FilterFlags::NAT) { + self.send_client_to_nat_servers(from, iter)?; } } - game::Packet::GetServerInfo(p) => { - trace!("{}: recv {:?}", from, p); - let p = server::GetServerInfoResponse { - map: self.update_map.as_ref(), - host: self.update_title.as_ref(), - protocol: 48, // XXX: how to detect what version client will accept? - dm: true, - maxcl: 32, - gamedir: "valve", // XXX: probably must be specific for client... - ..Default::default() - }; - trace!("{}: send {:?}", from, p); - let mut buf = [0; MAX_PACKET_SIZE]; - let n = p.encode(&mut buf)?; - self.sock.send_to(&buf[..n], from)?; - } } - } else if let Ok(p) = admin::Packet::decode(self.hash.len, src) { - let now = self.now(); + game::Packet::GetServerInfo(_) => { + let p = server::GetServerInfoResponse { + map: self.update_map.as_ref(), + host: self.update_title.as_ref(), + protocol: 48, // XXX: how to detect what version client will accept? + dm: true, + maxcl: 32, + gamedir: "valve", // XXX: probably must be specific for client... + ..Default::default() + }; + trace!("{}: send {:?}", from, p); + let mut buf = [0; MAX_PACKET_SIZE]; + let n = p.encode(&mut buf)?; + self.sock.send_to(&buf[..n], from)?; + } + } - if let Some(e) = self.admin_limit.get(from.ip()) { - if e.is_valid(now, self.timeout.admin) { - trace!("{}: rate limit", from); - return Ok(()); - } + Ok(()) + } + + fn handle_admin_packet(&mut self, from: SocketAddrV4, p: admin::Packet) -> Result<(), Error> { + trace!("{}: recv {:?}", from, p); + + let now = self.now(); + + if let Some(e) = self.admin_limit.get(from.ip()) { + if e.is_valid(now, self.timeout.admin) { + trace!("{}: rate limit", from); + return Ok(()); } + } - match p { - admin::Packet::AdminChallenge(p) => { - trace!("{}: recv {:?}", from, p); - let (master_challenge, hash_challenge) = self.admin_challenge_add(from); + match p { + admin::Packet::AdminChallenge => { + let (master_challenge, hash_challenge) = self.admin_challenge_add(from); - let p = master::AdminChallengeResponse::new(master_challenge, hash_challenge); - trace!("{}: send {:?}", from, p); - let mut buf = [0; 64]; - let n = p.encode(&mut buf)?; - self.sock.send_to(&buf[..n], from)?; + let p = master::AdminChallengeResponse::new(master_challenge, hash_challenge); + trace!("{}: send {:?}", from, p); + let mut buf = [0; 64]; + let n = p.encode(&mut buf)?; + self.sock.send_to(&buf[..n], from)?; - self.admin_challenges_cleanup(); + self.admin_challenges_cleanup(); + } + admin::Packet::AdminCommand(p) => { + let entry = *self + .admin_challenges + .get(from.ip()) + .ok_or(Error::AdminChallengeNotFound)?; + + if entry.0 != p.master_challenge { + trace!("{}: master challenge is not valid", from); + return Ok(()); } - admin::Packet::AdminCommand(p) => { - trace!("{}: recv {:?}", from, p); - let entry = *self - .admin_challenges - .get(from.ip()) - .ok_or(Error::AdminChallengeNotFound)?; - - if entry.0 != p.master_challenge { - trace!("{}: master challenge is not valid", from); - return Ok(()); - } - if !entry.is_valid(now, self.timeout.challenge) { - trace!("{}: challenge is outdated", from); - return Ok(()); - } + if !entry.is_valid(now, self.timeout.challenge) { + trace!("{}: challenge is outdated", from); + return Ok(()); + } - let state = Params::new() - .hash_length(self.hash.len) - .key(self.hash.key.as_bytes()) - .personal(self.hash.personal.as_bytes()) - .to_state(); - - let admin = self.admin_list.iter().find(|i| { - let hash = state - .clone() - .update(i.password.as_bytes()) - .update(&entry.1.to_le_bytes()) - .finalize(); - *p.hash == hash.as_bytes() - }); - - match admin { - Some(admin) => { - info!("{}: admin({}), command: {:?}", from, &admin.name, p.command); - self.admin_command(p.command); - self.admin_challenge_remove(from); - } - None => { - warn!("{}: invalid admin hash, command: {:?}", from, p.command); - self.admin_limit.insert(*from.ip(), Entry::new(now, ())); - self.admin_limit_cleanup(); - } + let state = Params::new() + .hash_length(self.hash.len) + .key(self.hash.key.as_bytes()) + .personal(self.hash.personal.as_bytes()) + .to_state(); + + let admin = self.admin_list.iter().find(|i| { + let hash = state + .clone() + .update(i.password.as_bytes()) + .update(&entry.1.to_le_bytes()) + .finalize(); + *p.hash == hash.as_bytes() + }); + + match admin { + Some(admin) => { + info!("{}: admin({}), command: {:?}", from, &admin.name, p.command); + self.admin_command(p.command); + self.admin_challenge_remove(from); + } + None => { + warn!("{}: invalid admin hash, command: {:?}", from, p.command); + self.admin_limit.insert(*from.ip(), Entry::new(now, ())); + self.admin_limit_cleanup(); } } } - } else { - debug!("{}: invalid packet: \"{}\"", from, Str(src)); } Ok(()) } + fn handle_packet(&mut self, from: SocketAddrV4, src: &[u8]) -> Result<(), Error> { + if self.is_blocked(from.ip()) { + return Ok(()); + } + + match server::Packet::decode(src) { + Ok(Some(p)) => return self.handle_server_packet(from, p), + Ok(None) => {} + Err(e) => Err(e)?, + } + + match game::Packet::decode(src) { + Ok(Some(p)) => return self.handle_game_packet(from, p), + Ok(None) => {} + Err(e) => Err(e)?, + } + + match admin::Packet::decode(self.hash.len, src) { + Ok(Some(p)) => return self.handle_admin_packet(from, p), + Ok(None) => {} + Err(e) => Err(e)?, + } + + Err(Error::UndefinedPacket) + } + fn now(&self) -> u32 { self.start_time.elapsed().as_secs() as u32 } diff --git a/protocol/src/admin.rs b/protocol/src/admin.rs index 69bc835..78ba814 100644 --- a/protocol/src/admin.rs +++ b/protocol/src/admin.rs @@ -5,7 +5,7 @@ use crate::cursor::{Cursor, CursorMut}; use crate::types::Hide; -use crate::Error; +use crate::{CursorError, Error}; /// Default hash length. pub const HASH_LEN: usize = 64; @@ -27,7 +27,7 @@ impl AdminChallenge { if src == Self::HEADER { Ok(Self) } else { - Err(Error::InvalidPacket) + Err(CursorError::Expect)? } } @@ -97,23 +97,22 @@ impl<'a> AdminCommand<'a> { #[derive(Clone, Debug, PartialEq)] pub enum Packet<'a> { /// Admin challenge request. - AdminChallenge(AdminChallenge), + AdminChallenge, /// Admin command. AdminCommand(AdminCommand<'a>), } impl<'a> Packet<'a> { /// Decode packet from `src` with specified hash length. - pub fn decode(hash_len: usize, src: &'a [u8]) -> Result { - if let Ok(p) = AdminChallenge::decode(src) { - return Ok(Self::AdminChallenge(p)); - } - - if let Ok(p) = AdminCommand::decode_with_hash_len(hash_len, src) { - return Ok(Self::AdminCommand(p)); + pub fn decode(hash_len: usize, src: &'a [u8]) -> Result, Error> { + if src.starts_with(AdminChallenge::HEADER) { + AdminChallenge::decode(src).map(|_| Self::AdminChallenge) + } else if src.starts_with(AdminCommand::HEADER) { + AdminCommand::decode_with_hash_len(hash_len, src).map(Self::AdminCommand) + } else { + return Ok(None); } - - Err(Error::InvalidPacket) + .map(Some) } } @@ -126,7 +125,10 @@ mod tests { let p = AdminChallenge; let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(AdminChallenge::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(HASH_LEN, &buf[..n]), + Ok(Some(Packet::AdminChallenge)) + ); } #[test] @@ -134,6 +136,9 @@ mod tests { let p = AdminCommand::new(0x12345678, &[1; HASH_LEN], "foo bar baz"); let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(AdminCommand::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(HASH_LEN, &buf[..n]), + Ok(Some(Packet::AdminCommand(p))) + ); } } diff --git a/protocol/src/cursor.rs b/protocol/src/cursor.rs index 52effad..abb8de6 100644 --- a/protocol/src/cursor.rs +++ b/protocol/src/cursor.rs @@ -7,8 +7,42 @@ use std::mem; use std::slice; use std::str; +use thiserror::Error; + +use super::color; use super::types::Str; -use super::{color, Error}; + +/// The error type for `Cursor` and `CursorMut`. +#[derive(Error, Debug, PartialEq, Eq)] +pub enum Error { + /// Invalid number. + #[error("Invalid number")] + InvalidNumber, + /// Invalid string. + #[error("Invalid string")] + InvalidString, + /// Invalid boolean. + #[error("Invalid boolean")] + InvalidBool, + /// Invalid table entry. + #[error("Invalid table key")] + InvalidTableKey, + /// Invalid table entry. + #[error("Invalid table entry")] + InvalidTableValue, + /// Table end found. + #[error("Table end")] + TableEnd, + /// Expected data not found. + #[error("Expected data not found")] + Expect, + /// An unexpected data found. + #[error("Unexpected data")] + ExpectEmpty, + /// Buffer size is no enougth to decode or encode a packet. + #[error("Unexpected end of buffer")] + UnexpectedEnd, +} pub trait GetKeyValue<'a>: Sized { fn get_key_value(cur: &mut Cursor<'a>) -> Result; @@ -56,7 +90,7 @@ impl<'a> GetKeyValue<'a> for bool { match cur.get_key_value_raw()? { b"0" => Ok(false), b"1" => Ok(true), - _ => Err(Error::InvalidPacket), + _ => Err(Error::InvalidBool), } } } @@ -68,7 +102,7 @@ macro_rules! impl_get_value { let s = cur.get_key_value::<&str>()?; // HACK: special case for one asshole let (_, s) = color::trim_start_color(s); - s.parse().map_err(|_| Error::InvalidPacket) + s.parse().map_err(|_| Error::InvalidNumber) } })+ }; @@ -216,13 +250,13 @@ impl<'a> Cursor<'a> { self.advance(s.len())?; Ok(()) } else { - Err(Error::InvalidPacket) + Err(Error::Expect) } } pub fn expect_empty(&self) -> Result<(), Error> { if self.has_remaining() { - Err(Error::InvalidPacket) + Err(Error::ExpectEmpty) } else { Ok(()) } @@ -252,12 +286,13 @@ impl<'a> Cursor<'a> { pub fn get_key_value_raw(&mut self) -> Result<&'a [u8], Error> { let mut cur = *self; - if cur.get_u8()? == b'\\' { - let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n'); - *self = cur; - Ok(value) - } else { - Err(Error::InvalidPacket) + match cur.get_u8()? { + b'\\' => { + let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n'); + *self = cur; + Ok(value) + } + _ => Err(Error::InvalidTableValue), } } @@ -265,14 +300,20 @@ impl<'a> Cursor<'a> { T::get_key_value(self) } + pub fn skip_key_value>(&mut self) -> Result<(), Error> { + T::get_key_value(self).map(|_| ()) + } + pub fn get_key_raw(&mut self) -> Result<&'a [u8], Error> { let mut cur = *self; - if cur.get_u8()? == b'\\' { - let value = cur.take_while(|c| c != b'\\' && c != b'\n')?; - *self = cur; - Ok(value) - } else { - Err(Error::InvalidPacket) + match cur.get_u8() { + Ok(b'\\') => { + let value = cur.take_while(|c| c != b'\\' && c != b'\n')?; + *self = cur; + Ok(value) + } + Ok(b'\n') | Err(Error::UnexpectedEnd) => Err(Error::TableEnd), + _ => Err(Error::InvalidTableKey), } } @@ -288,6 +329,18 @@ pub trait PutKeyValue { ) -> Result<&'b mut CursorMut<'a>, Error>; } +impl PutKeyValue for &T +where + T: PutKeyValue, +{ + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, Error> { + (*self).put_key_value(cur) + } +} + impl PutKeyValue for &str { fn put_key_value<'a, 'b>( &self, @@ -532,7 +585,7 @@ mod tests { assert_eq!(cur.get_key(), Ok((&b"gamedir"[..], "valve"))); assert_eq!(cur.get_key(), Ok((&b"password"[..], false))); assert_eq!(cur.get_key(), Ok((&b"host"[..], "test"))); - assert_eq!(cur.get_key::<&[u8]>(), Err(Error::UnexpectedEnd)); + assert_eq!(cur.get_key::<&[u8]>(), Err(Error::TableEnd)); Ok(()) } diff --git a/protocol/src/filter.rs b/protocol/src/filter.rs index ddc816f..6d50ece 100644 --- a/protocol/src/filter.rs +++ b/protocol/src/filter.rs @@ -31,7 +31,6 @@ use std::fmt; use std::net::SocketAddrV4; -use std::num::ParseIntError; use std::str::FromStr; use bitflags::bitflags; @@ -40,7 +39,7 @@ use log::debug; use crate::cursor::{Cursor, GetKeyValue, PutKeyValue}; use crate::server::{ServerAdd, ServerFlags, ServerType}; use crate::types::Str; -use crate::{Error, ServerInfo}; +use crate::{CursorError, Error, ServerInfo}; bitflags! { /// Additional filter flags. @@ -129,21 +128,21 @@ impl fmt::Display for Version { } impl FromStr for Version { - type Err = ParseIntError; + type Err = CursorError; fn from_str(s: &str) -> Result { let (major, tail) = s.split_once('.').unwrap_or((s, "0")); let (minor, patch) = tail.split_once('.').unwrap_or((tail, "0")); - let major = major.parse()?; - let minor = minor.parse()?; - let patch = patch.parse()?; + let major = major.parse().map_err(|_| CursorError::InvalidNumber)?; + let minor = minor.parse().map_err(|_| CursorError::InvalidNumber)?; + let patch = patch.parse().map_err(|_| CursorError::InvalidNumber)?; Ok(Self::with_patch(major, minor, patch)) } } impl GetKeyValue<'_> for Version { - fn get_key_value(cur: &mut Cursor) -> Result { - Self::from_str(cur.get_key_value()?).map_err(|_| Error::InvalidPacket) + fn get_key_value(cur: &mut Cursor) -> Result { + cur.get_key_value().and_then(Self::from_str) } } @@ -151,7 +150,7 @@ impl PutKeyValue for Version { fn put_key_value<'a, 'b>( &self, cur: &'b mut crate::cursor::CursorMut<'a>, - ) -> Result<&'b mut crate::cursor::CursorMut<'a>, Error> { + ) -> Result<&'b mut crate::cursor::CursorMut<'a>, CursorError> { cur.put_key_value(self.major)? .put_u8(b'.')? .put_key_value(self.minor)?; @@ -201,42 +200,48 @@ impl<'a> TryFrom<&'a [u8]> for Filter<'a> { type Error = Error; fn try_from(src: &'a [u8]) -> Result { + trait Helper<'a> { + fn get>(&mut self, key: &'static str) -> Result; + } + + impl<'a> Helper<'a> for Cursor<'a> { + fn get>(&mut self, key: &'static str) -> Result { + T::get_key_value(self).map_err(|e| Error::InvalidFilterValue(key, e)) + } + } + let mut cur = Cursor::new(src); let mut filter = Self::default(); loop { let key = match cur.get_key_raw().map(Str) { Ok(s) => s, - Err(Error::UnexpectedEnd) => break, - Err(e) => return Err(e), + Err(CursorError::TableEnd) => break, + Err(e) => Err(e)?, }; match *key { - b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get_key_value()?), - b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get_key_value()?), - b"gamedir" => filter.gamedir = Some(cur.get_key_value()?), - b"map" => filter.map = Some(cur.get_key_value()?), - b"protocol" => filter.protocol = Some(cur.get_key_value()?), - b"empty" => filter.insert_flag(FilterFlags::EMPTY, cur.get_key_value()?), - b"full" => filter.insert_flag(FilterFlags::FULL, cur.get_key_value()?), - b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get_key_value()?), - b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get_key_value()?), - b"clver" => { - filter.clver = Some( - cur.get_key_value::<&str>()? - .parse() - .map_err(|_| Error::InvalidPacket)?, - ); - } - b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get_key_value()?), - b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get_key_value()?), - b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get_key_value()?), + b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get("dedicated")?), + b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get("secure")?), + b"gamedir" => filter.gamedir = Some(cur.get("gamedir")?), + b"map" => filter.map = Some(cur.get("map")?), + b"protocol" => filter.protocol = Some(cur.get("protocol")?), + b"empty" => filter.insert_flag(FilterFlags::EMPTY, cur.get("empty")?), + b"full" => filter.insert_flag(FilterFlags::FULL, cur.get("full")?), + b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get("password")?), + b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get("noplayers")?), + b"clver" => filter.clver = Some(cur.get("clver")?), + b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get("nat")?), + b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get("lan")?), + b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get("bots")?), b"key" => { - filter.key = { - let s = cur.get_key_value::<&str>()?; - let x = u32::from_str_radix(s, 16).map_err(|_| Error::InvalidPacket)?; - Some(x) - } + filter.key = Some( + cur.get_key_value::<&str>() + .and_then(|s| { + u32::from_str_radix(s, 16).map_err(|_| CursorError::InvalidNumber) + }) + .map_err(|e| Error::InvalidFilterValue("key", e))?, + ) } _ => { // skip unknown fields diff --git a/protocol/src/game.rs b/protocol/src/game.rs index 91f4f92..2075e88 100644 --- a/protocol/src/game.rs +++ b/protocol/src/game.rs @@ -35,7 +35,7 @@ where pub fn decode(src: &'a [u8]) -> Result { let mut cur = Cursor::new(src); cur.expect(QueryServers::HEADER)?; - let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidPacket)?; + let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidRegion)?; let last = cur.get_cstr_as_str()?; let filter = match cur.get_bytes(cur.remaining())? { // some clients may have bug and filter will be with zero at the end @@ -44,7 +44,7 @@ where }; Ok(Self { region, - last: last.parse().map_err(|_| Error::InvalidPacket)?, + last: last.parse().map_err(|_| Error::InvalidQueryServersLast)?, filter: T::try_from(filter)?, }) } @@ -114,16 +114,15 @@ pub enum Packet<'a> { impl<'a> Packet<'a> { /// Decode packet from `src`. - pub fn decode(src: &'a [u8]) -> Result { - if let Ok(p) = QueryServers::decode(src) { - return Ok(Self::QueryServers(p)); - } - - if let Ok(p) = GetServerInfo::decode(src) { - return Ok(Self::GetServerInfo(p)); + pub fn decode(src: &'a [u8]) -> Result, Error> { + if src.starts_with(QueryServers::HEADER) { + QueryServers::decode(src).map(Self::QueryServers) + } else if src.starts_with(GetServerInfo::HEADER) { + GetServerInfo::decode(src).map(Self::GetServerInfo) + } else { + return Ok(None); } - - Err(Error::InvalidPacket) + .map(Some) } } @@ -151,7 +150,7 @@ mod tests { }; let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(QueryServers::decode(&buf[..n]), Ok(p)); + assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::QueryServers(p)))); } #[test] @@ -171,10 +170,10 @@ mod tests { }; let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0\0"; - assert_eq!(QueryServers::decode(s), Ok(p.clone())); + assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p.clone())))); let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0"; - assert_eq!(QueryServers::decode(s), Ok(p)); + assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p)))); } #[test] @@ -182,6 +181,9 @@ mod tests { let p = GetServerInfo::new(49); let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(GetServerInfo::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(&buf[..n]), + Ok(Some(Packet::GetServerInfo(p))) + ); } } diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index be03dc3..577eb23 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -16,6 +16,7 @@ pub mod master; pub mod server; pub mod types; +pub use cursor::Error as CursorError; pub use server_info::ServerInfo; use thiserror::Error; @@ -33,13 +34,25 @@ pub enum Error { /// Failed to decode a packet. #[error("Invalid packet")] InvalidPacket, - /// Invalid string in a packet. - #[error("Invalid UTF-8 string")] - InvalidString, - /// Buffer size is no enougth to decode or encode a packet. - #[error("Unexpected end of buffer")] - UnexpectedEnd, + /// Invalid region. + #[error("Invalid region")] + InvalidRegion, + /// Invalid client announce IP. + #[error("Invalid client announce IP")] + InvalidClientAnnounceIp, + /// Invalid last IP. + #[error("Invalid last server IP")] + InvalidQueryServersLast, /// Server protocol version is not supported. #[error("Invalid protocol version")] InvalidProtocolVersion, + /// Cursor error. + #[error("{0}")] + CursorError(#[from] CursorError), + /// Invalid value for server add packet. + #[error("Invalid value for server add key `{0}`: {1}")] + InvalidServerValue(&'static str, #[source] CursorError), + /// Invalid value for query servers packet. + #[error("Invalid value for filter key `{0}`: {1}")] + InvalidFilterValue(&'static str, #[source] CursorError), } diff --git a/protocol/src/master.rs b/protocol/src/master.rs index 41b019c..2d21ab1 100644 --- a/protocol/src/master.rs +++ b/protocol/src/master.rs @@ -174,7 +174,7 @@ impl ClientAnnounce { let addr = cur .get_str(cur.remaining())? .parse() - .map_err(|_| Error::InvalidPacket)?; + .map_err(|_| Error::InvalidClientAnnounceIp)?; cur.expect_empty()?; Ok(Self { addr }) } @@ -247,24 +247,19 @@ pub enum Packet<'a> { impl<'a> Packet<'a> { /// Decode packet from `src`. - pub fn decode(src: &'a [u8]) -> Result { - if let Ok(p) = ChallengeResponse::decode(src) { - return Ok(Self::ChallengeResponse(p)); - } - - if let Ok(p) = QueryServersResponse::decode(src) { - return Ok(Self::QueryServersResponse(p)); - } - - if let Ok(p) = ClientAnnounce::decode(src) { - return Ok(Self::ClientAnnounce(p)); - } - - if let Ok(p) = AdminChallengeResponse::decode(src) { - return Ok(Self::AdminChallengeResponse(p)); + pub fn decode(src: &'a [u8]) -> Result, Error> { + if src.starts_with(ChallengeResponse::HEADER) { + ChallengeResponse::decode(src).map(Self::ChallengeResponse) + } else if src.starts_with(QueryServersResponse::HEADER) { + QueryServersResponse::decode(src).map(Self::QueryServersResponse) + } else if src.starts_with(ClientAnnounce::HEADER) { + ClientAnnounce::decode(src).map(Self::ClientAnnounce) + } else if src.starts_with(AdminChallengeResponse::HEADER) { + AdminChallengeResponse::decode(src).map(Self::AdminChallengeResponse) + } else { + return Ok(None); } - - Err(Error::InvalidPacket) + .map(Some) } } @@ -277,7 +272,10 @@ mod tests { let p = ChallengeResponse::new(0x12345678, Some(0x87654321)); let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(ChallengeResponse::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(&buf[..n]), + Ok(Some(Packet::ChallengeResponse(p))) + ); } #[test] @@ -291,7 +289,10 @@ mod tests { let p = ChallengeResponse::new(0x12345678, None); let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(ChallengeResponse::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(&buf[..n]), + Ok(Some(Packet::ChallengeResponse(p))) + ); } #[test] @@ -314,7 +315,10 @@ mod tests { let p = ClientAnnounce::new("1.2.3.4:12345".parse().unwrap()); let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(ClientAnnounce::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(&buf[..n]), + Ok(Some(Packet::ClientAnnounce(p))) + ); } #[test] @@ -322,6 +326,9 @@ mod tests { let p = AdminChallengeResponse::new(0x12345678, 0x87654321); let mut buf = [0; 64]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(AdminChallengeResponse::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(&buf[..n]), + Ok(Some(Packet::AdminChallengeResponse(p))) + ); } } diff --git a/protocol/src/server.rs b/protocol/src/server.rs index 5f00894..52d9be6 100644 --- a/protocol/src/server.rs +++ b/protocol/src/server.rs @@ -11,7 +11,7 @@ use log::debug; use super::cursor::{Cursor, CursorMut, GetKeyValue, PutKeyValue}; use super::filter::Version; use super::types::Str; -use super::Error; +use super::{CursorError, Error}; /// Sended to a master server before `ServerAdd` packet. #[derive(Clone, Debug, PartialEq)] @@ -74,7 +74,7 @@ impl Default for Os { } impl TryFrom<&[u8]> for Os { - type Error = Error; + type Error = CursorError; fn try_from(value: &[u8]) -> Result { match value { @@ -87,7 +87,7 @@ impl TryFrom<&[u8]> for Os { } impl GetKeyValue<'_> for Os { - fn get_key_value(cur: &mut Cursor) -> Result { + fn get_key_value(cur: &mut Cursor) -> Result { cur.get_key_value_raw()?.try_into() } } @@ -96,7 +96,7 @@ impl PutKeyValue for Os { fn put_key_value<'a, 'b>( &self, cur: &'b mut CursorMut<'a>, - ) -> Result<&'b mut CursorMut<'a>, Error> { + ) -> Result<&'b mut CursorMut<'a>, CursorError> { match self { Self::Linux => cur.put_str("l"), Self::Windows => cur.put_str("w"), @@ -139,7 +139,7 @@ impl Default for ServerType { } impl TryFrom<&[u8]> for ServerType { - type Error = Error; + type Error = CursorError; fn try_from(value: &[u8]) -> Result { match value { @@ -152,7 +152,7 @@ impl TryFrom<&[u8]> for ServerType { } impl GetKeyValue<'_> for ServerType { - fn get_key_value(cur: &mut Cursor) -> Result { + fn get_key_value(cur: &mut Cursor) -> Result { cur.get_key_value_raw()?.try_into() } } @@ -161,7 +161,7 @@ impl PutKeyValue for ServerType { fn put_key_value<'a, 'b>( &self, cur: &'b mut CursorMut<'a>, - ) -> Result<&'b mut CursorMut<'a>, Error> { + ) -> Result<&'b mut CursorMut<'a>, CursorError> { match self { Self::Dedicated => cur.put_str("d"), Self::Local => cur.put_str("l"), @@ -217,7 +217,7 @@ impl Default for Region { } impl TryFrom for Region { - type Error = Error; + type Error = CursorError; fn try_from(value: u8) -> Result { match value { @@ -230,13 +230,13 @@ impl TryFrom for Region { 0x06 => Ok(Region::MiddleEast), 0x07 => Ok(Region::Africa), 0xff => Ok(Region::RestOfTheWorld), - _ => Err(Error::InvalidPacket), + _ => Err(CursorError::InvalidNumber), } } } impl GetKeyValue<'_> for Region { - fn get_key_value(cur: &mut Cursor) -> Result { + fn get_key_value(cur: &mut Cursor) -> Result { cur.get_key_value::()?.try_into() } } @@ -304,28 +304,38 @@ where { /// Decode packet from `src`. pub fn decode(src: &'a [u8]) -> Result { + trait Helper<'a> { + fn get>(&mut self, key: &'static str) -> Result; + } + + impl<'a> Helper<'a> for Cursor<'a> { + fn get>(&mut self, key: &'static str) -> Result { + T::get_key_value(self).map_err(|e| Error::InvalidServerValue(key, e)) + } + } + let mut cur = Cursor::new(src); cur.expect(ServerAdd::HEADER)?; let mut ret = Self::default(); let mut challenge = None; - while cur.as_slice().starts_with(&[b'\\']) { + loop { let key = match cur.get_key_raw() { Ok(s) => s, - Err(Error::UnexpectedEnd) => break, - Err(e) => return Err(e), + Err(CursorError::TableEnd) => break, + Err(e) => Err(e)?, }; match key { - b"protocol" => ret.protocol = cur.get_key_value()?, - b"challenge" => challenge = Some(cur.get_key_value()?), - b"players" => ret.players = cur.get_key_value()?, - b"max" => ret.max = cur.get_key_value()?, - b"gamedir" => ret.gamedir = cur.get_key_value()?, - b"product" => { let _ = cur.get_key_value::>()?; }, // legacy key, ignore - b"map" => ret.map = cur.get_key_value()?, - b"type" => ret.server_type = cur.get_key_value()?, - b"os" => ret.os = cur.get_key_value()?, + b"protocol" => ret.protocol = cur.get("protocol")?, + b"challenge" => challenge = Some(cur.get("challenge")?), + b"players" => ret.players = cur.get("players")?, + b"max" => ret.max = cur.get("max")?, + b"gamedir" => ret.gamedir = cur.get("gamedir")?, + b"product" => cur.skip_key_value::<&[u8]>()?, // legacy key, ignore + b"map" => ret.map = cur.get("map")?, + b"type" => ret.server_type = cur.get("type")?, + b"os" => ret.os = cur.get("os")?, b"version" => { ret.version = cur .get_key_value() @@ -335,12 +345,14 @@ where }) .unwrap_or_default() } - b"region" => ret.region = cur.get_key_value()?, - b"bots" => ret.flags.set(ServerFlags::BOTS, cur.get_key_value::()? != 0), - b"password" => ret.flags.set(ServerFlags::PASSWORD, cur.get_key_value()?), - b"secure" => ret.flags.set(ServerFlags::SECURE, cur.get_key_value()?), - b"lan" => ret.flags.set(ServerFlags::LAN, cur.get_key_value()?), - b"nat" => ret.flags.set(ServerFlags::NAT, cur.get_key_value()?), + b"region" => ret.region = cur.get("region")?, + b"bots" => ret + .flags + .set(ServerFlags::BOTS, cur.get::("bots")? != 0), + b"password" => ret.flags.set(ServerFlags::PASSWORD, cur.get("password")?), + b"secure" => ret.flags.set(ServerFlags::SECURE, cur.get("secure")?), + b"lan" => ret.flags.set(ServerFlags::LAN, cur.get("lan")?), + b"nat" => ret.flags.set(ServerFlags::NAT, cur.get("nat")?), _ => { // skip unknown fields let value = cur.get_key_value::>()?; @@ -354,14 +366,14 @@ where ret.challenge = c; Ok(ret) } - None => Err(Error::InvalidPacket), + None => Err(Error::InvalidServerValue("challenge", CursorError::Expect)), } } } impl ServerAdd where - T: PutKeyValue + Clone, + T: PutKeyValue, { /// Encode packet to `buf`. pub fn encode(&self, buf: &mut [u8]) -> Result { @@ -371,8 +383,8 @@ where .put_key("challenge", self.challenge)? .put_key("players", self.players)? .put_key("max", self.max)? - .put_key("gamedir", self.gamedir.clone())? - .put_key("map", self.map.clone())? + .put_key("gamedir", &self.gamedir)? + .put_key("map", &self.map)? .put_key("type", self.server_type)? .put_key("os", self.os)? .put_key("version", self.version)? @@ -469,8 +481,8 @@ where loop { let key = match cur.get_key_raw() { Ok(s) => s, - Err(Error::UnexpectedEnd) => break, - Err(e) => return Err(e), + Err(CursorError::TableEnd) => break, + Err(e) => Err(e)?, }; match key { @@ -500,21 +512,24 @@ where } } -impl<'a> GetServerInfoResponse<&'a str> { +impl GetServerInfoResponse +where + T: PutKeyValue, +{ /// Encode packet to `buf`. pub fn encode(&self, buf: &mut [u8]) -> Result { Ok(CursorMut::new(buf) .put_bytes(GetServerInfoResponse::HEADER)? .put_key("p", self.protocol)? - .put_key("map", self.map)? + .put_key("map", &self.map)? .put_key("dm", self.dm)? .put_key("team", self.team)? .put_key("coop", self.coop)? .put_key("numcl", self.numcl)? .put_key("maxcl", self.maxcl)? - .put_key("gamedir", self.gamedir)? + .put_key("gamedir", &self.gamedir)? .put_key("password", self.password)? - .put_key("host", self.host)? + .put_key("host", &self.host)? .pos()) } } @@ -534,24 +549,19 @@ pub enum Packet<'a> { impl<'a> Packet<'a> { /// Decode packet from `src`. - pub fn decode(src: &'a [u8]) -> Result { - if let Ok(p) = Challenge::decode(src) { - return Ok(Self::Challenge(p)); - } - - if let Ok(p) = ServerAdd::decode(src) { - return Ok(Self::ServerAdd(p)); - } - - if ServerRemove::decode(src).is_ok() { - return Ok(Self::ServerRemove); - } - - if let Ok(p) = GetServerInfoResponse::decode(src) { - return Ok(Self::GetServerInfoResponse(p)); + pub fn decode(src: &'a [u8]) -> Result, Error> { + if src.starts_with(Challenge::HEADER) { + Challenge::decode(src).map(Self::Challenge) + } else if src.starts_with(ServerAdd::HEADER) { + ServerAdd::decode(src).map(Self::ServerAdd) + } else if src.starts_with(ServerRemove::HEADER) { + ServerRemove::decode(src).map(|_| Self::ServerRemove) + } else if src.starts_with(GetServerInfoResponse::HEADER) { + GetServerInfoResponse::decode(src).map(Self::GetServerInfoResponse) + } else { + return Ok(None); } - - Err(Error::InvalidPacket) + .map(Some) } } @@ -564,13 +574,16 @@ mod tests { let p = Challenge::new(Some(0x12345678)); let mut buf = [0; 128]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(Challenge::decode(&buf[..n]), Ok(p)); + assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::Challenge(p)))); } #[test] fn challenge_old() { let s = b"q\xff"; - assert_eq!(Challenge::decode(s), Ok(Challenge::new(None))); + assert_eq!( + Packet::decode(s), + Ok(Some(Packet::Challenge(Challenge::new(None)))) + ); let p = Challenge::new(None); let mut buf = [0; 128]; @@ -581,8 +594,8 @@ mod tests { #[test] fn server_add() { let p = ServerAdd { - gamedir: "valve", - map: "crossfire", + gamedir: Str(&b"valve"[..]), + map: Str(&b"crossfire"[..]), version: Version::new(0, 20), challenge: 0x12345678, server_type: ServerType::Dedicated, @@ -595,7 +608,7 @@ mod tests { }; let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(ServerAdd::decode(&buf[..n]), Ok(p)); + assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::ServerAdd(p)))); } #[test] @@ -603,26 +616,29 @@ mod tests { let p = ServerRemove; let mut buf = [0; 64]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(ServerRemove::decode(&buf[..n]), Ok(p)); + assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::ServerRemove))); } #[test] fn get_server_info_response() { let p = GetServerInfoResponse { protocol: 49, - map: "crossfire", + map: Str("crossfire".as_bytes()), dm: true, team: true, coop: true, numcl: 4, maxcl: 32, - gamedir: "valve", + gamedir: Str("valve".as_bytes()), password: true, - host: "Test", + host: Str("Test".as_bytes()), }; let mut buf = [0; 512]; let n = p.encode(&mut buf).unwrap(); - assert_eq!(GetServerInfoResponse::decode(&buf[..n]), Ok(p)); + assert_eq!( + Packet::decode(&buf[..n]), + Ok(Some(Packet::GetServerInfoResponse(p))) + ); } #[test] diff --git a/protocol/src/types.rs b/protocol/src/types.rs index 54331da..a678647 100644 --- a/protocol/src/types.rs +++ b/protocol/src/types.rs @@ -6,6 +6,9 @@ use std::fmt; use std::ops::Deref; +use crate::cursor::{CursorMut, PutKeyValue}; +use crate::CursorError; + /// Wrapper for slice of bytes with printing the bytes as a string. /// /// # Examples @@ -24,6 +27,15 @@ impl From for Str { } } +impl PutKeyValue for Str<&[u8]> { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, CursorError> { + cur.put_bytes(self.0) + } +} + impl fmt::Debug for Str where T: AsRef<[u8]>,