Browse Source

protocol: report more informative errors

ipv6
Denis Drakhnia 11 months ago
parent
commit
142b28ad64
  1. 2
      admin/src/main.rs
  2. 312
      master/src/master_server.rs
  3. 33
      protocol/src/admin.rs
  4. 89
      protocol/src/cursor.rs
  5. 75
      protocol/src/filter.rs
  6. 32
      protocol/src/game.rs
  7. 25
      protocol/src/lib.rs
  8. 51
      protocol/src/master.rs
  9. 146
      protocol/src/server.rs
  10. 12
      protocol/src/types.rs

2
admin/src/main.rs

@ -31,7 +31,7 @@ fn send_command(cli: &cli::Cli) -> Result<(), Error> {
let n = sock.recv(&mut buf)?; let n = sock.recv(&mut buf)?;
let (master_challenge, hash_challenge) = match master::Packet::decode(&buf[..n])? { let (master_challenge, hash_challenge) = match master::Packet::decode(&buf[..n])? {
master::Packet::AdminChallengeResponse(p) => (p.master_challenge, p.hash_challenge), Some(master::Packet::AdminChallengeResponse(p)) => (p.master_challenge, p.hash_challenge),
_ => return Err(Error::UnexpectedPacket), _ => return Err(Error::UnexpectedPacket),
}; };

312
master/src/master_server.rs

@ -44,6 +44,10 @@ pub enum Error {
Io(#[from] io::Error), Io(#[from] io::Error),
#[error("Admin challenge do not exist")] #[error("Admin challenge do not exist")]
AdminChallengeNotFound, AdminChallengeNotFound,
#[error("Undefined packet")]
UndefinedPacket,
#[error("Unexpected packet")]
UnexpectedPacket,
} }
/// HashMap entry to keep tracking creation time. /// HashMap entry to keep tracking creation time.
@ -235,8 +239,9 @@ impl MasterServer {
} }
}; };
if let Err(e) = self.handle_packet(from, &buf[..n]) { let src = &buf[..n];
error!("{}: {}", from, e); if let Err(e) = self.handle_packet(from, src) {
debug!("{}: {}: \"{}\"", from, e, Str(src));
} }
} }
Ok(()) Ok(())
@ -249,167 +254,196 @@ impl MasterServer {
self.admin_challenges.clear(); self.admin_challenges.clear();
} }
fn handle_packet(&mut self, from: SocketAddrV4, src: &[u8]) -> Result<(), Error> { fn handle_server_packet(&mut self, from: SocketAddrV4, p: server::Packet) -> Result<(), Error> {
if self.is_blocked(from.ip()) { trace!("{}: recv {:?}", from, p);
return Ok(());
} match p {
server::Packet::Challenge(p) => {
if let Ok(p) = server::Packet::decode(src) { let master_challenge = self.add_challenge(from);
match p { let mut buf = [0; MAX_PACKET_SIZE];
server::Packet::Challenge(p) => { let p = master::ChallengeResponse::new(master_challenge, p.server_challenge);
trace!("{}: recv {:?}", from, p); trace!("{}: send {:?}", from, p);
let master_challenge = self.add_challenge(from); let n = p.encode(&mut buf)?;
let mut buf = [0; MAX_PACKET_SIZE]; self.sock.send_to(&buf[..n], from)?;
let p = master::ChallengeResponse::new(master_challenge, p.server_challenge); self.remove_outdated_challenges();
trace!("{}: send {:?}", from, p); }
let n = p.encode(&mut buf)?; server::Packet::ServerAdd(p) => {
self.sock.send_to(&buf[..n], from)?; let entry = match self.challenges.get(&from) {
self.remove_outdated_challenges(); Some(e) => e,
} None => {
server::Packet::ServerAdd(p) => { trace!("{}: Challenge does not exists", from);
trace!("{}: recv {:?}", from, p);
let entry = match self.challenges.get(&from) {
Some(e) => e,
None => {
trace!("{}: Challenge does not exists", from);
return Ok(());
}
};
if !entry.is_valid(self.now(), self.timeout.challenge) {
return Ok(());
}
if p.challenge != entry.value {
warn!(
"{}: Expected challenge {} but received {}",
from, entry.value, p.challenge
);
return Ok(()); return Ok(());
} }
if self.challenges.remove(&from).is_some() { };
self.add_server(from, ServerInfo::new(&p)); if !entry.is_valid(self.now(), self.timeout.challenge) {
} return Ok(());
self.remove_outdated_servers();
} }
_ => { if p.challenge != entry.value {
trace!("{}: recv {:?}", from, p); warn!(
"{}: Expected challenge {} but received {}",
from, entry.value, p.challenge
);
return Ok(());
}
if self.challenges.remove(&from).is_some() {
self.add_server(from, ServerInfo::new(&p));
} }
self.remove_outdated_servers();
}
server::Packet::ServerRemove => {
// ignore
} }
} else if let Ok(p) = game::Packet::decode(src) { _ => {
match p { return Err(Error::UnexpectedPacket);
game::Packet::QueryServers(p) => { }
trace!("{}: recv {:?}", from, p); }
if p.filter.clver.map_or(false, |v| v < self.clver) {
let iter = std::iter::once(self.update_addr); Ok(())
self.send_server_list(from, p.filter.key, iter)?; }
} else {
let now = self.now(); fn handle_game_packet(&mut self, from: SocketAddrV4, p: game::Packet) -> Result<(), Error> {
let iter = self trace!("{}: recv {:?}", from, p);
.servers
.iter() match p {
.filter(|i| i.1.is_valid(now, self.timeout.server)) game::Packet::QueryServers(p) => {
.filter(|i| i.1.matches(*i.0, p.region, &p.filter)) if p.filter.clver.map_or(false, |v| v < self.clver) {
.map(|i| *i.0); let iter = std::iter::once(self.update_addr);
self.send_server_list(from, p.filter.key, iter)?;
self.send_server_list(from, p.filter.key, iter.clone())?; } else {
let now = self.now();
if p.filter.flags.contains(FilterFlags::NAT) { let iter = self
self.send_client_to_nat_servers(from, iter)?; .servers
} .iter()
.filter(|i| i.1.is_valid(now, self.timeout.server))
.filter(|i| i.1.matches(*i.0, p.region, &p.filter))
.map(|i| *i.0);
self.send_server_list(from, p.filter.key, iter.clone())?;
if p.filter.flags.contains(FilterFlags::NAT) {
self.send_client_to_nat_servers(from, iter)?;
} }
} }
game::Packet::GetServerInfo(p) => {
trace!("{}: recv {:?}", from, p);
let p = server::GetServerInfoResponse {
map: self.update_map.as_ref(),
host: self.update_title.as_ref(),
protocol: 48, // XXX: how to detect what version client will accept?
dm: true,
maxcl: 32,
gamedir: "valve", // XXX: probably must be specific for client...
..Default::default()
};
trace!("{}: send {:?}", from, p);
let mut buf = [0; MAX_PACKET_SIZE];
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
}
} }
} else if let Ok(p) = admin::Packet::decode(self.hash.len, src) { game::Packet::GetServerInfo(_) => {
let now = self.now(); let p = server::GetServerInfoResponse {
map: self.update_map.as_ref(),
host: self.update_title.as_ref(),
protocol: 48, // XXX: how to detect what version client will accept?
dm: true,
maxcl: 32,
gamedir: "valve", // XXX: probably must be specific for client...
..Default::default()
};
trace!("{}: send {:?}", from, p);
let mut buf = [0; MAX_PACKET_SIZE];
let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?;
}
}
if let Some(e) = self.admin_limit.get(from.ip()) { Ok(())
if e.is_valid(now, self.timeout.admin) { }
trace!("{}: rate limit", from);
return Ok(()); fn handle_admin_packet(&mut self, from: SocketAddrV4, p: admin::Packet) -> Result<(), Error> {
} trace!("{}: recv {:?}", from, p);
let now = self.now();
if let Some(e) = self.admin_limit.get(from.ip()) {
if e.is_valid(now, self.timeout.admin) {
trace!("{}: rate limit", from);
return Ok(());
} }
}
match p { match p {
admin::Packet::AdminChallenge(p) => { admin::Packet::AdminChallenge => {
trace!("{}: recv {:?}", from, p); let (master_challenge, hash_challenge) = self.admin_challenge_add(from);
let (master_challenge, hash_challenge) = self.admin_challenge_add(from);
let p = master::AdminChallengeResponse::new(master_challenge, hash_challenge); let p = master::AdminChallengeResponse::new(master_challenge, hash_challenge);
trace!("{}: send {:?}", from, p); trace!("{}: send {:?}", from, p);
let mut buf = [0; 64]; let mut buf = [0; 64];
let n = p.encode(&mut buf)?; let n = p.encode(&mut buf)?;
self.sock.send_to(&buf[..n], from)?; self.sock.send_to(&buf[..n], from)?;
self.admin_challenges_cleanup(); self.admin_challenges_cleanup();
}
admin::Packet::AdminCommand(p) => {
let entry = *self
.admin_challenges
.get(from.ip())
.ok_or(Error::AdminChallengeNotFound)?;
if entry.0 != p.master_challenge {
trace!("{}: master challenge is not valid", from);
return Ok(());
} }
admin::Packet::AdminCommand(p) => {
trace!("{}: recv {:?}", from, p);
let entry = *self
.admin_challenges
.get(from.ip())
.ok_or(Error::AdminChallengeNotFound)?;
if entry.0 != p.master_challenge {
trace!("{}: master challenge is not valid", from);
return Ok(());
}
if !entry.is_valid(now, self.timeout.challenge) { if !entry.is_valid(now, self.timeout.challenge) {
trace!("{}: challenge is outdated", from); trace!("{}: challenge is outdated", from);
return Ok(()); return Ok(());
} }
let state = Params::new() let state = Params::new()
.hash_length(self.hash.len) .hash_length(self.hash.len)
.key(self.hash.key.as_bytes()) .key(self.hash.key.as_bytes())
.personal(self.hash.personal.as_bytes()) .personal(self.hash.personal.as_bytes())
.to_state(); .to_state();
let admin = self.admin_list.iter().find(|i| { let admin = self.admin_list.iter().find(|i| {
let hash = state let hash = state
.clone() .clone()
.update(i.password.as_bytes()) .update(i.password.as_bytes())
.update(&entry.1.to_le_bytes()) .update(&entry.1.to_le_bytes())
.finalize(); .finalize();
*p.hash == hash.as_bytes() *p.hash == hash.as_bytes()
}); });
match admin { match admin {
Some(admin) => { Some(admin) => {
info!("{}: admin({}), command: {:?}", from, &admin.name, p.command); info!("{}: admin({}), command: {:?}", from, &admin.name, p.command);
self.admin_command(p.command); self.admin_command(p.command);
self.admin_challenge_remove(from); self.admin_challenge_remove(from);
} }
None => { None => {
warn!("{}: invalid admin hash, command: {:?}", from, p.command); warn!("{}: invalid admin hash, command: {:?}", from, p.command);
self.admin_limit.insert(*from.ip(), Entry::new(now, ())); self.admin_limit.insert(*from.ip(), Entry::new(now, ()));
self.admin_limit_cleanup(); self.admin_limit_cleanup();
}
} }
} }
} }
} else {
debug!("{}: invalid packet: \"{}\"", from, Str(src));
} }
Ok(()) Ok(())
} }
fn handle_packet(&mut self, from: SocketAddrV4, src: &[u8]) -> Result<(), Error> {
if self.is_blocked(from.ip()) {
return Ok(());
}
match server::Packet::decode(src) {
Ok(Some(p)) => return self.handle_server_packet(from, p),
Ok(None) => {}
Err(e) => Err(e)?,
}
match game::Packet::decode(src) {
Ok(Some(p)) => return self.handle_game_packet(from, p),
Ok(None) => {}
Err(e) => Err(e)?,
}
match admin::Packet::decode(self.hash.len, src) {
Ok(Some(p)) => return self.handle_admin_packet(from, p),
Ok(None) => {}
Err(e) => Err(e)?,
}
Err(Error::UndefinedPacket)
}
fn now(&self) -> u32 { fn now(&self) -> u32 {
self.start_time.elapsed().as_secs() as u32 self.start_time.elapsed().as_secs() as u32
} }

33
protocol/src/admin.rs

@ -5,7 +5,7 @@
use crate::cursor::{Cursor, CursorMut}; use crate::cursor::{Cursor, CursorMut};
use crate::types::Hide; use crate::types::Hide;
use crate::Error; use crate::{CursorError, Error};
/// Default hash length. /// Default hash length.
pub const HASH_LEN: usize = 64; pub const HASH_LEN: usize = 64;
@ -27,7 +27,7 @@ impl AdminChallenge {
if src == Self::HEADER { if src == Self::HEADER {
Ok(Self) Ok(Self)
} else { } else {
Err(Error::InvalidPacket) Err(CursorError::Expect)?
} }
} }
@ -97,23 +97,22 @@ impl<'a> AdminCommand<'a> {
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum Packet<'a> { pub enum Packet<'a> {
/// Admin challenge request. /// Admin challenge request.
AdminChallenge(AdminChallenge), AdminChallenge,
/// Admin command. /// Admin command.
AdminCommand(AdminCommand<'a>), AdminCommand(AdminCommand<'a>),
} }
impl<'a> Packet<'a> { impl<'a> Packet<'a> {
/// Decode packet from `src` with specified hash length. /// Decode packet from `src` with specified hash length.
pub fn decode(hash_len: usize, src: &'a [u8]) -> Result<Self, Error> { pub fn decode(hash_len: usize, src: &'a [u8]) -> Result<Option<Self>, Error> {
if let Ok(p) = AdminChallenge::decode(src) { if src.starts_with(AdminChallenge::HEADER) {
return Ok(Self::AdminChallenge(p)); AdminChallenge::decode(src).map(|_| Self::AdminChallenge)
} } else if src.starts_with(AdminCommand::HEADER) {
AdminCommand::decode_with_hash_len(hash_len, src).map(Self::AdminCommand)
if let Ok(p) = AdminCommand::decode_with_hash_len(hash_len, src) { } else {
return Ok(Self::AdminCommand(p)); return Ok(None);
} }
.map(Some)
Err(Error::InvalidPacket)
} }
} }
@ -126,7 +125,10 @@ mod tests {
let p = AdminChallenge; let p = AdminChallenge;
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(AdminChallenge::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(HASH_LEN, &buf[..n]),
Ok(Some(Packet::AdminChallenge))
);
} }
#[test] #[test]
@ -134,6 +136,9 @@ mod tests {
let p = AdminCommand::new(0x12345678, &[1; HASH_LEN], "foo bar baz"); let p = AdminCommand::new(0x12345678, &[1; HASH_LEN], "foo bar baz");
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(AdminCommand::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(HASH_LEN, &buf[..n]),
Ok(Some(Packet::AdminCommand(p)))
);
} }
} }

89
protocol/src/cursor.rs

@ -7,8 +7,42 @@ use std::mem;
use std::slice; use std::slice;
use std::str; use std::str;
use thiserror::Error;
use super::color;
use super::types::Str; use super::types::Str;
use super::{color, Error};
/// The error type for `Cursor` and `CursorMut`.
#[derive(Error, Debug, PartialEq, Eq)]
pub enum Error {
/// Invalid number.
#[error("Invalid number")]
InvalidNumber,
/// Invalid string.
#[error("Invalid string")]
InvalidString,
/// Invalid boolean.
#[error("Invalid boolean")]
InvalidBool,
/// Invalid table entry.
#[error("Invalid table key")]
InvalidTableKey,
/// Invalid table entry.
#[error("Invalid table entry")]
InvalidTableValue,
/// Table end found.
#[error("Table end")]
TableEnd,
/// Expected data not found.
#[error("Expected data not found")]
Expect,
/// An unexpected data found.
#[error("Unexpected data")]
ExpectEmpty,
/// Buffer size is no enougth to decode or encode a packet.
#[error("Unexpected end of buffer")]
UnexpectedEnd,
}
pub trait GetKeyValue<'a>: Sized { pub trait GetKeyValue<'a>: Sized {
fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error>; fn get_key_value(cur: &mut Cursor<'a>) -> Result<Self, Error>;
@ -56,7 +90,7 @@ impl<'a> GetKeyValue<'a> for bool {
match cur.get_key_value_raw()? { match cur.get_key_value_raw()? {
b"0" => Ok(false), b"0" => Ok(false),
b"1" => Ok(true), b"1" => Ok(true),
_ => Err(Error::InvalidPacket), _ => Err(Error::InvalidBool),
} }
} }
} }
@ -68,7 +102,7 @@ macro_rules! impl_get_value {
let s = cur.get_key_value::<&str>()?; let s = cur.get_key_value::<&str>()?;
// HACK: special case for one asshole // HACK: special case for one asshole
let (_, s) = color::trim_start_color(s); let (_, s) = color::trim_start_color(s);
s.parse().map_err(|_| Error::InvalidPacket) s.parse().map_err(|_| Error::InvalidNumber)
} }
})+ })+
}; };
@ -216,13 +250,13 @@ impl<'a> Cursor<'a> {
self.advance(s.len())?; self.advance(s.len())?;
Ok(()) Ok(())
} else { } else {
Err(Error::InvalidPacket) Err(Error::Expect)
} }
} }
pub fn expect_empty(&self) -> Result<(), Error> { pub fn expect_empty(&self) -> Result<(), Error> {
if self.has_remaining() { if self.has_remaining() {
Err(Error::InvalidPacket) Err(Error::ExpectEmpty)
} else { } else {
Ok(()) Ok(())
} }
@ -252,12 +286,13 @@ impl<'a> Cursor<'a> {
pub fn get_key_value_raw(&mut self) -> Result<&'a [u8], Error> { pub fn get_key_value_raw(&mut self) -> Result<&'a [u8], Error> {
let mut cur = *self; let mut cur = *self;
if cur.get_u8()? == b'\\' { match cur.get_u8()? {
let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n'); b'\\' => {
*self = cur; let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n');
Ok(value) *self = cur;
} else { Ok(value)
Err(Error::InvalidPacket) }
_ => Err(Error::InvalidTableValue),
} }
} }
@ -265,14 +300,20 @@ impl<'a> Cursor<'a> {
T::get_key_value(self) T::get_key_value(self)
} }
pub fn skip_key_value<T: GetKeyValue<'a>>(&mut self) -> Result<(), Error> {
T::get_key_value(self).map(|_| ())
}
pub fn get_key_raw(&mut self) -> Result<&'a [u8], Error> { pub fn get_key_raw(&mut self) -> Result<&'a [u8], Error> {
let mut cur = *self; let mut cur = *self;
if cur.get_u8()? == b'\\' { match cur.get_u8() {
let value = cur.take_while(|c| c != b'\\' && c != b'\n')?; Ok(b'\\') => {
*self = cur; let value = cur.take_while(|c| c != b'\\' && c != b'\n')?;
Ok(value) *self = cur;
} else { Ok(value)
Err(Error::InvalidPacket) }
Ok(b'\n') | Err(Error::UnexpectedEnd) => Err(Error::TableEnd),
_ => Err(Error::InvalidTableKey),
} }
} }
@ -288,6 +329,18 @@ pub trait PutKeyValue {
) -> Result<&'b mut CursorMut<'a>, Error>; ) -> Result<&'b mut CursorMut<'a>, Error>;
} }
impl<T> PutKeyValue for &T
where
T: PutKeyValue,
{
fn put_key_value<'a, 'b>(
&self,
cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, Error> {
(*self).put_key_value(cur)
}
}
impl PutKeyValue for &str { impl PutKeyValue for &str {
fn put_key_value<'a, 'b>( fn put_key_value<'a, 'b>(
&self, &self,
@ -532,7 +585,7 @@ mod tests {
assert_eq!(cur.get_key(), Ok((&b"gamedir"[..], "valve"))); assert_eq!(cur.get_key(), Ok((&b"gamedir"[..], "valve")));
assert_eq!(cur.get_key(), Ok((&b"password"[..], false))); assert_eq!(cur.get_key(), Ok((&b"password"[..], false)));
assert_eq!(cur.get_key(), Ok((&b"host"[..], "test"))); assert_eq!(cur.get_key(), Ok((&b"host"[..], "test")));
assert_eq!(cur.get_key::<&[u8]>(), Err(Error::UnexpectedEnd)); assert_eq!(cur.get_key::<&[u8]>(), Err(Error::TableEnd));
Ok(()) Ok(())
} }

75
protocol/src/filter.rs

@ -31,7 +31,6 @@
use std::fmt; use std::fmt;
use std::net::SocketAddrV4; use std::net::SocketAddrV4;
use std::num::ParseIntError;
use std::str::FromStr; use std::str::FromStr;
use bitflags::bitflags; use bitflags::bitflags;
@ -40,7 +39,7 @@ use log::debug;
use crate::cursor::{Cursor, GetKeyValue, PutKeyValue}; use crate::cursor::{Cursor, GetKeyValue, PutKeyValue};
use crate::server::{ServerAdd, ServerFlags, ServerType}; use crate::server::{ServerAdd, ServerFlags, ServerType};
use crate::types::Str; use crate::types::Str;
use crate::{Error, ServerInfo}; use crate::{CursorError, Error, ServerInfo};
bitflags! { bitflags! {
/// Additional filter flags. /// Additional filter flags.
@ -129,21 +128,21 @@ impl fmt::Display for Version {
} }
impl FromStr for Version { impl FromStr for Version {
type Err = ParseIntError; type Err = CursorError;
fn from_str(s: &str) -> Result<Self, Self::Err> { fn from_str(s: &str) -> Result<Self, Self::Err> {
let (major, tail) = s.split_once('.').unwrap_or((s, "0")); let (major, tail) = s.split_once('.').unwrap_or((s, "0"));
let (minor, patch) = tail.split_once('.').unwrap_or((tail, "0")); let (minor, patch) = tail.split_once('.').unwrap_or((tail, "0"));
let major = major.parse()?; let major = major.parse().map_err(|_| CursorError::InvalidNumber)?;
let minor = minor.parse()?; let minor = minor.parse().map_err(|_| CursorError::InvalidNumber)?;
let patch = patch.parse()?; let patch = patch.parse().map_err(|_| CursorError::InvalidNumber)?;
Ok(Self::with_patch(major, minor, patch)) Ok(Self::with_patch(major, minor, patch))
} }
} }
impl GetKeyValue<'_> for Version { impl GetKeyValue<'_> for Version {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> { fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
Self::from_str(cur.get_key_value()?).map_err(|_| Error::InvalidPacket) cur.get_key_value().and_then(Self::from_str)
} }
} }
@ -151,7 +150,7 @@ impl PutKeyValue for Version {
fn put_key_value<'a, 'b>( fn put_key_value<'a, 'b>(
&self, &self,
cur: &'b mut crate::cursor::CursorMut<'a>, cur: &'b mut crate::cursor::CursorMut<'a>,
) -> Result<&'b mut crate::cursor::CursorMut<'a>, Error> { ) -> Result<&'b mut crate::cursor::CursorMut<'a>, CursorError> {
cur.put_key_value(self.major)? cur.put_key_value(self.major)?
.put_u8(b'.')? .put_u8(b'.')?
.put_key_value(self.minor)?; .put_key_value(self.minor)?;
@ -201,42 +200,48 @@ impl<'a> TryFrom<&'a [u8]> for Filter<'a> {
type Error = Error; type Error = Error;
fn try_from(src: &'a [u8]) -> Result<Self, Self::Error> { fn try_from(src: &'a [u8]) -> Result<Self, Self::Error> {
trait Helper<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error>;
}
impl<'a> Helper<'a> for Cursor<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error> {
T::get_key_value(self).map_err(|e| Error::InvalidFilterValue(key, e))
}
}
let mut cur = Cursor::new(src); let mut cur = Cursor::new(src);
let mut filter = Self::default(); let mut filter = Self::default();
loop { loop {
let key = match cur.get_key_raw().map(Str) { let key = match cur.get_key_raw().map(Str) {
Ok(s) => s, Ok(s) => s,
Err(Error::UnexpectedEnd) => break, Err(CursorError::TableEnd) => break,
Err(e) => return Err(e), Err(e) => Err(e)?,
}; };
match *key { match *key {
b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get_key_value()?), b"dedicated" => filter.insert_flag(FilterFlags::DEDICATED, cur.get("dedicated")?),
b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get_key_value()?), b"secure" => filter.insert_flag(FilterFlags::SECURE, cur.get("secure")?),
b"gamedir" => filter.gamedir = Some(cur.get_key_value()?), b"gamedir" => filter.gamedir = Some(cur.get("gamedir")?),
b"map" => filter.map = Some(cur.get_key_value()?), b"map" => filter.map = Some(cur.get("map")?),
b"protocol" => filter.protocol = Some(cur.get_key_value()?), b"protocol" => filter.protocol = Some(cur.get("protocol")?),
b"empty" => filter.insert_flag(FilterFlags::EMPTY, cur.get_key_value()?), b"empty" => filter.insert_flag(FilterFlags::EMPTY, cur.get("empty")?),
b"full" => filter.insert_flag(FilterFlags::FULL, cur.get_key_value()?), b"full" => filter.insert_flag(FilterFlags::FULL, cur.get("full")?),
b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get_key_value()?), b"password" => filter.insert_flag(FilterFlags::PASSWORD, cur.get("password")?),
b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get_key_value()?), b"noplayers" => filter.insert_flag(FilterFlags::NOPLAYERS, cur.get("noplayers")?),
b"clver" => { b"clver" => filter.clver = Some(cur.get("clver")?),
filter.clver = Some( b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get("nat")?),
cur.get_key_value::<&str>()? b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get("lan")?),
.parse() b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get("bots")?),
.map_err(|_| Error::InvalidPacket)?,
);
}
b"nat" => filter.insert_flag(FilterFlags::NAT, cur.get_key_value()?),
b"lan" => filter.insert_flag(FilterFlags::LAN, cur.get_key_value()?),
b"bots" => filter.insert_flag(FilterFlags::BOTS, cur.get_key_value()?),
b"key" => { b"key" => {
filter.key = { filter.key = Some(
let s = cur.get_key_value::<&str>()?; cur.get_key_value::<&str>()
let x = u32::from_str_radix(s, 16).map_err(|_| Error::InvalidPacket)?; .and_then(|s| {
Some(x) u32::from_str_radix(s, 16).map_err(|_| CursorError::InvalidNumber)
} })
.map_err(|e| Error::InvalidFilterValue("key", e))?,
)
} }
_ => { _ => {
// skip unknown fields // skip unknown fields

32
protocol/src/game.rs

@ -35,7 +35,7 @@ where
pub fn decode(src: &'a [u8]) -> Result<Self, Error> { pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
let mut cur = Cursor::new(src); let mut cur = Cursor::new(src);
cur.expect(QueryServers::HEADER)?; cur.expect(QueryServers::HEADER)?;
let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidPacket)?; let region = cur.get_u8()?.try_into().map_err(|_| Error::InvalidRegion)?;
let last = cur.get_cstr_as_str()?; let last = cur.get_cstr_as_str()?;
let filter = match cur.get_bytes(cur.remaining())? { let filter = match cur.get_bytes(cur.remaining())? {
// some clients may have bug and filter will be with zero at the end // some clients may have bug and filter will be with zero at the end
@ -44,7 +44,7 @@ where
}; };
Ok(Self { Ok(Self {
region, region,
last: last.parse().map_err(|_| Error::InvalidPacket)?, last: last.parse().map_err(|_| Error::InvalidQueryServersLast)?,
filter: T::try_from(filter)?, filter: T::try_from(filter)?,
}) })
} }
@ -114,16 +114,15 @@ pub enum Packet<'a> {
impl<'a> Packet<'a> { impl<'a> Packet<'a> {
/// Decode packet from `src`. /// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> { pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
if let Ok(p) = QueryServers::decode(src) { if src.starts_with(QueryServers::HEADER) {
return Ok(Self::QueryServers(p)); QueryServers::decode(src).map(Self::QueryServers)
} } else if src.starts_with(GetServerInfo::HEADER) {
GetServerInfo::decode(src).map(Self::GetServerInfo)
if let Ok(p) = GetServerInfo::decode(src) { } else {
return Ok(Self::GetServerInfo(p)); return Ok(None);
} }
.map(Some)
Err(Error::InvalidPacket)
} }
} }
@ -151,7 +150,7 @@ mod tests {
}; };
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(QueryServers::decode(&buf[..n]), Ok(p)); assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::QueryServers(p))));
} }
#[test] #[test]
@ -171,10 +170,10 @@ mod tests {
}; };
let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0\0"; let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0\0";
assert_eq!(QueryServers::decode(s), Ok(p.clone())); assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p.clone()))));
let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0"; let s = b"1\xff0.0.0.0:0\x00\\protocol\\48\\clver\\0.20\\nat\\0";
assert_eq!(QueryServers::decode(s), Ok(p)); assert_eq!(Packet::decode(s), Ok(Some(Packet::QueryServers(p))));
} }
#[test] #[test]
@ -182,6 +181,9 @@ mod tests {
let p = GetServerInfo::new(49); let p = GetServerInfo::new(49);
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(GetServerInfo::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::GetServerInfo(p)))
);
} }
} }

25
protocol/src/lib.rs

@ -16,6 +16,7 @@ pub mod master;
pub mod server; pub mod server;
pub mod types; pub mod types;
pub use cursor::Error as CursorError;
pub use server_info::ServerInfo; pub use server_info::ServerInfo;
use thiserror::Error; use thiserror::Error;
@ -33,13 +34,25 @@ pub enum Error {
/// Failed to decode a packet. /// Failed to decode a packet.
#[error("Invalid packet")] #[error("Invalid packet")]
InvalidPacket, InvalidPacket,
/// Invalid string in a packet. /// Invalid region.
#[error("Invalid UTF-8 string")] #[error("Invalid region")]
InvalidString, InvalidRegion,
/// Buffer size is no enougth to decode or encode a packet. /// Invalid client announce IP.
#[error("Unexpected end of buffer")] #[error("Invalid client announce IP")]
UnexpectedEnd, InvalidClientAnnounceIp,
/// Invalid last IP.
#[error("Invalid last server IP")]
InvalidQueryServersLast,
/// Server protocol version is not supported. /// Server protocol version is not supported.
#[error("Invalid protocol version")] #[error("Invalid protocol version")]
InvalidProtocolVersion, InvalidProtocolVersion,
/// Cursor error.
#[error("{0}")]
CursorError(#[from] CursorError),
/// Invalid value for server add packet.
#[error("Invalid value for server add key `{0}`: {1}")]
InvalidServerValue(&'static str, #[source] CursorError),
/// Invalid value for query servers packet.
#[error("Invalid value for filter key `{0}`: {1}")]
InvalidFilterValue(&'static str, #[source] CursorError),
} }

51
protocol/src/master.rs

@ -174,7 +174,7 @@ impl ClientAnnounce {
let addr = cur let addr = cur
.get_str(cur.remaining())? .get_str(cur.remaining())?
.parse() .parse()
.map_err(|_| Error::InvalidPacket)?; .map_err(|_| Error::InvalidClientAnnounceIp)?;
cur.expect_empty()?; cur.expect_empty()?;
Ok(Self { addr }) Ok(Self { addr })
} }
@ -247,24 +247,19 @@ pub enum Packet<'a> {
impl<'a> Packet<'a> { impl<'a> Packet<'a> {
/// Decode packet from `src`. /// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> { pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
if let Ok(p) = ChallengeResponse::decode(src) { if src.starts_with(ChallengeResponse::HEADER) {
return Ok(Self::ChallengeResponse(p)); ChallengeResponse::decode(src).map(Self::ChallengeResponse)
} } else if src.starts_with(QueryServersResponse::HEADER) {
QueryServersResponse::decode(src).map(Self::QueryServersResponse)
if let Ok(p) = QueryServersResponse::decode(src) { } else if src.starts_with(ClientAnnounce::HEADER) {
return Ok(Self::QueryServersResponse(p)); ClientAnnounce::decode(src).map(Self::ClientAnnounce)
} } else if src.starts_with(AdminChallengeResponse::HEADER) {
AdminChallengeResponse::decode(src).map(Self::AdminChallengeResponse)
if let Ok(p) = ClientAnnounce::decode(src) { } else {
return Ok(Self::ClientAnnounce(p)); return Ok(None);
}
if let Ok(p) = AdminChallengeResponse::decode(src) {
return Ok(Self::AdminChallengeResponse(p));
} }
.map(Some)
Err(Error::InvalidPacket)
} }
} }
@ -277,7 +272,10 @@ mod tests {
let p = ChallengeResponse::new(0x12345678, Some(0x87654321)); let p = ChallengeResponse::new(0x12345678, Some(0x87654321));
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(ChallengeResponse::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::ChallengeResponse(p)))
);
} }
#[test] #[test]
@ -291,7 +289,10 @@ mod tests {
let p = ChallengeResponse::new(0x12345678, None); let p = ChallengeResponse::new(0x12345678, None);
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(ChallengeResponse::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::ChallengeResponse(p)))
);
} }
#[test] #[test]
@ -314,7 +315,10 @@ mod tests {
let p = ClientAnnounce::new("1.2.3.4:12345".parse().unwrap()); let p = ClientAnnounce::new("1.2.3.4:12345".parse().unwrap());
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(ClientAnnounce::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::ClientAnnounce(p)))
);
} }
#[test] #[test]
@ -322,6 +326,9 @@ mod tests {
let p = AdminChallengeResponse::new(0x12345678, 0x87654321); let p = AdminChallengeResponse::new(0x12345678, 0x87654321);
let mut buf = [0; 64]; let mut buf = [0; 64];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(AdminChallengeResponse::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::AdminChallengeResponse(p)))
);
} }
} }

146
protocol/src/server.rs

@ -11,7 +11,7 @@ use log::debug;
use super::cursor::{Cursor, CursorMut, GetKeyValue, PutKeyValue}; use super::cursor::{Cursor, CursorMut, GetKeyValue, PutKeyValue};
use super::filter::Version; use super::filter::Version;
use super::types::Str; use super::types::Str;
use super::Error; use super::{CursorError, Error};
/// Sended to a master server before `ServerAdd` packet. /// Sended to a master server before `ServerAdd` packet.
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
@ -74,7 +74,7 @@ impl Default for Os {
} }
impl TryFrom<&[u8]> for Os { impl TryFrom<&[u8]> for Os {
type Error = Error; type Error = CursorError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> { fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
match value { match value {
@ -87,7 +87,7 @@ impl TryFrom<&[u8]> for Os {
} }
impl GetKeyValue<'_> for Os { impl GetKeyValue<'_> for Os {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> { fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
cur.get_key_value_raw()?.try_into() cur.get_key_value_raw()?.try_into()
} }
} }
@ -96,7 +96,7 @@ impl PutKeyValue for Os {
fn put_key_value<'a, 'b>( fn put_key_value<'a, 'b>(
&self, &self,
cur: &'b mut CursorMut<'a>, cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, Error> { ) -> Result<&'b mut CursorMut<'a>, CursorError> {
match self { match self {
Self::Linux => cur.put_str("l"), Self::Linux => cur.put_str("l"),
Self::Windows => cur.put_str("w"), Self::Windows => cur.put_str("w"),
@ -139,7 +139,7 @@ impl Default for ServerType {
} }
impl TryFrom<&[u8]> for ServerType { impl TryFrom<&[u8]> for ServerType {
type Error = Error; type Error = CursorError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> { fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
match value { match value {
@ -152,7 +152,7 @@ impl TryFrom<&[u8]> for ServerType {
} }
impl GetKeyValue<'_> for ServerType { impl GetKeyValue<'_> for ServerType {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> { fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
cur.get_key_value_raw()?.try_into() cur.get_key_value_raw()?.try_into()
} }
} }
@ -161,7 +161,7 @@ impl PutKeyValue for ServerType {
fn put_key_value<'a, 'b>( fn put_key_value<'a, 'b>(
&self, &self,
cur: &'b mut CursorMut<'a>, cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, Error> { ) -> Result<&'b mut CursorMut<'a>, CursorError> {
match self { match self {
Self::Dedicated => cur.put_str("d"), Self::Dedicated => cur.put_str("d"),
Self::Local => cur.put_str("l"), Self::Local => cur.put_str("l"),
@ -217,7 +217,7 @@ impl Default for Region {
} }
impl TryFrom<u8> for Region { impl TryFrom<u8> for Region {
type Error = Error; type Error = CursorError;
fn try_from(value: u8) -> Result<Self, Self::Error> { fn try_from(value: u8) -> Result<Self, Self::Error> {
match value { match value {
@ -230,13 +230,13 @@ impl TryFrom<u8> for Region {
0x06 => Ok(Region::MiddleEast), 0x06 => Ok(Region::MiddleEast),
0x07 => Ok(Region::Africa), 0x07 => Ok(Region::Africa),
0xff => Ok(Region::RestOfTheWorld), 0xff => Ok(Region::RestOfTheWorld),
_ => Err(Error::InvalidPacket), _ => Err(CursorError::InvalidNumber),
} }
} }
} }
impl GetKeyValue<'_> for Region { impl GetKeyValue<'_> for Region {
fn get_key_value(cur: &mut Cursor) -> Result<Self, Error> { fn get_key_value(cur: &mut Cursor) -> Result<Self, CursorError> {
cur.get_key_value::<u8>()?.try_into() cur.get_key_value::<u8>()?.try_into()
} }
} }
@ -304,28 +304,38 @@ where
{ {
/// Decode packet from `src`. /// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> { pub fn decode(src: &'a [u8]) -> Result<Self, Error> {
trait Helper<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error>;
}
impl<'a> Helper<'a> for Cursor<'a> {
fn get<T: GetKeyValue<'a>>(&mut self, key: &'static str) -> Result<T, Error> {
T::get_key_value(self).map_err(|e| Error::InvalidServerValue(key, e))
}
}
let mut cur = Cursor::new(src); let mut cur = Cursor::new(src);
cur.expect(ServerAdd::HEADER)?; cur.expect(ServerAdd::HEADER)?;
let mut ret = Self::default(); let mut ret = Self::default();
let mut challenge = None; let mut challenge = None;
while cur.as_slice().starts_with(&[b'\\']) { loop {
let key = match cur.get_key_raw() { let key = match cur.get_key_raw() {
Ok(s) => s, Ok(s) => s,
Err(Error::UnexpectedEnd) => break, Err(CursorError::TableEnd) => break,
Err(e) => return Err(e), Err(e) => Err(e)?,
}; };
match key { match key {
b"protocol" => ret.protocol = cur.get_key_value()?, b"protocol" => ret.protocol = cur.get("protocol")?,
b"challenge" => challenge = Some(cur.get_key_value()?), b"challenge" => challenge = Some(cur.get("challenge")?),
b"players" => ret.players = cur.get_key_value()?, b"players" => ret.players = cur.get("players")?,
b"max" => ret.max = cur.get_key_value()?, b"max" => ret.max = cur.get("max")?,
b"gamedir" => ret.gamedir = cur.get_key_value()?, b"gamedir" => ret.gamedir = cur.get("gamedir")?,
b"product" => { let _ = cur.get_key_value::<Str<&[u8]>>()?; }, // legacy key, ignore b"product" => cur.skip_key_value::<&[u8]>()?, // legacy key, ignore
b"map" => ret.map = cur.get_key_value()?, b"map" => ret.map = cur.get("map")?,
b"type" => ret.server_type = cur.get_key_value()?, b"type" => ret.server_type = cur.get("type")?,
b"os" => ret.os = cur.get_key_value()?, b"os" => ret.os = cur.get("os")?,
b"version" => { b"version" => {
ret.version = cur ret.version = cur
.get_key_value() .get_key_value()
@ -335,12 +345,14 @@ where
}) })
.unwrap_or_default() .unwrap_or_default()
} }
b"region" => ret.region = cur.get_key_value()?, b"region" => ret.region = cur.get("region")?,
b"bots" => ret.flags.set(ServerFlags::BOTS, cur.get_key_value::<u8>()? != 0), b"bots" => ret
b"password" => ret.flags.set(ServerFlags::PASSWORD, cur.get_key_value()?), .flags
b"secure" => ret.flags.set(ServerFlags::SECURE, cur.get_key_value()?), .set(ServerFlags::BOTS, cur.get::<u8>("bots")? != 0),
b"lan" => ret.flags.set(ServerFlags::LAN, cur.get_key_value()?), b"password" => ret.flags.set(ServerFlags::PASSWORD, cur.get("password")?),
b"nat" => ret.flags.set(ServerFlags::NAT, cur.get_key_value()?), b"secure" => ret.flags.set(ServerFlags::SECURE, cur.get("secure")?),
b"lan" => ret.flags.set(ServerFlags::LAN, cur.get("lan")?),
b"nat" => ret.flags.set(ServerFlags::NAT, cur.get("nat")?),
_ => { _ => {
// skip unknown fields // skip unknown fields
let value = cur.get_key_value::<Str<&[u8]>>()?; let value = cur.get_key_value::<Str<&[u8]>>()?;
@ -354,14 +366,14 @@ where
ret.challenge = c; ret.challenge = c;
Ok(ret) Ok(ret)
} }
None => Err(Error::InvalidPacket), None => Err(Error::InvalidServerValue("challenge", CursorError::Expect)),
} }
} }
} }
impl<T> ServerAdd<T> impl<T> ServerAdd<T>
where where
T: PutKeyValue + Clone, T: PutKeyValue,
{ {
/// Encode packet to `buf`. /// Encode packet to `buf`.
pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> { pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
@ -371,8 +383,8 @@ where
.put_key("challenge", self.challenge)? .put_key("challenge", self.challenge)?
.put_key("players", self.players)? .put_key("players", self.players)?
.put_key("max", self.max)? .put_key("max", self.max)?
.put_key("gamedir", self.gamedir.clone())? .put_key("gamedir", &self.gamedir)?
.put_key("map", self.map.clone())? .put_key("map", &self.map)?
.put_key("type", self.server_type)? .put_key("type", self.server_type)?
.put_key("os", self.os)? .put_key("os", self.os)?
.put_key("version", self.version)? .put_key("version", self.version)?
@ -469,8 +481,8 @@ where
loop { loop {
let key = match cur.get_key_raw() { let key = match cur.get_key_raw() {
Ok(s) => s, Ok(s) => s,
Err(Error::UnexpectedEnd) => break, Err(CursorError::TableEnd) => break,
Err(e) => return Err(e), Err(e) => Err(e)?,
}; };
match key { match key {
@ -500,21 +512,24 @@ where
} }
} }
impl<'a> GetServerInfoResponse<&'a str> { impl<T> GetServerInfoResponse<T>
where
T: PutKeyValue,
{
/// Encode packet to `buf`. /// Encode packet to `buf`.
pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> { pub fn encode(&self, buf: &mut [u8]) -> Result<usize, Error> {
Ok(CursorMut::new(buf) Ok(CursorMut::new(buf)
.put_bytes(GetServerInfoResponse::HEADER)? .put_bytes(GetServerInfoResponse::HEADER)?
.put_key("p", self.protocol)? .put_key("p", self.protocol)?
.put_key("map", self.map)? .put_key("map", &self.map)?
.put_key("dm", self.dm)? .put_key("dm", self.dm)?
.put_key("team", self.team)? .put_key("team", self.team)?
.put_key("coop", self.coop)? .put_key("coop", self.coop)?
.put_key("numcl", self.numcl)? .put_key("numcl", self.numcl)?
.put_key("maxcl", self.maxcl)? .put_key("maxcl", self.maxcl)?
.put_key("gamedir", self.gamedir)? .put_key("gamedir", &self.gamedir)?
.put_key("password", self.password)? .put_key("password", self.password)?
.put_key("host", self.host)? .put_key("host", &self.host)?
.pos()) .pos())
} }
} }
@ -534,24 +549,19 @@ pub enum Packet<'a> {
impl<'a> Packet<'a> { impl<'a> Packet<'a> {
/// Decode packet from `src`. /// Decode packet from `src`.
pub fn decode(src: &'a [u8]) -> Result<Self, Error> { pub fn decode(src: &'a [u8]) -> Result<Option<Self>, Error> {
if let Ok(p) = Challenge::decode(src) { if src.starts_with(Challenge::HEADER) {
return Ok(Self::Challenge(p)); Challenge::decode(src).map(Self::Challenge)
} } else if src.starts_with(ServerAdd::HEADER) {
ServerAdd::decode(src).map(Self::ServerAdd)
if let Ok(p) = ServerAdd::decode(src) { } else if src.starts_with(ServerRemove::HEADER) {
return Ok(Self::ServerAdd(p)); ServerRemove::decode(src).map(|_| Self::ServerRemove)
} } else if src.starts_with(GetServerInfoResponse::HEADER) {
GetServerInfoResponse::decode(src).map(Self::GetServerInfoResponse)
if ServerRemove::decode(src).is_ok() { } else {
return Ok(Self::ServerRemove); return Ok(None);
}
if let Ok(p) = GetServerInfoResponse::decode(src) {
return Ok(Self::GetServerInfoResponse(p));
} }
.map(Some)
Err(Error::InvalidPacket)
} }
} }
@ -564,13 +574,16 @@ mod tests {
let p = Challenge::new(Some(0x12345678)); let p = Challenge::new(Some(0x12345678));
let mut buf = [0; 128]; let mut buf = [0; 128];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(Challenge::decode(&buf[..n]), Ok(p)); assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::Challenge(p))));
} }
#[test] #[test]
fn challenge_old() { fn challenge_old() {
let s = b"q\xff"; let s = b"q\xff";
assert_eq!(Challenge::decode(s), Ok(Challenge::new(None))); assert_eq!(
Packet::decode(s),
Ok(Some(Packet::Challenge(Challenge::new(None))))
);
let p = Challenge::new(None); let p = Challenge::new(None);
let mut buf = [0; 128]; let mut buf = [0; 128];
@ -581,8 +594,8 @@ mod tests {
#[test] #[test]
fn server_add() { fn server_add() {
let p = ServerAdd { let p = ServerAdd {
gamedir: "valve", gamedir: Str(&b"valve"[..]),
map: "crossfire", map: Str(&b"crossfire"[..]),
version: Version::new(0, 20), version: Version::new(0, 20),
challenge: 0x12345678, challenge: 0x12345678,
server_type: ServerType::Dedicated, server_type: ServerType::Dedicated,
@ -595,7 +608,7 @@ mod tests {
}; };
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(ServerAdd::decode(&buf[..n]), Ok(p)); assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::ServerAdd(p))));
} }
#[test] #[test]
@ -603,26 +616,29 @@ mod tests {
let p = ServerRemove; let p = ServerRemove;
let mut buf = [0; 64]; let mut buf = [0; 64];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(ServerRemove::decode(&buf[..n]), Ok(p)); assert_eq!(Packet::decode(&buf[..n]), Ok(Some(Packet::ServerRemove)));
} }
#[test] #[test]
fn get_server_info_response() { fn get_server_info_response() {
let p = GetServerInfoResponse { let p = GetServerInfoResponse {
protocol: 49, protocol: 49,
map: "crossfire", map: Str("crossfire".as_bytes()),
dm: true, dm: true,
team: true, team: true,
coop: true, coop: true,
numcl: 4, numcl: 4,
maxcl: 32, maxcl: 32,
gamedir: "valve", gamedir: Str("valve".as_bytes()),
password: true, password: true,
host: "Test", host: Str("Test".as_bytes()),
}; };
let mut buf = [0; 512]; let mut buf = [0; 512];
let n = p.encode(&mut buf).unwrap(); let n = p.encode(&mut buf).unwrap();
assert_eq!(GetServerInfoResponse::decode(&buf[..n]), Ok(p)); assert_eq!(
Packet::decode(&buf[..n]),
Ok(Some(Packet::GetServerInfoResponse(p)))
);
} }
#[test] #[test]

12
protocol/src/types.rs

@ -6,6 +6,9 @@
use std::fmt; use std::fmt;
use std::ops::Deref; use std::ops::Deref;
use crate::cursor::{CursorMut, PutKeyValue};
use crate::CursorError;
/// Wrapper for slice of bytes with printing the bytes as a string. /// Wrapper for slice of bytes with printing the bytes as a string.
/// ///
/// # Examples /// # Examples
@ -24,6 +27,15 @@ impl<T> From<T> for Str<T> {
} }
} }
impl PutKeyValue for Str<&[u8]> {
fn put_key_value<'a, 'b>(
&self,
cur: &'b mut CursorMut<'a>,
) -> Result<&'b mut CursorMut<'a>, CursorError> {
cur.put_bytes(self.0)
}
}
impl<T> fmt::Debug for Str<T> impl<T> fmt::Debug for Str<T>
where where
T: AsRef<[u8]>, T: AsRef<[u8]>,

Loading…
Cancel
Save