You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1178 lines
26 KiB
1178 lines
26 KiB
//========= Copyright Valve Corporation, All rights reserved. ============// |
|
// |
|
// Purpose: |
|
// |
|
// $NoKeywords: $ |
|
//=============================================================================// |
|
|
|
|
|
//#define PARANOID |
|
|
|
#if defined( PARANOID ) |
|
#include <stdlib.h> |
|
#include <crtdbg.h> |
|
#endif |
|
|
|
#include <winsock2.h> |
|
#include <mswsock.h> |
|
#include "tcpsocket.h" |
|
#include "tier1/utllinkedlist.h" |
|
#include <stdio.h> |
|
#include "threadhelpers.h" |
|
#include "tier0/dbg.h" |
|
|
|
|
|
|
|
#error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead." |
|
|
|
|
|
extern TIMEVAL SetupTimeVal( double flTimeout ); |
|
extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut ); |
|
extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut ); |
|
|
|
|
|
#define SENTINEL_DISCONNECT -1 |
|
#define SENTINEL_KEEPALIVE -2 |
|
|
|
|
|
#define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS |
|
#define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long |
|
|
|
|
|
static bool g_bEnableTCPTimeout = true; |
|
|
|
|
|
class CRecvData |
|
{ |
|
public: |
|
int m_Count; |
|
unsigned char m_Data[1]; |
|
}; |
|
|
|
|
|
|
|
SOCKET TCPBind( const CIPAddr *pAddr ) |
|
{ |
|
// Create a socket to send and receive through. |
|
SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED ); |
|
if ( sock == INVALID_SOCKET ) |
|
{ |
|
Assert( false ); |
|
return INVALID_SOCKET; |
|
} |
|
|
|
// bind to it! |
|
sockaddr_in addr; |
|
IPAddrToSockAddr( pAddr, &addr ); |
|
|
|
int status = bind( sock, (sockaddr*)&addr, sizeof(addr) ); |
|
if ( status == 0 ) |
|
{ |
|
return sock; |
|
} |
|
else |
|
{ |
|
closesocket( sock ); |
|
return INVALID_SOCKET; |
|
} |
|
} |
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------------------- // |
|
// TCP sockets. |
|
// ---------------------------------------------------------------------------------------- // |
|
|
|
enum |
|
{ |
|
OP_RECV=111, |
|
OP_SEND |
|
}; |
|
|
|
// We use this for all OVERLAPPED structures. |
|
class COverlappedPlus : public WSAOVERLAPPED |
|
{ |
|
public: |
|
COverlappedPlus() |
|
{ |
|
memset( this, 0, sizeof( WSAOVERLAPPED ) ); |
|
} |
|
|
|
int m_OPType; // One of the OP_ defines. |
|
}; |
|
|
|
typedef struct SendBuf_t |
|
{ |
|
COverlappedPlus m_Overlapped; |
|
int m_Index; // Index into m_SendBufs. |
|
int m_DataLength; |
|
char m_Data[1]; |
|
} SendBuf_s; |
|
|
|
|
|
// These manage a thread that calls SendKeepalive() on all TCPSockets. |
|
// AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called. |
|
class CTCPSocket; |
|
void AddGlobalTCPSocket( CTCPSocket *pSocket ); |
|
void RemoveGlobalTCPSocket( CTCPSocket *pSocket ); |
|
|
|
|
|
|
|
// ------------------------------------------------------------------------------------------ // |
|
// CTCPSocket implementation. |
|
// ------------------------------------------------------------------------------------------ // |
|
|
|
class CTCPSocket : public ITCPSocket |
|
{ |
|
friend class CTCPListenSocket; |
|
|
|
public: |
|
|
|
CTCPSocket() |
|
{ |
|
m_Socket = INVALID_SOCKET; |
|
m_bConnected = false; |
|
|
|
m_hIOCP = NULL; |
|
|
|
m_bShouldExitThreads = false; |
|
m_bConnectionLost = false; |
|
m_nSizeBytesReceived = 0; |
|
|
|
m_pIncomingData = NULL; |
|
|
|
memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) ); |
|
m_RecvOverlapped.m_OPType = OP_RECV; |
|
|
|
m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL ); |
|
m_RecvStage = -1; |
|
|
|
m_MainThreadID = GetCurrentThreadId(); |
|
} |
|
|
|
virtual ~CTCPSocket() |
|
{ |
|
Term(); |
|
CloseHandle( m_hRecvSignal ); |
|
} |
|
|
|
void Term() |
|
{ |
|
Assert( GetCurrentThreadId() == m_MainThreadID ); |
|
|
|
RemoveGlobalTCPSocket( this ); |
|
|
|
if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost ) |
|
{ |
|
SendDisconnectSentinel(); |
|
|
|
// Give the sends a second to complete. SO_LINGER is having trouble for some reason. |
|
WaitForSendsToComplete( 1 ); |
|
} |
|
|
|
|
|
StopThreads(); |
|
|
|
if ( m_Socket != INVALID_SOCKET ) |
|
{ |
|
closesocket( m_Socket ); |
|
m_Socket = INVALID_SOCKET; |
|
} |
|
|
|
if ( m_hIOCP ) |
|
{ |
|
CloseHandle( m_hIOCP ); |
|
m_hIOCP = NULL; |
|
} |
|
|
|
m_bConnected = false; |
|
m_bConnectionLost = true; |
|
m_RecvStage = -1; |
|
|
|
FOR_EACH_LL( m_SendBufs, i ) |
|
{ |
|
SendBuf_t *pSendBuf = m_SendBufs[i]; |
|
ParanoidMemoryCheck( pSendBuf ); |
|
free( pSendBuf ); |
|
} |
|
m_SendBufs.Purge(); |
|
|
|
FOR_EACH_LL( m_RecvDatas, j ) |
|
{ |
|
CRecvData *pRecvData = m_RecvDatas[j]; |
|
ParanoidMemoryCheck( pRecvData ); |
|
free( pRecvData ); |
|
} |
|
m_RecvDatas.Purge(); |
|
|
|
if ( m_pIncomingData ) |
|
{ |
|
ParanoidMemoryCheck( m_pIncomingData ); |
|
free( m_pIncomingData ); |
|
m_pIncomingData = 0; |
|
} |
|
} |
|
|
|
virtual void Release() |
|
{ |
|
delete this; |
|
} |
|
|
|
|
|
void ParanoidMemoryCheck( void *ptr = NULL ) |
|
{ |
|
#if defined( PARANOID ) |
|
Assert( _CrtIsValidHeapPointer( this ) ); |
|
|
|
if ( ptr ) |
|
{ |
|
Assert( _CrtIsValidHeapPointer( ptr ) ); |
|
} |
|
|
|
Assert( _CrtCheckMemory() == TRUE ); |
|
#endif |
|
} |
|
|
|
|
|
virtual bool BindToAny( const unsigned short port ) |
|
{ |
|
Term(); |
|
|
|
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY |
|
m_Socket = TCPBind( &addr ); |
|
if ( m_Socket == INVALID_SOCKET ) |
|
{ |
|
return false; |
|
} |
|
else |
|
{ |
|
SetInitialSocketOptions(); |
|
return true; |
|
} |
|
} |
|
|
|
|
|
// Set the initial socket options that we want. |
|
void SetInitialSocketOptions() |
|
{ |
|
// Set nodelay to improve latency. |
|
BOOL val = TRUE; |
|
setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) ); |
|
|
|
// Make it linger for 3 seconds when it exits. |
|
LINGER linger; |
|
linger.l_onoff = 1; |
|
linger.l_linger = 3; |
|
setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) ); |
|
} |
|
|
|
|
|
// Called only by main thread interface functions. |
|
// Returns true if the connection is lost. |
|
bool CheckConnectionLost() |
|
{ |
|
Assert( GetCurrentThreadId() == m_MainThreadID ); |
|
|
|
if ( m_Socket == SOCKET_ERROR ) |
|
return true; |
|
|
|
// Have we timed out? |
|
if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) ) |
|
{ |
|
SetConnectionLost( "Connection timed out." ); |
|
} |
|
|
|
// Has any thread posted that the connection has been lost? |
|
CCriticalSectionLock postLock( &m_ConnectionLostCS ); |
|
postLock.Lock(); |
|
if ( m_bConnectionLost ) |
|
{ |
|
Term(); |
|
return true; |
|
} |
|
else |
|
{ |
|
return false; |
|
} |
|
} |
|
|
|
// Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost. |
|
void SetConnectionLost( const char *pErrorString, int err = -1 ) |
|
{ |
|
CCriticalSectionLock postLock( &m_ConnectionLostCS ); |
|
postLock.Lock(); |
|
m_bConnectionLost = true; |
|
postLock.Unlock(); |
|
|
|
// Handle it right away if we're in the main thread. If we're in an IO thread, |
|
// it has to wait until the next interface function calls CheckConnectionLost(). |
|
if ( GetCurrentThreadId() == m_MainThreadID ) |
|
{ |
|
Term(); |
|
} |
|
|
|
if ( pErrorString ) |
|
{ |
|
m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 ); |
|
} |
|
else |
|
{ |
|
char *lpMsgBuf; |
|
FormatMessage( |
|
FORMAT_MESSAGE_ALLOCATE_BUFFER | |
|
FORMAT_MESSAGE_FROM_SYSTEM | |
|
FORMAT_MESSAGE_IGNORE_INSERTS, |
|
NULL, |
|
err, |
|
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language |
|
(LPTSTR) &lpMsgBuf, |
|
0, |
|
NULL |
|
); |
|
|
|
m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 ); |
|
LocalFree( lpMsgBuf ); |
|
} |
|
} |
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- // |
|
// The receive code. |
|
// -------------------------------------------------------------------------------------------------- // |
|
|
|
virtual bool StartWaitingForSize( bool bFresh ) |
|
{ |
|
Assert( m_Socket != INVALID_SOCKET ); |
|
Assert( m_bConnected ); |
|
|
|
m_RecvStage = 0; |
|
m_RecvDataSize = -1; |
|
if ( bFresh ) |
|
m_nSizeBytesReceived = 0; |
|
|
|
DWORD dwNumBytesReceived = 0; |
|
WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived }; |
|
DWORD dwFlags = 0; |
|
|
|
int status = WSARecv( |
|
m_Socket, |
|
&buf, |
|
1, |
|
&dwNumBytesReceived, |
|
&dwFlags, |
|
&m_RecvOverlapped, |
|
NULL ); |
|
|
|
int err = -1; |
|
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) |
|
{ |
|
SetConnectionLost( NULL, err ); |
|
return false; |
|
} |
|
else |
|
{ |
|
return true; |
|
} |
|
} |
|
|
|
|
|
bool PostNextDataPart() |
|
{ |
|
DWORD dwNumBytesReceived = 0; |
|
WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived }; |
|
DWORD dwFlags = 0; |
|
|
|
int status = WSARecv( |
|
m_Socket, |
|
&buf, |
|
1, |
|
&dwNumBytesReceived, |
|
&dwFlags, |
|
&m_RecvOverlapped, |
|
NULL ); |
|
|
|
int err = -1; |
|
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) |
|
{ |
|
SetConnectionLost( NULL, err ); |
|
return false; |
|
} |
|
else |
|
{ |
|
return true; |
|
} |
|
} |
|
|
|
|
|
bool StartWaitingForData() |
|
{ |
|
Assert( m_Socket != INVALID_SOCKET ); |
|
Assert( m_RecvStage == 0 ); |
|
Assert( m_bConnected ); |
|
Assert( m_RecvDataSize > 0 ); |
|
|
|
m_RecvStage = 1; |
|
|
|
// Add a CRecvData element. |
|
ParanoidMemoryCheck(); |
|
m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize ); |
|
if ( !m_pIncomingData ) |
|
{ |
|
char str[512]; |
|
_snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize ); |
|
SetConnectionLost( str ); |
|
return false; |
|
} |
|
|
|
m_pIncomingData->m_Count = m_RecvDataSize; |
|
|
|
m_AmountReceived = 0; |
|
|
|
return PostNextDataPart(); |
|
} |
|
|
|
virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout ) |
|
{ |
|
if ( CheckConnectionLost() ) |
|
return false; |
|
|
|
// Wait in 50ms chunks, checking for disconnections along the way. |
|
bool bGotData = false; |
|
DWORD msToWait = (DWORD)( flTimeout * 1000.0 ); |
|
do |
|
{ |
|
DWORD curWaitTime = min( msToWait, 50 ); |
|
DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime ); |
|
if ( ret == WAIT_OBJECT_0 ) |
|
{ |
|
bGotData = true; |
|
break; |
|
} |
|
|
|
// Did the connection timeout? |
|
if ( CheckConnectionLost() ) |
|
return false; |
|
|
|
msToWait -= curWaitTime; |
|
} while ( msToWait ); |
|
|
|
// If we never got a WAIT_OBJECT_0, then we never received anything. |
|
if ( !bGotData ) |
|
return false; |
|
|
|
|
|
CCriticalSectionLock csLock( &m_RecvDataCS ); |
|
csLock.Lock(); |
|
|
|
// Pickup the head m_RecvDatas element. |
|
CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ]; |
|
data.CopyArray( pRecvData->m_Data, pRecvData->m_Count ); |
|
|
|
// Now free it. |
|
m_RecvDatas.Remove( m_RecvDatas.Head() ); |
|
ParanoidMemoryCheck( pRecvData ); |
|
free( pRecvData ); |
|
|
|
// Set the event again for the next time around, if there is more data waiting. |
|
if ( m_RecvDatas.Count() > 0 ) |
|
SetEvent( m_hRecvSignal ); |
|
|
|
return true; |
|
} |
|
|
|
// INSIDE IO THREAD. |
|
void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) |
|
{ |
|
if ( dwNumBytes == 0 ) |
|
{ |
|
SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" ); |
|
return; |
|
} |
|
|
|
m_LastRecvTime = Plat_FloatTime(); |
|
if ( m_RecvStage == 0 ) |
|
{ |
|
m_nSizeBytesReceived += dwNumBytes; |
|
if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) ) |
|
{ |
|
// Size of -1 means the other size is breaking the connection. |
|
if ( m_RecvDataSize == SENTINEL_DISCONNECT ) |
|
{ |
|
SetConnectionLost( "Got a graceful disconnect message." ); |
|
return; |
|
} |
|
else if ( m_RecvDataSize == SENTINEL_KEEPALIVE ) |
|
{ |
|
// No data follows this. Just let m_LastRecvTime get updated. |
|
StartWaitingForSize( true ); |
|
return; |
|
} |
|
|
|
StartWaitingForData(); |
|
} |
|
else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) ) |
|
{ |
|
// Handle the case where we only got some of the data (maybe one of the clients got disconnected). |
|
StartWaitingForSize( false ); |
|
} |
|
else |
|
{ |
|
// This case should never ever happen! |
|
#if defined( _DEBUG ) |
|
__asm int 3; |
|
#endif |
|
|
|
SetConnectionLost( "Received too much data in a packet!" ); |
|
return; |
|
} |
|
} |
|
else if ( m_RecvStage == 1 ) |
|
{ |
|
// Got the data, make sure we got it all. |
|
m_AmountReceived += dwNumBytes; |
|
|
|
// Sanity check. |
|
#if defined( _DEBUG ) |
|
Assert( m_RecvDataSize == m_pIncomingData->m_Count ); |
|
Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads. |
|
#endif |
|
|
|
if ( m_AmountReceived == m_RecvDataSize ) |
|
{ |
|
m_RecvStage = 2; |
|
|
|
// Add the data to the list of packets waiting to be picked up. |
|
CCriticalSectionLock csLock( &m_RecvDataCS ); |
|
csLock.Lock(); |
|
|
|
m_RecvDatas.AddToTail( m_pIncomingData ); |
|
m_pIncomingData = NULL; |
|
|
|
if ( m_RecvDatas.Count() == 1 ) |
|
SetEvent( m_hRecvSignal ); // Notify the Recv() function. |
|
|
|
StartWaitingForSize( true ); |
|
} |
|
else |
|
{ |
|
PostNextDataPart(); |
|
} |
|
} |
|
else |
|
{ |
|
Assert( false ); |
|
} |
|
} |
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- // |
|
// The send code. |
|
// -------------------------------------------------------------------------------------------------- // |
|
|
|
virtual void WaitForSendsToComplete( double flTimeout ) |
|
{ |
|
CWaitTimer waitTimer( flTimeout ); |
|
while ( 1 ) |
|
{ |
|
CCriticalSectionLock sendBufLock( &m_SendCS ); |
|
sendBufLock.Lock(); |
|
if( m_SendBufs.Count() == 0 ) |
|
return; |
|
sendBufLock.Unlock(); |
|
|
|
if ( waitTimer.ShouldKeepWaiting() ) |
|
Sleep( 10 ); |
|
else |
|
break; |
|
} |
|
} |
|
|
|
|
|
// This is called in the keepalive thread. |
|
void SendKeepalive() |
|
{ |
|
// Send a message saying we're exiting. |
|
ParanoidMemoryCheck(); |
|
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); |
|
if ( !pBuf ) |
|
{ |
|
SetConnectionLost( "malloc() in SendKeepalive() failed." ); |
|
return; |
|
} |
|
|
|
pBuf->m_DataLength = sizeof( int ); |
|
*((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE; |
|
InternalSendDataBuf( pBuf ); |
|
} |
|
|
|
|
|
void SendDisconnectSentinel() |
|
{ |
|
// Send a message saying we're exiting. |
|
ParanoidMemoryCheck(); |
|
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) ); |
|
if ( pBuf ) |
|
{ |
|
pBuf->m_DataLength = sizeof( int ); |
|
*((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting. |
|
InternalSendDataBuf( pBuf ); |
|
} |
|
} |
|
|
|
|
|
virtual bool Send( const void *pData, int len ) |
|
{ |
|
const void *pChunks[1] = { pData }; |
|
int chunkLengths[1] = { len }; |
|
return SendChunks( pChunks, chunkLengths, 1 ); |
|
} |
|
|
|
|
|
virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks ) |
|
{ |
|
if ( CheckConnectionLost() ) |
|
return false; |
|
|
|
CChunkWalker walker( pChunks, pChunkLengths, nChunks ); |
|
int totalLength = walker.GetTotalLength(); |
|
|
|
if ( !totalLength ) |
|
return true; |
|
|
|
// Create a buffer to hold the data and copy the data in. |
|
ParanoidMemoryCheck(); |
|
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) ); |
|
if ( !pBuf ) |
|
{ |
|
char str[512]; |
|
_snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength ); |
|
SetConnectionLost( str ); |
|
return false; |
|
} |
|
|
|
pBuf->m_DataLength = totalLength + sizeof( int ); |
|
|
|
int *pByteCountPos = (int*)pBuf->m_Data; |
|
*pByteCountPos = totalLength; |
|
|
|
char *pDataPos = &pBuf->m_Data[ sizeof( int ) ]; |
|
walker.CopyTo( pDataPos, totalLength ); |
|
|
|
int status = InternalSendDataBuf( pBuf ); |
|
int err = -1; |
|
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING ) |
|
{ |
|
SetConnectionLost( NULL, err ); |
|
return false; |
|
} |
|
else |
|
{ |
|
return true; |
|
} |
|
} |
|
|
|
|
|
int InternalSendDataBuf( SendBuf_t *pBuf ) |
|
{ |
|
// Protect against interference from the keepalive thread. |
|
CCriticalSectionLock csLock( &m_SendCS ); |
|
csLock.Lock(); |
|
|
|
|
|
pBuf->m_Overlapped.m_OPType = OP_SEND; |
|
pBuf->m_Overlapped.hEvent = NULL; |
|
|
|
// Add it to our list of buffers. |
|
pBuf->m_Index = m_SendBufs.AddToTail( pBuf ); |
|
|
|
// Tell Winsock to send it. |
|
WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data }; |
|
|
|
DWORD dwNumBytesSent = 0; |
|
return WSASend( |
|
m_Socket, |
|
&buf, |
|
1, |
|
&dwNumBytesSent, |
|
0, |
|
&pBuf->m_Overlapped, |
|
NULL ); |
|
} |
|
|
|
|
|
// INSIDE IO THREAD. |
|
void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes ) |
|
{ |
|
if ( dwNumBytes == 0 ) |
|
{ |
|
SetConnectionLost( "0 bytes in HandleSendCompletion." ); |
|
return; |
|
} |
|
|
|
// Just free the buffer. |
|
SendBuf_t *pBuf = (SendBuf_t*)pInfo; |
|
Assert( dwNumBytes == (DWORD)pBuf->m_DataLength ); |
|
|
|
CCriticalSectionLock sendBufLock( &m_SendCS ); |
|
sendBufLock.Lock(); |
|
m_SendBufs.Remove( pBuf->m_Index ); |
|
sendBufLock.Unlock(); |
|
|
|
ParanoidMemoryCheck( pBuf ); |
|
free( pBuf ); |
|
} |
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- // |
|
// The connect code. |
|
// -------------------------------------------------------------------------------------------------- // |
|
|
|
virtual bool BeginConnect( const CIPAddr &inputAddr ) |
|
{ |
|
sockaddr_in addr; |
|
IPAddrToSockAddr( &inputAddr, &addr ); |
|
|
|
m_bConnected = false; |
|
int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) ); |
|
ret=ret; |
|
|
|
return true; |
|
} |
|
|
|
|
|
virtual bool UpdateConnect() |
|
{ |
|
// We're still ok.. just wait until the socket becomes writable (is connected) or we timeout. |
|
fd_set writeSet; |
|
writeSet.fd_count = 1; |
|
writeSet.fd_array[0] = m_Socket; |
|
TIMEVAL timeVal = SetupTimeVal( 0 ); |
|
|
|
// See if it has a packet waiting. |
|
int status = select( 0, NULL, &writeSet, NULL, &timeVal ); |
|
if ( status > 0 ) |
|
{ |
|
SetupConnected(); |
|
return true; |
|
} |
|
|
|
return false; |
|
} |
|
|
|
|
|
void SetupConnected() |
|
{ |
|
m_bConnected = true; |
|
m_bConnectionLost = false; |
|
m_LastRecvTime = Plat_FloatTime(); |
|
|
|
CreateThreads(); |
|
StartWaitingForSize( true ); |
|
AddGlobalTCPSocket( this ); |
|
} |
|
|
|
|
|
virtual bool IsConnected() |
|
{ |
|
CheckConnectionLost(); |
|
return m_bConnected; |
|
} |
|
|
|
|
|
virtual void GetDisconnectReason( CUtlVector<char> &reason ) |
|
{ |
|
reason = m_ErrorString; |
|
} |
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- // |
|
// Threads code. |
|
// -------------------------------------------------------------------------------------------------- // |
|
|
|
// Create our IO Completion Port threads. |
|
bool CreateThreads() |
|
{ |
|
int nThreads = 1; |
|
SetShouldExitThreads( false ); |
|
|
|
// Create our IO completion port and hook it to our socket. |
|
m_hIOCP = CreateIoCompletionPort( |
|
INVALID_HANDLE_VALUE, NULL, 0, 0); |
|
|
|
m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads ); |
|
|
|
for ( int i=0; i < nThreads; i++ ) |
|
{ |
|
DWORD dwThreadID = 0; |
|
HANDLE hThread = CreateThread( |
|
NULL, |
|
0, |
|
&CTCPSocket::StaticThreadFn, |
|
this, |
|
0, |
|
&dwThreadID ); |
|
|
|
if ( hThread ) |
|
{ |
|
SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL ); |
|
m_Threads.AddToTail( hThread ); |
|
} |
|
else |
|
{ |
|
StopThreads(); |
|
return false; |
|
} |
|
} |
|
|
|
return true; |
|
} |
|
|
|
|
|
void StopThreads() |
|
{ |
|
// Tell the threads to exit, then wait for them to do so. |
|
SetShouldExitThreads( true ); |
|
WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE ); |
|
|
|
for ( int i=0; i < m_Threads.Count(); i++ ) |
|
{ |
|
CloseHandle( m_Threads[i] ); |
|
} |
|
m_Threads.Purge(); |
|
} |
|
|
|
|
|
void SetShouldExitThreads( bool bShouldExit ) |
|
{ |
|
CCriticalSectionLock lock( &m_ThreadsCS ); |
|
lock.Lock(); |
|
m_bShouldExitThreads = bShouldExit; |
|
} |
|
|
|
|
|
bool ShouldExitThreads() |
|
{ |
|
CCriticalSectionLock lock( &m_ThreadsCS ); |
|
lock.Lock(); |
|
|
|
bool bRet = m_bShouldExitThreads; |
|
return bRet; |
|
} |
|
|
|
|
|
DWORD ThreadFn() |
|
{ |
|
while ( 1 ) |
|
{ |
|
DWORD dwNumBytes = 0; |
|
unsigned long pInputTCPSocket; |
|
LPOVERLAPPED pOverlapped; |
|
|
|
if ( GetQueuedCompletionStatus( |
|
m_hIOCP, // the port we're listening on |
|
&dwNumBytes, // # bytes received on the port |
|
&pInputTCPSocket,// "completion key" = CTCPSocket* |
|
&pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend. |
|
100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted. |
|
) ) |
|
{ |
|
COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped; |
|
ParanoidMemoryCheck( pInfo ); |
|
|
|
if ( pInfo->m_OPType == OP_RECV ) |
|
{ |
|
Assert( pInfo == &m_RecvOverlapped ); |
|
HandleRecvCompletion( pInfo, dwNumBytes ); |
|
} |
|
else |
|
{ |
|
Assert( pInfo->m_OPType == OP_SEND ); |
|
HandleSendCompletion( pInfo, dwNumBytes ); |
|
} |
|
} |
|
|
|
if ( ShouldExitThreads() ) |
|
break; |
|
} |
|
|
|
return 0; |
|
} |
|
|
|
|
|
static DWORD WINAPI StaticThreadFn( LPVOID pParameter ) |
|
{ |
|
return ((CTCPSocket*)pParameter)->ThreadFn(); |
|
} |
|
|
|
|
|
|
|
private: |
|
|
|
SOCKET m_Socket; |
|
bool m_bConnected; |
|
|
|
|
|
// m_RecvOverlapped is setup to first wait for the size, then the data. |
|
// Then it is not posted until the app grabs the data. |
|
HANDLE m_hRecvSignal; // Tells Recv() when we have data. |
|
COverlappedPlus m_RecvOverlapped; |
|
int m_RecvStage; // -1 = not initialized |
|
// 0 = waiting for size |
|
// 1 = waiting for data |
|
// 2 = waiting for app to pickup the data |
|
|
|
CUtlLinkedList<CRecvData*,int> m_RecvDatas; // The head element is the next one to be picked up. |
|
CRecvData *m_pIncomingData; // The packet we're currently receiving. |
|
CCriticalSection m_RecvDataCS; // This protects adds and removes in the list. |
|
|
|
// These reference the element at the tail of m_RecvData. It is the current one getting |
|
volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet? |
|
int m_RecvDataSize; // this is received over the network |
|
int m_AmountReceived; // How much we've received so far. |
|
|
|
// Last time we received anything from this connection. Used to determine if the connection is |
|
// still active. |
|
double m_LastRecvTime; |
|
|
|
|
|
// Outgoing send buffers. |
|
CUtlLinkedList<SendBuf_t*,int> m_SendBufs; |
|
CCriticalSection m_SendCS; |
|
|
|
|
|
// All the threads waiting for IO. |
|
CUtlVector<HANDLE> m_Threads; |
|
HANDLE m_hIOCP; |
|
|
|
// Used during shutdown. |
|
volatile bool m_bShouldExitThreads; |
|
CCriticalSection m_ThreadsCS; |
|
|
|
// For debugging. |
|
DWORD m_MainThreadID; |
|
|
|
// Set by the main thread or IO threads to signal connection lost. |
|
bool m_bConnectionLost; |
|
CCriticalSection m_ConnectionLostCS; |
|
|
|
// This is set when we get disconnected. |
|
CUtlVector<char> m_ErrorString; |
|
}; |
|
|
|
|
|
// ------------------------------------------------------------------------------------------ // |
|
// ITCPListenSocket implementation. |
|
// ------------------------------------------------------------------------------------------ // |
|
|
|
class CTCPListenSocket : public ITCPListenSocket |
|
{ |
|
public: |
|
|
|
CTCPListenSocket() |
|
{ |
|
m_Socket = INVALID_SOCKET; |
|
} |
|
|
|
|
|
virtual ~CTCPListenSocket() |
|
{ |
|
if ( m_Socket != INVALID_SOCKET ) |
|
{ |
|
closesocket( m_Socket ); |
|
} |
|
} |
|
|
|
|
|
// The main function to create one of these suckers. |
|
static ITCPListenSocket* Create( const unsigned short port, int nQueueLength ) |
|
{ |
|
CTCPListenSocket *pRet = new CTCPListenSocket; |
|
if ( !pRet ) |
|
return NULL; |
|
|
|
// Bind it to a socket and start listening. |
|
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY |
|
pRet->m_Socket = TCPBind( &addr ); |
|
if ( pRet->m_Socket == INVALID_SOCKET || |
|
listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 ) |
|
{ |
|
pRet->Release(); |
|
return false; |
|
} |
|
|
|
return pRet; |
|
} |
|
|
|
|
|
virtual void Release() |
|
{ |
|
delete this; |
|
} |
|
|
|
|
|
virtual ITCPSocket* UpdateListen( CIPAddr *pAddr ) |
|
{ |
|
// We're still ok.. just wait until the socket becomes writable (is connected) or we timeout. |
|
fd_set readSet; |
|
readSet.fd_count = 1; |
|
readSet.fd_array[0] = m_Socket; |
|
TIMEVAL timeVal = SetupTimeVal( 0 ); |
|
|
|
// Wait until it connects. |
|
int status = select( 0, &readSet, NULL, NULL, &timeVal ); |
|
if ( status > 0 ) |
|
{ |
|
sockaddr_in addr; |
|
int addrSize = sizeof( addr ); |
|
|
|
// Now accept the final connection. |
|
SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize ); |
|
if ( newSock == INVALID_SOCKET ) |
|
{ |
|
Assert( false ); |
|
} |
|
else |
|
{ |
|
CTCPSocket *pRet = new CTCPSocket; |
|
if ( !pRet ) |
|
{ |
|
closesocket( newSock ); |
|
return NULL; |
|
} |
|
|
|
pRet->m_Socket = newSock; |
|
pRet->SetInitialSocketOptions(); |
|
pRet->SetupConnected(); |
|
|
|
// Report the address.. |
|
SockAddrToIPAddr( &addr, pAddr ); |
|
|
|
return pRet; |
|
} |
|
} |
|
|
|
return NULL; |
|
} |
|
|
|
|
|
private: |
|
SOCKET m_Socket; |
|
}; |
|
|
|
|
|
|
|
ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength ) |
|
{ |
|
return CTCPListenSocket::Create( port, nQueueLength ); |
|
} |
|
|
|
|
|
ITCPSocket* CreateTCPSocket() |
|
{ |
|
return new CTCPSocket; |
|
} |
|
|
|
|
|
void TCPSocket_EnableTimeout( bool bEnable ) |
|
{ |
|
g_bEnableTCPTimeout = bEnable; |
|
} |
|
|
|
|
|
// --------------------------------------------------------------------------------- // |
|
// This thread sends keepalives on all active TCP sockets. |
|
// --------------------------------------------------------------------------------- // |
|
|
|
HANDLE g_hKeepaliveThread; |
|
HANDLE g_hKeepaliveThreadSignal; |
|
HANDLE g_hKeepaliveThreadReply; |
|
CUtlLinkedList<CTCPSocket*,int> g_TCPSockets; |
|
CCriticalSection g_TCPSocketsCS; |
|
|
|
|
|
DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter ) |
|
{ |
|
while ( 1 ) |
|
{ |
|
if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 ) |
|
break; |
|
|
|
// Tell all TCP sockets to send a keepalive. |
|
CCriticalSectionLock csLock( &g_TCPSocketsCS ); |
|
csLock.Lock(); |
|
|
|
FOR_EACH_LL( g_TCPSockets, i ) |
|
{ |
|
g_TCPSockets[i]->SendKeepalive(); |
|
} |
|
} |
|
|
|
SetEvent( g_hKeepaliveThreadReply ); |
|
return 0; |
|
} |
|
|
|
|
|
void AddGlobalTCPSocket( CTCPSocket *pSocket ) |
|
{ |
|
CCriticalSectionLock csLock( &g_TCPSocketsCS ); |
|
csLock.Lock(); |
|
|
|
Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() ); |
|
g_TCPSockets.AddToTail( pSocket ); |
|
|
|
// If this is the first one, create the keepalive thread. |
|
if ( g_TCPSockets.Count() == 1 ) |
|
{ |
|
g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL ); |
|
g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL ); |
|
|
|
DWORD dwThreadID = 0; |
|
g_hKeepaliveThread = CreateThread( |
|
NULL, |
|
0, |
|
TCPKeepaliveThread, |
|
NULL, |
|
0, |
|
&dwThreadID |
|
); |
|
} |
|
} |
|
|
|
|
|
void RemoveGlobalTCPSocket( CTCPSocket *pSocket ) |
|
{ |
|
bool bThreadRunning = false; |
|
DWORD dwExitCode = 0; |
|
if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE ) |
|
{ |
|
bThreadRunning = true; |
|
} |
|
|
|
CCriticalSectionLock csLock( &g_TCPSocketsCS ); |
|
csLock.Lock(); |
|
|
|
int index = g_TCPSockets.Find( pSocket ); |
|
if ( index != g_TCPSockets.InvalidIndex() ) |
|
{ |
|
g_TCPSockets.Remove( index ); |
|
|
|
// If this was the last one, delete the thread. |
|
if ( g_TCPSockets.Count() == 0 ) |
|
{ |
|
csLock.Unlock(); |
|
|
|
if ( bThreadRunning ) |
|
{ |
|
SetEvent( g_hKeepaliveThreadSignal ); |
|
WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE ); |
|
} |
|
|
|
CloseHandle( g_hKeepaliveThreadSignal ); |
|
CloseHandle( g_hKeepaliveThreadReply ); |
|
CloseHandle( g_hKeepaliveThread ); |
|
return; |
|
} |
|
} |
|
|
|
csLock.Unlock(); |
|
}
|
|
|