diff --git a/protocol/Cargo.toml b/protocol/Cargo.toml index cc3ceb6..090e8b0 100644 --- a/protocol/Cargo.toml +++ b/protocol/Cargo.toml @@ -12,9 +12,10 @@ homepage = "https://xash.su" repository = "https://git.mentality.rip/numas13/xash3d-master" [features] -default = ["std"] +default = ["std", "net"] std = ["alloc"] alloc = [] +net = [] [dependencies] log = "0.4.18" diff --git a/protocol/src/cursor.rs b/protocol/src/cursor.rs index f0aeafb..3c0bda6 100644 --- a/protocol/src/cursor.rs +++ b/protocol/src/cursor.rs @@ -1,16 +1,13 @@ // SPDX-License-Identifier: LGPL-3.0-only // SPDX-FileCopyrightText: 2023 Denis Drakhnia -use core::{ - fmt::{self, Write}, - mem, str, -}; +mod cursor; +mod cursor_mut; -#[cfg(feature = "alloc")] -use alloc::{borrow::ToOwned, boxed::Box, string::String}; +use core::fmt; -use super::color; -use super::wrappers::Str; +pub use cursor::{Cursor, GetKeyValue}; +pub use cursor_mut::{CursorMut, PutKeyValue}; /// The error type for `Cursor` and `CursorMut`. #[derive(Debug, PartialEq, Eq)] @@ -56,472 +53,12 @@ impl core::error::Error for CursorError {} pub type Result = core::result::Result; -pub trait GetKeyValue<'a>: Sized { - fn get_key_value(cur: &mut Cursor<'a>) -> Result; -} - -impl<'a> GetKeyValue<'a> for &'a [u8] { - fn get_key_value(cur: &mut Cursor<'a>) -> Result { - cur.get_key_value_raw() - } -} - -impl<'a> GetKeyValue<'a> for Str<&'a [u8]> { - fn get_key_value(cur: &mut Cursor<'a>) -> Result { - cur.get_key_value_raw().map(Str) - } -} - -impl<'a> GetKeyValue<'a> for &'a str { - fn get_key_value(cur: &mut Cursor<'a>) -> Result { - let raw = cur.get_key_value_raw()?; - str::from_utf8(raw).map_err(|_| CursorError::InvalidString) - } -} - -#[cfg(feature = "alloc")] -impl<'a> GetKeyValue<'a> for Box { - fn get_key_value(cur: &mut Cursor<'a>) -> Result { - let raw = cur.get_key_value_raw()?; - str::from_utf8(raw) - .map(|s| s.to_owned().into_boxed_str()) - .map_err(|_| CursorError::InvalidString) - } -} - -#[cfg(feature = "alloc")] -impl<'a> GetKeyValue<'a> for String { - fn get_key_value(cur: &mut Cursor<'a>) -> Result { - let raw = cur.get_key_value_raw()?; - str::from_utf8(raw) - .map(|s| s.to_owned()) - .map_err(|_| CursorError::InvalidString) - } -} - -impl<'a> GetKeyValue<'a> for bool { - fn get_key_value(cur: &mut Cursor<'a>) -> Result { - match cur.get_key_value_raw()? { - b"0" => Ok(false), - b"1" => Ok(true), - _ => Err(CursorError::InvalidBool), - } - } -} - -macro_rules! impl_get_value { - ($($t:ty),+ $(,)?) => { - $(impl<'a> GetKeyValue<'a> for $t { - fn get_key_value(cur: &mut Cursor<'a>) -> Result { - let s = cur.get_key_value::<&str>()?; - // HACK: special case for one asshole - let (_, s) = color::trim_start_color(s); - s.parse().map_err(|_| CursorError::InvalidNumber) - } - })+ - }; -} - -impl_get_value! { - u8, - u16, - u32, - u64, - - i8, - i16, - i32, - i64, -} - -// TODO: impl GetKeyValue for f32 and f64 - -#[derive(Copy, Clone)] -pub struct Cursor<'a> { - buffer: &'a [u8], -} - -macro_rules! impl_get { - ($($n:ident: $t:ty = $f:ident),+ $(,)?) => ( - $(#[inline] - pub fn $n(&mut self) -> Result<$t> { - const N: usize = mem::size_of::<$t>(); - self.get_array::().map(<$t>::$f) - })+ - ); -} - -impl<'a> Cursor<'a> { - pub fn new(buffer: &'a [u8]) -> Self { - Self { buffer } - } - - pub fn end(self) -> &'a [u8] { - self.buffer - } - - pub fn as_slice(&'a self) -> &'a [u8] { - self.buffer - } - - #[inline(always)] - pub fn remaining(&self) -> usize { - self.buffer.len() - } - - #[inline(always)] - pub fn has_remaining(&self) -> bool { - self.remaining() != 0 - } - - pub fn get_bytes(&mut self, count: usize) -> Result<&'a [u8]> { - if count <= self.remaining() { - let (head, tail) = self.buffer.split_at(count); - self.buffer = tail; - Ok(head) - } else { - Err(CursorError::UnexpectedEnd) - } - } - - pub fn advance(&mut self, count: usize) -> Result<()> { - self.get_bytes(count).map(|_| ()) - } - - pub fn get_array(&mut self) -> Result<[u8; N]> { - self.get_bytes(N).map(|s| { - let mut array = [0; N]; - array.copy_from_slice(s); - array - }) - } - - pub fn get_str(&mut self, n: usize) -> Result<&'a str> { - let mut cur = *self; - let s = cur - .get_bytes(n) - .and_then(|s| str::from_utf8(s).map_err(|_| CursorError::InvalidString))?; - *self = cur; - Ok(s) - } - - pub fn get_cstr(&mut self) -> Result> { - let pos = self - .buffer - .iter() - .position(|&c| c == b'\0') - .ok_or(CursorError::UnexpectedEnd)?; - let (head, tail) = self.buffer.split_at(pos); - self.buffer = &tail[1..]; - Ok(Str(&head[..pos])) - } - - pub fn get_cstr_as_str(&mut self) -> Result<&'a str> { - str::from_utf8(&self.get_cstr()?).map_err(|_| CursorError::InvalidString) - } - - #[inline(always)] - pub fn get_u8(&mut self) -> Result { - self.get_array::<1>().map(|s| s[0]) - } - - #[inline(always)] - pub fn get_i8(&mut self) -> Result { - self.get_array::<1>().map(|s| s[0] as i8) - } - - impl_get! { - get_u16_le: u16 = from_le_bytes, - get_u32_le: u32 = from_le_bytes, - get_u64_le: u64 = from_le_bytes, - get_i16_le: i16 = from_le_bytes, - get_i32_le: i32 = from_le_bytes, - get_i64_le: i64 = from_le_bytes, - get_f32_le: f32 = from_le_bytes, - get_f64_le: f64 = from_le_bytes, - - get_u16_be: u16 = from_be_bytes, - get_u32_be: u32 = from_be_bytes, - get_u64_be: u64 = from_be_bytes, - get_i16_be: i16 = from_be_bytes, - get_i32_be: i32 = from_be_bytes, - get_i64_be: i64 = from_be_bytes, - get_f32_be: f32 = from_be_bytes, - get_f64_be: f64 = from_be_bytes, - - get_u16_ne: u16 = from_ne_bytes, - get_u32_ne: u32 = from_ne_bytes, - get_u64_ne: u64 = from_ne_bytes, - get_i16_ne: i16 = from_ne_bytes, - get_i32_ne: i32 = from_ne_bytes, - get_i64_ne: i64 = from_ne_bytes, - get_f32_ne: f32 = from_ne_bytes, - get_f64_ne: f64 = from_ne_bytes, - } - - pub fn expect(&mut self, s: &[u8]) -> Result<()> { - if self.buffer.starts_with(s) { - self.advance(s.len())?; - Ok(()) - } else { - Err(CursorError::Expect) - } - } - - pub fn expect_empty(&self) -> Result<()> { - if self.has_remaining() { - Err(CursorError::ExpectEmpty) - } else { - Ok(()) - } - } - - pub fn take_while(&mut self, mut cond: F) -> Result<&'a [u8]> - where - F: FnMut(u8) -> bool, - { - self.buffer - .iter() - .position(|&i| !cond(i)) - .ok_or(CursorError::UnexpectedEnd) - .and_then(|n| self.get_bytes(n)) - } - - pub fn take_while_or_all(&mut self, cond: F) -> &'a [u8] - where - F: FnMut(u8) -> bool, - { - self.take_while(cond).unwrap_or_else(|_| { - let (head, tail) = self.buffer.split_at(self.buffer.len()); - self.buffer = tail; - head - }) - } - - pub fn get_key_value_raw(&mut self) -> Result<&'a [u8]> { - let mut cur = *self; - match cur.get_u8()? { - b'\\' => { - let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n'); - *self = cur; - Ok(value) - } - _ => Err(CursorError::InvalidTableValue), - } - } - - pub fn get_key_value>(&mut self) -> Result { - T::get_key_value(self) - } - - pub fn skip_key_value>(&mut self) -> Result<()> { - T::get_key_value(self).map(|_| ()) - } - - pub fn get_key_raw(&mut self) -> Result<&'a [u8]> { - let mut cur = *self; - match cur.get_u8() { - Ok(b'\\') => { - let value = cur.take_while(|c| c != b'\\' && c != b'\n')?; - *self = cur; - Ok(value) - } - Ok(b'\n') | Err(CursorError::UnexpectedEnd) => Err(CursorError::TableEnd), - _ => Err(CursorError::InvalidTableKey), - } - } - - pub fn get_key>(&mut self) -> Result<(&'a [u8], T)> { - Ok((self.get_key_raw()?, self.get_key_value()?)) - } -} - -pub trait PutKeyValue { - fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>>; -} - -impl PutKeyValue for &T -where - T: PutKeyValue, -{ - fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { - (*self).put_key_value(cur) - } -} - -impl PutKeyValue for &str { - fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { - cur.put_str(self) - } -} - -impl PutKeyValue for bool { - fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { - cur.put_u8(if *self { b'1' } else { b'0' }) - } -} - -macro_rules! impl_put_key_value { - ($($t:ty),+ $(,)?) => { - $(impl PutKeyValue for $t { - fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { - cur.put_as_str(self) - } - })+ - }; -} - -impl_put_key_value! { - u8, - u16, - u32, - u64, - - i8, - i16, - i32, - i64, - - f32, - f64, -} - -pub struct CursorMut<'a> { - buffer: &'a mut [u8], - pos: usize, -} - -macro_rules! impl_put { - ($($n:ident: $t:ty = $f:ident),+ $(,)?) => ( - $(#[inline] - pub fn $n(&mut self, n: $t) -> Result<&mut Self> { - self.put_array(&n.$f()) - })+ - ); -} - -impl<'a> CursorMut<'a> { - pub fn new(buffer: &'a mut [u8]) -> Self { - Self { buffer, pos: 0 } - } - - pub fn pos(&mut self) -> usize { - self.pos - } - - #[inline(always)] - pub fn available(&self) -> usize { - self.buffer.len() - self.pos - } - - pub fn advance(&mut self, count: usize, mut f: F) -> Result<&mut Self> - where - F: FnMut(&mut [u8]), - { - if count <= self.available() { - f(&mut self.buffer[self.pos..self.pos + count]); - self.pos += count; - Ok(self) - } else { - Err(CursorError::UnexpectedEnd) - } - } - - pub fn put_bytes(&mut self, s: &[u8]) -> Result<&mut Self> { - self.advance(s.len(), |i| { - i.copy_from_slice(s); - }) - } - - pub fn put_array(&mut self, s: &[u8; N]) -> Result<&mut Self> { - self.advance(N, |i| { - i.copy_from_slice(s); - }) - } - - pub fn put_str(&mut self, s: &str) -> Result<&mut Self> { - self.put_bytes(s.as_bytes()) - } - - pub fn put_cstr(&mut self, s: &str) -> Result<&mut Self> { - self.put_str(s)?.put_u8(0) - } - - #[inline(always)] - pub fn put_u8(&mut self, n: u8) -> Result<&mut Self> { - self.put_array(&[n]) - } - - #[inline(always)] - pub fn put_i8(&mut self, n: i8) -> Result<&mut Self> { - self.put_u8(n as u8) - } - - impl_put! { - put_u16_le: u16 = to_le_bytes, - put_u32_le: u32 = to_le_bytes, - put_u64_le: u64 = to_le_bytes, - put_i16_le: i16 = to_le_bytes, - put_i32_le: i32 = to_le_bytes, - put_i64_le: i64 = to_le_bytes, - put_f32_le: f32 = to_le_bytes, - put_f64_le: f64 = to_le_bytes, - - put_u16_be: u16 = to_be_bytes, - put_u32_be: u32 = to_be_bytes, - put_u64_be: u64 = to_be_bytes, - put_i16_be: i16 = to_be_bytes, - put_i32_be: i32 = to_be_bytes, - put_i64_be: i64 = to_be_bytes, - put_f32_be: f32 = to_be_bytes, - put_f64_be: f64 = to_be_bytes, - - put_u16_ne: u16 = to_ne_bytes, - put_u32_ne: u32 = to_ne_bytes, - put_u64_ne: u64 = to_ne_bytes, - put_i16_ne: i16 = to_ne_bytes, - put_i32_ne: i32 = to_ne_bytes, - put_i64_ne: i64 = to_ne_bytes, - put_f32_ne: f32 = to_ne_bytes, - put_f64_ne: f64 = to_ne_bytes, - } - - pub fn put_as_str(&mut self, value: T) -> Result<&mut Self> { - write!(self, "{}", value).map_err(|_| CursorError::UnexpectedEnd)?; - Ok(self) - } - - pub fn put_key_value(&mut self, value: T) -> Result<&mut Self> { - value.put_key_value(self) - } - - pub fn put_key_raw(&mut self, key: &str, value: &[u8]) -> Result<&mut Self> { - self.put_u8(b'\\')? - .put_str(key)? - .put_u8(b'\\')? - .put_bytes(value) - } - - pub fn put_key(&mut self, key: &str, value: T) -> Result<&mut Self> { - self.put_u8(b'\\')? - .put_str(key)? - .put_u8(b'\\')? - .put_key_value(value) - } -} - -impl fmt::Write for CursorMut<'_> { - fn write_str(&mut self, s: &str) -> fmt::Result { - self.put_bytes(s.as_bytes()) - .map(|_| ()) - .map_err(|_| fmt::Error) - } -} - #[cfg(test)] mod tests { use super::*; + use crate::wrappers::Str; + #[test] fn cursor() -> Result<()> { let mut buf = [0; 64]; diff --git a/protocol/src/cursor/cursor.rs b/protocol/src/cursor/cursor.rs new file mode 100644 index 0000000..896c7fc --- /dev/null +++ b/protocol/src/cursor/cursor.rs @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: LGPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +#![cfg_attr(not(feature = "net"), allow(dead_code))] + +use core::{mem, str}; + +#[cfg(feature = "alloc")] +use alloc::{borrow::ToOwned, boxed::Box, string::String}; + +use crate::{color, wrappers::Str}; + +use super::{CursorError, Result}; + +pub trait GetKeyValue<'a>: Sized { + fn get_key_value(cur: &mut Cursor<'a>) -> Result; +} + +impl<'a> GetKeyValue<'a> for &'a [u8] { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + cur.get_key_value_raw() + } +} + +impl<'a> GetKeyValue<'a> for Str<&'a [u8]> { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + cur.get_key_value_raw().map(Str) + } +} + +impl<'a> GetKeyValue<'a> for &'a str { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + let raw = cur.get_key_value_raw()?; + str::from_utf8(raw).map_err(|_| CursorError::InvalidString) + } +} + +#[cfg(feature = "alloc")] +impl<'a> GetKeyValue<'a> for Box { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + let raw = cur.get_key_value_raw()?; + str::from_utf8(raw) + .map(|s| s.to_owned().into_boxed_str()) + .map_err(|_| CursorError::InvalidString) + } +} + +#[cfg(feature = "alloc")] +impl<'a> GetKeyValue<'a> for String { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + let raw = cur.get_key_value_raw()?; + str::from_utf8(raw) + .map(|s| s.to_owned()) + .map_err(|_| CursorError::InvalidString) + } +} + +impl<'a> GetKeyValue<'a> for bool { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + match cur.get_key_value_raw()? { + b"0" => Ok(false), + b"1" => Ok(true), + _ => Err(CursorError::InvalidBool), + } + } +} + +impl GetKeyValue<'_> for crate::server_info::Region { + fn get_key_value(cur: &mut Cursor) -> Result { + cur.get_key_value::()?.try_into() + } +} + +impl GetKeyValue<'_> for crate::server_info::ServerType { + fn get_key_value(cur: &mut Cursor) -> Result { + cur.get_key_value_raw()?.try_into() + } +} + +macro_rules! impl_get_value { + ($($t:ty),+ $(,)?) => { + $(impl<'a> GetKeyValue<'a> for $t { + fn get_key_value(cur: &mut Cursor<'a>) -> Result { + let s = cur.get_key_value::<&str>()?; + // HACK: special case for one asshole + let (_, s) = color::trim_start_color(s); + s.parse().map_err(|_| CursorError::InvalidNumber) + } + })+ + }; +} + +impl_get_value! { + u8, + u16, + u32, + u64, + + i8, + i16, + i32, + i64, +} + +// TODO: impl GetKeyValue for f32 and f64 + +#[derive(Copy, Clone)] +pub struct Cursor<'a> { + buffer: &'a [u8], +} + +macro_rules! impl_get { + ($($n:ident: $t:ty = $f:ident),+ $(,)?) => ( + $(#[inline] + pub fn $n(&mut self) -> Result<$t> { + const N: usize = mem::size_of::<$t>(); + self.get_array::().map(<$t>::$f) + })+ + ); +} + +impl<'a> Cursor<'a> { + pub fn new(buffer: &'a [u8]) -> Self { + Self { buffer } + } + + pub fn end(self) -> &'a [u8] { + self.buffer + } + + pub fn as_slice(&'a self) -> &'a [u8] { + self.buffer + } + + #[inline(always)] + pub fn remaining(&self) -> usize { + self.buffer.len() + } + + #[inline(always)] + pub fn has_remaining(&self) -> bool { + self.remaining() != 0 + } + + pub fn get_bytes(&mut self, count: usize) -> Result<&'a [u8]> { + if count <= self.remaining() { + let (head, tail) = self.buffer.split_at(count); + self.buffer = tail; + Ok(head) + } else { + Err(CursorError::UnexpectedEnd) + } + } + + pub fn advance(&mut self, count: usize) -> Result<()> { + self.get_bytes(count).map(|_| ()) + } + + pub fn get_array(&mut self) -> Result<[u8; N]> { + self.get_bytes(N).map(|s| { + let mut array = [0; N]; + array.copy_from_slice(s); + array + }) + } + + pub fn get_str(&mut self, n: usize) -> Result<&'a str> { + let mut cur = *self; + let s = cur + .get_bytes(n) + .and_then(|s| str::from_utf8(s).map_err(|_| CursorError::InvalidString))?; + *self = cur; + Ok(s) + } + + pub fn get_cstr(&mut self) -> Result> { + let pos = self + .buffer + .iter() + .position(|&c| c == b'\0') + .ok_or(CursorError::UnexpectedEnd)?; + let (head, tail) = self.buffer.split_at(pos); + self.buffer = &tail[1..]; + Ok(Str(&head[..pos])) + } + + pub fn get_cstr_as_str(&mut self) -> Result<&'a str> { + str::from_utf8(&self.get_cstr()?).map_err(|_| CursorError::InvalidString) + } + + #[inline(always)] + pub fn get_u8(&mut self) -> Result { + self.get_array::<1>().map(|s| s[0]) + } + + #[inline(always)] + pub fn get_i8(&mut self) -> Result { + self.get_array::<1>().map(|s| s[0] as i8) + } + + impl_get! { + get_u16_le: u16 = from_le_bytes, + get_u32_le: u32 = from_le_bytes, + get_u64_le: u64 = from_le_bytes, + get_i16_le: i16 = from_le_bytes, + get_i32_le: i32 = from_le_bytes, + get_i64_le: i64 = from_le_bytes, + get_f32_le: f32 = from_le_bytes, + get_f64_le: f64 = from_le_bytes, + + get_u16_be: u16 = from_be_bytes, + get_u32_be: u32 = from_be_bytes, + get_u64_be: u64 = from_be_bytes, + get_i16_be: i16 = from_be_bytes, + get_i32_be: i32 = from_be_bytes, + get_i64_be: i64 = from_be_bytes, + get_f32_be: f32 = from_be_bytes, + get_f64_be: f64 = from_be_bytes, + + get_u16_ne: u16 = from_ne_bytes, + get_u32_ne: u32 = from_ne_bytes, + get_u64_ne: u64 = from_ne_bytes, + get_i16_ne: i16 = from_ne_bytes, + get_i32_ne: i32 = from_ne_bytes, + get_i64_ne: i64 = from_ne_bytes, + get_f32_ne: f32 = from_ne_bytes, + get_f64_ne: f64 = from_ne_bytes, + } + + pub fn expect(&mut self, s: &[u8]) -> Result<()> { + if self.buffer.starts_with(s) { + self.advance(s.len())?; + Ok(()) + } else { + Err(CursorError::Expect) + } + } + + pub fn expect_empty(&self) -> Result<()> { + if self.has_remaining() { + Err(CursorError::ExpectEmpty) + } else { + Ok(()) + } + } + + pub fn take_while(&mut self, mut cond: F) -> Result<&'a [u8]> + where + F: FnMut(u8) -> bool, + { + self.buffer + .iter() + .position(|&i| !cond(i)) + .ok_or(CursorError::UnexpectedEnd) + .and_then(|n| self.get_bytes(n)) + } + + pub fn take_while_or_all(&mut self, cond: F) -> &'a [u8] + where + F: FnMut(u8) -> bool, + { + self.take_while(cond).unwrap_or_else(|_| { + let (head, tail) = self.buffer.split_at(self.buffer.len()); + self.buffer = tail; + head + }) + } + + pub fn get_key_value_raw(&mut self) -> Result<&'a [u8]> { + let mut cur = *self; + match cur.get_u8()? { + b'\\' => { + let value = cur.take_while_or_all(|c| c != b'\\' && c != b'\n'); + *self = cur; + Ok(value) + } + _ => Err(CursorError::InvalidTableValue), + } + } + + pub fn get_key_value>(&mut self) -> Result { + T::get_key_value(self) + } + + pub fn skip_key_value>(&mut self) -> Result<()> { + T::get_key_value(self).map(|_| ()) + } + + pub fn get_key_raw(&mut self) -> Result<&'a [u8]> { + let mut cur = *self; + match cur.get_u8() { + Ok(b'\\') => { + let value = cur.take_while(|c| c != b'\\' && c != b'\n')?; + *self = cur; + Ok(value) + } + Ok(b'\n') | Err(CursorError::UnexpectedEnd) => Err(CursorError::TableEnd), + _ => Err(CursorError::InvalidTableKey), + } + } + + pub fn get_key>(&mut self) -> Result<(&'a [u8], T)> { + Ok((self.get_key_raw()?, self.get_key_value()?)) + } +} diff --git a/protocol/src/cursor/cursor_mut.rs b/protocol/src/cursor/cursor_mut.rs new file mode 100644 index 0000000..bd8e1ce --- /dev/null +++ b/protocol/src/cursor/cursor_mut.rs @@ -0,0 +1,207 @@ +// SPDX-License-Identifier: LGPL-3.0-only +// SPDX-FileCopyrightText: 2023 Denis Drakhnia + +#![cfg_attr(not(feature = "net"), allow(dead_code))] + +use core::{ + fmt::{self, Write}, + str, +}; + +use super::{CursorError, Result}; + +pub trait PutKeyValue { + fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>>; +} + +impl PutKeyValue for &T +where + T: PutKeyValue, +{ + fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { + (*self).put_key_value(cur) + } +} + +impl PutKeyValue for &str { + fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { + cur.put_str(self) + } +} + +impl PutKeyValue for bool { + fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { + cur.put_u8(if *self { b'1' } else { b'0' }) + } +} + +impl PutKeyValue for crate::server_info::ServerType { + fn put_key_value<'a, 'b>( + &self, + cur: &'b mut CursorMut<'a>, + ) -> Result<&'b mut CursorMut<'a>, CursorError> { + match self { + Self::Dedicated => cur.put_str("d"), + Self::Local => cur.put_str("l"), + Self::Proxy => cur.put_str("p"), + Self::Unknown => cur.put_str("?"), + } + } +} + +macro_rules! impl_put_key_value { + ($($t:ty),+ $(,)?) => { + $(impl PutKeyValue for $t { + fn put_key_value<'a, 'b>(&self, cur: &'b mut CursorMut<'a>) -> Result<&'b mut CursorMut<'a>> { + cur.put_as_str(self) + } + })+ + }; +} + +impl_put_key_value! { + u8, + u16, + u32, + u64, + + i8, + i16, + i32, + i64, + + f32, + f64, +} + +pub struct CursorMut<'a> { + buffer: &'a mut [u8], + pos: usize, +} + +macro_rules! impl_put { + ($($n:ident: $t:ty = $f:ident),+ $(,)?) => ( + $(#[inline] + pub fn $n(&mut self, n: $t) -> Result<&mut Self> { + self.put_array(&n.$f()) + })+ + ); +} + +impl<'a> CursorMut<'a> { + pub fn new(buffer: &'a mut [u8]) -> Self { + Self { buffer, pos: 0 } + } + + pub fn pos(&mut self) -> usize { + self.pos + } + + #[inline(always)] + pub fn available(&self) -> usize { + self.buffer.len() - self.pos + } + + pub fn advance(&mut self, count: usize, mut f: F) -> Result<&mut Self> + where + F: FnMut(&mut [u8]), + { + if count <= self.available() { + f(&mut self.buffer[self.pos..self.pos + count]); + self.pos += count; + Ok(self) + } else { + Err(CursorError::UnexpectedEnd) + } + } + + pub fn put_bytes(&mut self, s: &[u8]) -> Result<&mut Self> { + self.advance(s.len(), |i| { + i.copy_from_slice(s); + }) + } + + pub fn put_array(&mut self, s: &[u8; N]) -> Result<&mut Self> { + self.advance(N, |i| { + i.copy_from_slice(s); + }) + } + + pub fn put_str(&mut self, s: &str) -> Result<&mut Self> { + self.put_bytes(s.as_bytes()) + } + + pub fn put_cstr(&mut self, s: &str) -> Result<&mut Self> { + self.put_str(s)?.put_u8(0) + } + + #[inline(always)] + pub fn put_u8(&mut self, n: u8) -> Result<&mut Self> { + self.put_array(&[n]) + } + + #[inline(always)] + pub fn put_i8(&mut self, n: i8) -> Result<&mut Self> { + self.put_u8(n as u8) + } + + impl_put! { + put_u16_le: u16 = to_le_bytes, + put_u32_le: u32 = to_le_bytes, + put_u64_le: u64 = to_le_bytes, + put_i16_le: i16 = to_le_bytes, + put_i32_le: i32 = to_le_bytes, + put_i64_le: i64 = to_le_bytes, + put_f32_le: f32 = to_le_bytes, + put_f64_le: f64 = to_le_bytes, + + put_u16_be: u16 = to_be_bytes, + put_u32_be: u32 = to_be_bytes, + put_u64_be: u64 = to_be_bytes, + put_i16_be: i16 = to_be_bytes, + put_i32_be: i32 = to_be_bytes, + put_i64_be: i64 = to_be_bytes, + put_f32_be: f32 = to_be_bytes, + put_f64_be: f64 = to_be_bytes, + + put_u16_ne: u16 = to_ne_bytes, + put_u32_ne: u32 = to_ne_bytes, + put_u64_ne: u64 = to_ne_bytes, + put_i16_ne: i16 = to_ne_bytes, + put_i32_ne: i32 = to_ne_bytes, + put_i64_ne: i64 = to_ne_bytes, + put_f32_ne: f32 = to_ne_bytes, + put_f64_ne: f64 = to_ne_bytes, + } + + pub fn put_as_str(&mut self, value: T) -> Result<&mut Self> { + write!(self, "{}", value).map_err(|_| CursorError::UnexpectedEnd)?; + Ok(self) + } + + pub fn put_key_value(&mut self, value: T) -> Result<&mut Self> { + value.put_key_value(self) + } + + pub fn put_key_raw(&mut self, key: &str, value: &[u8]) -> Result<&mut Self> { + self.put_u8(b'\\')? + .put_str(key)? + .put_u8(b'\\')? + .put_bytes(value) + } + + pub fn put_key(&mut self, key: &str, value: T) -> Result<&mut Self> { + self.put_u8(b'\\')? + .put_str(key)? + .put_u8(b'\\')? + .put_key_value(value) + } +} + +impl fmt::Write for CursorMut<'_> { + fn write_str(&mut self, s: &str) -> fmt::Result { + self.put_bytes(s.as_bytes()) + .map(|_| ()) + .map_err(|_| fmt::Error) + } +} diff --git a/protocol/src/filter.rs b/protocol/src/filter.rs index 3235d2d..c8157b8 100644 --- a/protocol/src/filter.rs +++ b/protocol/src/filter.rs @@ -34,10 +34,16 @@ use core::{fmt, net::SocketAddr, str::FromStr}; use bitflags::bitflags; use crate::{ - cursor::{Cursor, GetKeyValue, PutKeyValue}, - server::{ServerAdd, ServerFlags, ServerType}, + cursor::{Cursor, CursorError, GetKeyValue, PutKeyValue}, + server_info::ServerInfo, wrappers::Str, - ServerInfo, {CursorError, Error}, + Error, +}; + +#[cfg(feature = "net")] +use crate::{ + net::server::ServerAdd, + server_info::{ServerFlags, ServerType}, }; bitflags! { @@ -65,6 +71,7 @@ bitflags! { } } +#[cfg(feature = "net")] impl From<&ServerAdd> for FilterFlags { fn from(info: &ServerAdd) -> Self { let mut flags = Self::empty(); @@ -309,13 +316,9 @@ impl fmt::Display for &Filter<'_> { #[cfg(test)] mod tests { - use std::net::SocketAddr; - - use crate::{cursor::CursorMut, wrappers::Str}; - use super::*; - type ServerInfo = crate::ServerInfo>; + use crate::{cursor::CursorMut, wrappers::Str}; macro_rules! tests { ($($name:ident$(($($predefined_f:ident: $predefined_v:expr),+ $(,)?))? { @@ -454,6 +457,17 @@ mod tests { .pos(); assert_eq!(&buf[..n], b"0.19.3"); } +} + +#[cfg(all(test, feature = "net"))] +mod match_tests { + use std::net::SocketAddr; + + use crate::{cursor::CursorMut, wrappers::Str}; + + use super::*; + + type ServerInfo = crate::ServerInfo>; macro_rules! servers { ($($addr:expr => $info:expr $(=> $func:expr)?)+) => ( diff --git a/protocol/src/lib.rs b/protocol/src/lib.rs index edd3adf..735e6d8 100644 --- a/protocol/src/lib.rs +++ b/protocol/src/lib.rs @@ -15,16 +15,19 @@ extern crate alloc; extern crate log; mod cursor; -mod server_info; -pub mod admin; +#[cfg(feature = "net")] +pub mod net; + pub mod color; pub mod filter; -pub mod game; -pub mod master; -pub mod server; +pub mod server_info; pub mod wrappers; +#[deprecated(since = "0.2.1", note = "use net module instead")] +#[cfg(feature = "net")] +pub use crate::net::{admin, game, master, server}; + use core::fmt; pub use cursor::CursorError; diff --git a/protocol/src/net.rs b/protocol/src/net.rs new file mode 100644 index 0000000..f83b7fe --- /dev/null +++ b/protocol/src/net.rs @@ -0,0 +1,6 @@ +//! Network packets decoders and encoders. + +pub mod admin; +pub mod game; +pub mod master; +pub mod server; diff --git a/protocol/src/admin.rs b/protocol/src/net/admin.rs similarity index 100% rename from protocol/src/admin.rs rename to protocol/src/net/admin.rs diff --git a/protocol/src/game.rs b/protocol/src/net/game.rs similarity index 98% rename from protocol/src/game.rs rename to protocol/src/net/game.rs index b0b7f12..5e8bda7 100644 --- a/protocol/src/game.rs +++ b/protocol/src/net/game.rs @@ -5,10 +5,12 @@ use core::{fmt, net::SocketAddr}; -use crate::cursor::{Cursor, CursorMut}; -use crate::filter::Filter; -use crate::server::Region; -use crate::Error; +use crate::{ + cursor::{Cursor, CursorMut}, + filter::Filter, + net::server::Region, + Error, +}; /// Request a list of server addresses from master servers. #[derive(Clone, Debug, PartialEq)] diff --git a/protocol/src/master.rs b/protocol/src/net/master.rs similarity index 99% rename from protocol/src/master.rs rename to protocol/src/net/master.rs index 8263336..bad5078 100644 --- a/protocol/src/master.rs +++ b/protocol/src/net/master.rs @@ -5,8 +5,10 @@ use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6}; -use super::cursor::{Cursor, CursorMut}; -use super::Error; +use crate::{ + cursor::{Cursor, CursorMut}, + Error, +}; /// Master server challenge response packet. #[derive(Clone, Debug, PartialEq)] diff --git a/protocol/src/server.rs b/protocol/src/net/server.rs similarity index 82% rename from protocol/src/server.rs rename to protocol/src/net/server.rs index bbd7421..6b5ca1c 100644 --- a/protocol/src/server.rs +++ b/protocol/src/net/server.rs @@ -5,12 +5,21 @@ use core::fmt; -use bitflags::bitflags; +use crate::{ + cursor::{Cursor, CursorMut, GetKeyValue, PutKeyValue}, + filter::Version, + wrappers::Str, + {CursorError, Error}, +}; -use super::cursor::{Cursor, CursorMut, GetKeyValue, PutKeyValue}; -use super::filter::Version; -use super::wrappers::Str; -use super::{CursorError, Error}; +#[deprecated(since = "0.2.1", note = "use server_info::Region instead")] +pub use crate::server_info::Region; + +#[deprecated(since = "0.2.1", note = "use server_info::ServerType instead")] +pub use crate::server_info::ServerType; + +#[deprecated(since = "0.2.1", note = "use server_info::ServerFlags instead")] +pub use crate::server_info::ServerFlags; /// Sended to a master server before `ServerAdd` packet. #[derive(Clone, Debug, PartialEq)] @@ -117,146 +126,6 @@ impl fmt::Display for Os { } } -/// Game server type. -#[derive(Copy, Clone, Debug, PartialEq)] -#[repr(u8)] -pub enum ServerType { - /// Dedicated server. - Dedicated, - /// Game client. - Local, - /// Spectator proxy. - Proxy, - /// Unknown. - Unknown, -} - -impl Default for ServerType { - fn default() -> Self { - Self::Unknown - } -} - -impl TryFrom<&[u8]> for ServerType { - type Error = CursorError; - - fn try_from(value: &[u8]) -> Result { - match value { - b"d" => Ok(Self::Dedicated), - b"l" => Ok(Self::Local), - b"p" => Ok(Self::Proxy), - _ => Ok(Self::Unknown), - } - } -} - -impl GetKeyValue<'_> for ServerType { - fn get_key_value(cur: &mut Cursor) -> Result { - cur.get_key_value_raw()?.try_into() - } -} - -impl PutKeyValue for ServerType { - fn put_key_value<'a, 'b>( - &self, - cur: &'b mut CursorMut<'a>, - ) -> Result<&'b mut CursorMut<'a>, CursorError> { - match self { - Self::Dedicated => cur.put_str("d"), - Self::Local => cur.put_str("l"), - Self::Proxy => cur.put_str("p"), - Self::Unknown => cur.put_str("?"), - } - } -} - -impl fmt::Display for ServerType { - fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { - use ServerType as E; - - let s = match self { - E::Dedicated => "dedicated", - E::Local => "local", - E::Proxy => "proxy", - E::Unknown => "unknown", - }; - - write!(fmt, "{}", s) - } -} - -/// The region of the world in which the server is located. -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -#[repr(u8)] -pub enum Region { - /// US East coast. - USEastCoast = 0x00, - /// US West coast. - USWestCoast = 0x01, - /// South America. - SouthAmerica = 0x02, - /// Europe. - Europe = 0x03, - /// Asia. - Asia = 0x04, - /// Australia. - Australia = 0x05, - /// Middle East. - MiddleEast = 0x06, - /// Africa. - Africa = 0x07, - /// Rest of the world. - RestOfTheWorld = 0xff, -} - -impl Default for Region { - fn default() -> Self { - Self::RestOfTheWorld - } -} - -impl TryFrom for Region { - type Error = CursorError; - - fn try_from(value: u8) -> Result { - match value { - 0x00 => Ok(Region::USEastCoast), - 0x01 => Ok(Region::USWestCoast), - 0x02 => Ok(Region::SouthAmerica), - 0x03 => Ok(Region::Europe), - 0x04 => Ok(Region::Asia), - 0x05 => Ok(Region::Australia), - 0x06 => Ok(Region::MiddleEast), - 0x07 => Ok(Region::Africa), - 0xff => Ok(Region::RestOfTheWorld), - _ => Err(CursorError::InvalidNumber), - } - } -} - -impl GetKeyValue<'_> for Region { - fn get_key_value(cur: &mut Cursor) -> Result { - cur.get_key_value::()?.try_into() - } -} - -bitflags! { - /// Additional server flags. - #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] - pub struct ServerFlags: u8 { - /// Server has bots. - const BOTS = 1 << 0; - /// Server is behind a password. - const PASSWORD = 1 << 1; - /// Server using anti-cheat. - const SECURE = 1 << 2; - /// Server is LAN. - const LAN = 1 << 3; - /// Server behind NAT. - const NAT = 1 << 4; - } -} - /// Add/update game server information on the master server. #[derive(Clone, Debug, PartialEq, Default)] pub struct ServerAdd { diff --git a/protocol/src/server_info.rs b/protocol/src/server_info.rs index 73be80b..54f5d09 100644 --- a/protocol/src/server_info.rs +++ b/protocol/src/server_info.rs @@ -1,12 +1,136 @@ // SPDX-License-Identifier: LGPL-3.0-only // SPDX-FileCopyrightText: 2023 Denis Drakhnia -#[cfg(feature = "alloc")] +//! Server info structures used in filter. + +use core::fmt; + +#[cfg(all(feature = "alloc", feature = "net"))] use alloc::boxed::Box; -use super::filter::{FilterFlags, Version}; -use super::server::{Region, ServerAdd}; -use super::wrappers::Str; +use bitflags::bitflags; + +use crate::{ + cursor::CursorError, + filter::{FilterFlags, Version}, +}; + +#[cfg(feature = "net")] +use crate::{net::server::ServerAdd, wrappers::Str}; + +/// The region of the world in which the server is located. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum Region { + /// US East coast. + USEastCoast = 0x00, + /// US West coast. + USWestCoast = 0x01, + /// South America. + SouthAmerica = 0x02, + /// Europe. + Europe = 0x03, + /// Asia. + Asia = 0x04, + /// Australia. + Australia = 0x05, + /// Middle East. + MiddleEast = 0x06, + /// Africa. + Africa = 0x07, + /// Rest of the world. + RestOfTheWorld = 0xff, +} + +impl Default for Region { + fn default() -> Self { + Self::RestOfTheWorld + } +} + +impl TryFrom for Region { + type Error = CursorError; + + fn try_from(value: u8) -> Result { + match value { + 0x00 => Ok(Region::USEastCoast), + 0x01 => Ok(Region::USWestCoast), + 0x02 => Ok(Region::SouthAmerica), + 0x03 => Ok(Region::Europe), + 0x04 => Ok(Region::Asia), + 0x05 => Ok(Region::Australia), + 0x06 => Ok(Region::MiddleEast), + 0x07 => Ok(Region::Africa), + 0xff => Ok(Region::RestOfTheWorld), + _ => Err(CursorError::InvalidNumber), + } + } +} + +/// Game server type. +#[derive(Copy, Clone, Debug, PartialEq)] +#[repr(u8)] +pub enum ServerType { + /// Dedicated server. + Dedicated, + /// Game client. + Local, + /// Spectator proxy. + Proxy, + /// Unknown. + Unknown, +} + +impl Default for ServerType { + fn default() -> Self { + Self::Unknown + } +} + +impl TryFrom<&[u8]> for ServerType { + type Error = CursorError; + + fn try_from(value: &[u8]) -> Result { + match value { + b"d" => Ok(Self::Dedicated), + b"l" => Ok(Self::Local), + b"p" => Ok(Self::Proxy), + _ => Ok(Self::Unknown), + } + } +} + +impl fmt::Display for ServerType { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + use ServerType as E; + + let s = match self { + E::Dedicated => "dedicated", + E::Local => "local", + E::Proxy => "proxy", + E::Unknown => "unknown", + }; + + write!(fmt, "{}", s) + } +} + +bitflags! { + /// Additional server flags. + #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] + pub struct ServerFlags: u8 { + /// Server has bots. + const BOTS = 1 << 0; + /// Server is behind a password. + const PASSWORD = 1 << 1; + /// Server using anti-cheat. + const SECURE = 1 << 2; + /// Server is LAN. + const LAN = 1 << 3; + /// Server behind NAT. + const NAT = 1 << 4; + } +} /// Game server information. #[derive(Clone, Debug)] @@ -25,6 +149,7 @@ pub struct ServerInfo { pub region: Region, } +#[cfg(feature = "net")] impl<'a> ServerInfo<&'a [u8]> { /// Creates a new `ServerInfo`. pub fn new(info: &ServerAdd>) -> Self { @@ -39,6 +164,7 @@ impl<'a> ServerInfo<&'a [u8]> { } } +#[cfg(feature = "net")] #[cfg(any(feature = "alloc", test))] impl ServerInfo> { /// Creates a new `ServerInfo`.