source-engine/utils/vmpi/tcpsocket.cpp

1179 lines
26 KiB
C++
Raw Permalink Normal View History

2020-04-22 12:56:21 -04:00
//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose:
//
// $NoKeywords: $
//=============================================================================//
//#define PARANOID
#if defined( PARANOID )
#include <stdlib.h>
#include <crtdbg.h>
#endif
#include <winsock2.h>
#include <mswsock.h>
#include "tcpsocket.h"
#include "tier1/utllinkedlist.h"
#include <stdio.h>
#include "threadhelpers.h"
#include "tier0/dbg.h"
#error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead."
extern TIMEVAL SetupTimeVal( double flTimeout );
extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut );
extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut );
#define SENTINEL_DISCONNECT -1
#define SENTINEL_KEEPALIVE -2
#define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS
#define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long
static bool g_bEnableTCPTimeout = true;
class CRecvData
{
public:
int m_Count;
unsigned char m_Data[1];
};
SOCKET TCPBind( const CIPAddr *pAddr )
{
// Create a socket to send and receive through.
SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED );
if ( sock == INVALID_SOCKET )
{
Assert( false );
return INVALID_SOCKET;
}
// bind to it!
sockaddr_in addr;
IPAddrToSockAddr( pAddr, &addr );
int status = bind( sock, (sockaddr*)&addr, sizeof(addr) );
if ( status == 0 )
{
return sock;
}
else
{
closesocket( sock );
return INVALID_SOCKET;
}
}
// ---------------------------------------------------------------------------------------- //
// TCP sockets.
// ---------------------------------------------------------------------------------------- //
enum
{
OP_RECV=111,
OP_SEND
};
// We use this for all OVERLAPPED structures.
class COverlappedPlus : public WSAOVERLAPPED
{
public:
COverlappedPlus()
{
memset( this, 0, sizeof( WSAOVERLAPPED ) );
}
int m_OPType; // One of the OP_ defines.
};
typedef struct SendBuf_t
{
COverlappedPlus m_Overlapped;
int m_Index; // Index into m_SendBufs.
int m_DataLength;
char m_Data[1];
} SendBuf_s;
// These manage a thread that calls SendKeepalive() on all TCPSockets.
// AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called.
class CTCPSocket;
void AddGlobalTCPSocket( CTCPSocket *pSocket );
void RemoveGlobalTCPSocket( CTCPSocket *pSocket );
// ------------------------------------------------------------------------------------------ //
// CTCPSocket implementation.
// ------------------------------------------------------------------------------------------ //
class CTCPSocket : public ITCPSocket
{
friend class CTCPListenSocket;
public:
CTCPSocket()
{
m_Socket = INVALID_SOCKET;
m_bConnected = false;
m_hIOCP = NULL;
m_bShouldExitThreads = false;
m_bConnectionLost = false;
m_nSizeBytesReceived = 0;
m_pIncomingData = NULL;
memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) );
m_RecvOverlapped.m_OPType = OP_RECV;
m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL );
m_RecvStage = -1;
m_MainThreadID = GetCurrentThreadId();
}
virtual ~CTCPSocket()
{
Term();
CloseHandle( m_hRecvSignal );
}
void Term()
{
Assert( GetCurrentThreadId() == m_MainThreadID );
RemoveGlobalTCPSocket( this );
if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost )
{
SendDisconnectSentinel();
// Give the sends a second to complete. SO_LINGER is having trouble for some reason.
WaitForSendsToComplete( 1 );
}
StopThreads();
if ( m_Socket != INVALID_SOCKET )
{
closesocket( m_Socket );
m_Socket = INVALID_SOCKET;
}
if ( m_hIOCP )
{
CloseHandle( m_hIOCP );
m_hIOCP = NULL;
}
m_bConnected = false;
m_bConnectionLost = true;
m_RecvStage = -1;
FOR_EACH_LL( m_SendBufs, i )
{
SendBuf_t *pSendBuf = m_SendBufs[i];
ParanoidMemoryCheck( pSendBuf );
free( pSendBuf );
}
m_SendBufs.Purge();
FOR_EACH_LL( m_RecvDatas, j )
{
CRecvData *pRecvData = m_RecvDatas[j];
ParanoidMemoryCheck( pRecvData );
free( pRecvData );
}
m_RecvDatas.Purge();
if ( m_pIncomingData )
{
ParanoidMemoryCheck( m_pIncomingData );
free( m_pIncomingData );
m_pIncomingData = 0;
}
}
virtual void Release()
{
delete this;
}
void ParanoidMemoryCheck( void *ptr = NULL )
{
#if defined( PARANOID )
Assert( _CrtIsValidHeapPointer( this ) );
if ( ptr )
{
Assert( _CrtIsValidHeapPointer( ptr ) );
}
Assert( _CrtCheckMemory() == TRUE );
#endif
}
virtual bool BindToAny( const unsigned short port )
{
Term();
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
m_Socket = TCPBind( &addr );
if ( m_Socket == INVALID_SOCKET )
{
return false;
}
else
{
SetInitialSocketOptions();
return true;
}
}
// Set the initial socket options that we want.
void SetInitialSocketOptions()
{
// Set nodelay to improve latency.
BOOL val = TRUE;
setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) );
// Make it linger for 3 seconds when it exits.
LINGER linger;
linger.l_onoff = 1;
linger.l_linger = 3;
setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) );
}
// Called only by main thread interface functions.
// Returns true if the connection is lost.
bool CheckConnectionLost()
{
Assert( GetCurrentThreadId() == m_MainThreadID );
if ( m_Socket == SOCKET_ERROR )
return true;
// Have we timed out?
if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) )
{
SetConnectionLost( "Connection timed out." );
}
// Has any thread posted that the connection has been lost?
CCriticalSectionLock postLock( &m_ConnectionLostCS );
postLock.Lock();
if ( m_bConnectionLost )
{
Term();
return true;
}
else
{
return false;
}
}
// Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost.
void SetConnectionLost( const char *pErrorString, int err = -1 )
{
CCriticalSectionLock postLock( &m_ConnectionLostCS );
postLock.Lock();
m_bConnectionLost = true;
postLock.Unlock();
// Handle it right away if we're in the main thread. If we're in an IO thread,
// it has to wait until the next interface function calls CheckConnectionLost().
if ( GetCurrentThreadId() == m_MainThreadID )
{
Term();
}
if ( pErrorString )
{
m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 );
}
else
{
char *lpMsgBuf;
FormatMessage(
FORMAT_MESSAGE_ALLOCATE_BUFFER |
FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL,
err,
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
(LPTSTR) &lpMsgBuf,
0,
NULL
);
m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 );
LocalFree( lpMsgBuf );
}
}
// -------------------------------------------------------------------------------------------------- //
// The receive code.
// -------------------------------------------------------------------------------------------------- //
virtual bool StartWaitingForSize( bool bFresh )
{
Assert( m_Socket != INVALID_SOCKET );
Assert( m_bConnected );
m_RecvStage = 0;
m_RecvDataSize = -1;
if ( bFresh )
m_nSizeBytesReceived = 0;
DWORD dwNumBytesReceived = 0;
WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived };
DWORD dwFlags = 0;
int status = WSARecv(
m_Socket,
&buf,
1,
&dwNumBytesReceived,
&dwFlags,
&m_RecvOverlapped,
NULL );
int err = -1;
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
{
SetConnectionLost( NULL, err );
return false;
}
else
{
return true;
}
}
bool PostNextDataPart()
{
DWORD dwNumBytesReceived = 0;
WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived };
DWORD dwFlags = 0;
int status = WSARecv(
m_Socket,
&buf,
1,
&dwNumBytesReceived,
&dwFlags,
&m_RecvOverlapped,
NULL );
int err = -1;
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
{
SetConnectionLost( NULL, err );
return false;
}
else
{
return true;
}
}
bool StartWaitingForData()
{
Assert( m_Socket != INVALID_SOCKET );
Assert( m_RecvStage == 0 );
Assert( m_bConnected );
Assert( m_RecvDataSize > 0 );
m_RecvStage = 1;
// Add a CRecvData element.
ParanoidMemoryCheck();
m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize );
if ( !m_pIncomingData )
{
char str[512];
_snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize );
SetConnectionLost( str );
return false;
}
m_pIncomingData->m_Count = m_RecvDataSize;
m_AmountReceived = 0;
return PostNextDataPart();
}
virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout )
{
if ( CheckConnectionLost() )
return false;
// Wait in 50ms chunks, checking for disconnections along the way.
bool bGotData = false;
DWORD msToWait = (DWORD)( flTimeout * 1000.0 );
do
{
DWORD curWaitTime = min( msToWait, 50 );
DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime );
if ( ret == WAIT_OBJECT_0 )
{
bGotData = true;
break;
}
// Did the connection timeout?
if ( CheckConnectionLost() )
return false;
msToWait -= curWaitTime;
} while ( msToWait );
// If we never got a WAIT_OBJECT_0, then we never received anything.
if ( !bGotData )
return false;
CCriticalSectionLock csLock( &m_RecvDataCS );
csLock.Lock();
// Pickup the head m_RecvDatas element.
CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ];
data.CopyArray( pRecvData->m_Data, pRecvData->m_Count );
// Now free it.
m_RecvDatas.Remove( m_RecvDatas.Head() );
ParanoidMemoryCheck( pRecvData );
free( pRecvData );
// Set the event again for the next time around, if there is more data waiting.
if ( m_RecvDatas.Count() > 0 )
SetEvent( m_hRecvSignal );
return true;
}
// INSIDE IO THREAD.
void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
{
if ( dwNumBytes == 0 )
{
SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" );
return;
}
m_LastRecvTime = Plat_FloatTime();
if ( m_RecvStage == 0 )
{
m_nSizeBytesReceived += dwNumBytes;
if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) )
{
// Size of -1 means the other size is breaking the connection.
if ( m_RecvDataSize == SENTINEL_DISCONNECT )
{
SetConnectionLost( "Got a graceful disconnect message." );
return;
}
else if ( m_RecvDataSize == SENTINEL_KEEPALIVE )
{
// No data follows this. Just let m_LastRecvTime get updated.
StartWaitingForSize( true );
return;
}
StartWaitingForData();
}
else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) )
{
// Handle the case where we only got some of the data (maybe one of the clients got disconnected).
StartWaitingForSize( false );
}
else
{
// This case should never ever happen!
#if defined( _DEBUG )
__asm int 3;
#endif
SetConnectionLost( "Received too much data in a packet!" );
return;
}
}
else if ( m_RecvStage == 1 )
{
// Got the data, make sure we got it all.
m_AmountReceived += dwNumBytes;
// Sanity check.
#if defined( _DEBUG )
Assert( m_RecvDataSize == m_pIncomingData->m_Count );
Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads.
#endif
if ( m_AmountReceived == m_RecvDataSize )
{
m_RecvStage = 2;
// Add the data to the list of packets waiting to be picked up.
CCriticalSectionLock csLock( &m_RecvDataCS );
csLock.Lock();
m_RecvDatas.AddToTail( m_pIncomingData );
m_pIncomingData = NULL;
if ( m_RecvDatas.Count() == 1 )
SetEvent( m_hRecvSignal ); // Notify the Recv() function.
StartWaitingForSize( true );
}
else
{
PostNextDataPart();
}
}
else
{
Assert( false );
}
}
// -------------------------------------------------------------------------------------------------- //
// The send code.
// -------------------------------------------------------------------------------------------------- //
virtual void WaitForSendsToComplete( double flTimeout )
{
CWaitTimer waitTimer( flTimeout );
while ( 1 )
{
CCriticalSectionLock sendBufLock( &m_SendCS );
sendBufLock.Lock();
if( m_SendBufs.Count() == 0 )
return;
sendBufLock.Unlock();
if ( waitTimer.ShouldKeepWaiting() )
Sleep( 10 );
else
break;
}
}
// This is called in the keepalive thread.
void SendKeepalive()
{
// Send a message saying we're exiting.
ParanoidMemoryCheck();
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
if ( !pBuf )
{
SetConnectionLost( "malloc() in SendKeepalive() failed." );
return;
}
pBuf->m_DataLength = sizeof( int );
*((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE;
InternalSendDataBuf( pBuf );
}
void SendDisconnectSentinel()
{
// Send a message saying we're exiting.
ParanoidMemoryCheck();
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
if ( pBuf )
{
pBuf->m_DataLength = sizeof( int );
*((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting.
InternalSendDataBuf( pBuf );
}
}
virtual bool Send( const void *pData, int len )
{
const void *pChunks[1] = { pData };
int chunkLengths[1] = { len };
return SendChunks( pChunks, chunkLengths, 1 );
}
virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks )
{
if ( CheckConnectionLost() )
return false;
CChunkWalker walker( pChunks, pChunkLengths, nChunks );
int totalLength = walker.GetTotalLength();
if ( !totalLength )
return true;
// Create a buffer to hold the data and copy the data in.
ParanoidMemoryCheck();
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) );
if ( !pBuf )
{
char str[512];
_snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength );
SetConnectionLost( str );
return false;
}
pBuf->m_DataLength = totalLength + sizeof( int );
int *pByteCountPos = (int*)pBuf->m_Data;
*pByteCountPos = totalLength;
char *pDataPos = &pBuf->m_Data[ sizeof( int ) ];
walker.CopyTo( pDataPos, totalLength );
int status = InternalSendDataBuf( pBuf );
int err = -1;
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
{
SetConnectionLost( NULL, err );
return false;
}
else
{
return true;
}
}
int InternalSendDataBuf( SendBuf_t *pBuf )
{
// Protect against interference from the keepalive thread.
CCriticalSectionLock csLock( &m_SendCS );
csLock.Lock();
pBuf->m_Overlapped.m_OPType = OP_SEND;
pBuf->m_Overlapped.hEvent = NULL;
// Add it to our list of buffers.
pBuf->m_Index = m_SendBufs.AddToTail( pBuf );
// Tell Winsock to send it.
WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data };
DWORD dwNumBytesSent = 0;
return WSASend(
m_Socket,
&buf,
1,
&dwNumBytesSent,
0,
&pBuf->m_Overlapped,
NULL );
}
// INSIDE IO THREAD.
void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
{
if ( dwNumBytes == 0 )
{
SetConnectionLost( "0 bytes in HandleSendCompletion." );
return;
}
// Just free the buffer.
SendBuf_t *pBuf = (SendBuf_t*)pInfo;
Assert( dwNumBytes == (DWORD)pBuf->m_DataLength );
CCriticalSectionLock sendBufLock( &m_SendCS );
sendBufLock.Lock();
m_SendBufs.Remove( pBuf->m_Index );
sendBufLock.Unlock();
ParanoidMemoryCheck( pBuf );
free( pBuf );
}
// -------------------------------------------------------------------------------------------------- //
// The connect code.
// -------------------------------------------------------------------------------------------------- //
virtual bool BeginConnect( const CIPAddr &inputAddr )
{
sockaddr_in addr;
IPAddrToSockAddr( &inputAddr, &addr );
m_bConnected = false;
int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) );
ret=ret;
return true;
}
virtual bool UpdateConnect()
{
// We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
fd_set writeSet;
writeSet.fd_count = 1;
writeSet.fd_array[0] = m_Socket;
TIMEVAL timeVal = SetupTimeVal( 0 );
// See if it has a packet waiting.
int status = select( 0, NULL, &writeSet, NULL, &timeVal );
if ( status > 0 )
{
SetupConnected();
return true;
}
return false;
}
void SetupConnected()
{
m_bConnected = true;
m_bConnectionLost = false;
m_LastRecvTime = Plat_FloatTime();
CreateThreads();
StartWaitingForSize( true );
AddGlobalTCPSocket( this );
}
virtual bool IsConnected()
{
CheckConnectionLost();
return m_bConnected;
}
virtual void GetDisconnectReason( CUtlVector<char> &reason )
{
reason = m_ErrorString;
}
// -------------------------------------------------------------------------------------------------- //
// Threads code.
// -------------------------------------------------------------------------------------------------- //
// Create our IO Completion Port threads.
bool CreateThreads()
{
int nThreads = 1;
SetShouldExitThreads( false );
// Create our IO completion port and hook it to our socket.
m_hIOCP = CreateIoCompletionPort(
INVALID_HANDLE_VALUE, NULL, 0, 0);
m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads );
for ( int i=0; i < nThreads; i++ )
{
DWORD dwThreadID = 0;
HANDLE hThread = CreateThread(
NULL,
0,
&CTCPSocket::StaticThreadFn,
this,
0,
&dwThreadID );
if ( hThread )
{
SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL );
m_Threads.AddToTail( hThread );
}
else
{
StopThreads();
return false;
}
}
return true;
}
void StopThreads()
{
// Tell the threads to exit, then wait for them to do so.
SetShouldExitThreads( true );
WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE );
for ( int i=0; i < m_Threads.Count(); i++ )
{
CloseHandle( m_Threads[i] );
}
m_Threads.Purge();
}
void SetShouldExitThreads( bool bShouldExit )
{
CCriticalSectionLock lock( &m_ThreadsCS );
lock.Lock();
m_bShouldExitThreads = bShouldExit;
}
bool ShouldExitThreads()
{
CCriticalSectionLock lock( &m_ThreadsCS );
lock.Lock();
bool bRet = m_bShouldExitThreads;
return bRet;
}
DWORD ThreadFn()
{
while ( 1 )
{
DWORD dwNumBytes = 0;
unsigned long pInputTCPSocket;
LPOVERLAPPED pOverlapped;
if ( GetQueuedCompletionStatus(
m_hIOCP, // the port we're listening on
&dwNumBytes, // # bytes received on the port
&pInputTCPSocket,// "completion key" = CTCPSocket*
&pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend.
100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted.
) )
{
COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped;
ParanoidMemoryCheck( pInfo );
if ( pInfo->m_OPType == OP_RECV )
{
Assert( pInfo == &m_RecvOverlapped );
HandleRecvCompletion( pInfo, dwNumBytes );
}
else
{
Assert( pInfo->m_OPType == OP_SEND );
HandleSendCompletion( pInfo, dwNumBytes );
}
}
if ( ShouldExitThreads() )
break;
}
return 0;
}
static DWORD WINAPI StaticThreadFn( LPVOID pParameter )
{
return ((CTCPSocket*)pParameter)->ThreadFn();
}
private:
SOCKET m_Socket;
bool m_bConnected;
// m_RecvOverlapped is setup to first wait for the size, then the data.
// Then it is not posted until the app grabs the data.
HANDLE m_hRecvSignal; // Tells Recv() when we have data.
COverlappedPlus m_RecvOverlapped;
int m_RecvStage; // -1 = not initialized
// 0 = waiting for size
// 1 = waiting for data
// 2 = waiting for app to pickup the data
CUtlLinkedList<CRecvData*,int> m_RecvDatas; // The head element is the next one to be picked up.
CRecvData *m_pIncomingData; // The packet we're currently receiving.
CCriticalSection m_RecvDataCS; // This protects adds and removes in the list.
// These reference the element at the tail of m_RecvData. It is the current one getting
volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet?
int m_RecvDataSize; // this is received over the network
int m_AmountReceived; // How much we've received so far.
// Last time we received anything from this connection. Used to determine if the connection is
// still active.
double m_LastRecvTime;
// Outgoing send buffers.
CUtlLinkedList<SendBuf_t*,int> m_SendBufs;
CCriticalSection m_SendCS;
// All the threads waiting for IO.
CUtlVector<HANDLE> m_Threads;
HANDLE m_hIOCP;
// Used during shutdown.
volatile bool m_bShouldExitThreads;
CCriticalSection m_ThreadsCS;
// For debugging.
DWORD m_MainThreadID;
// Set by the main thread or IO threads to signal connection lost.
bool m_bConnectionLost;
CCriticalSection m_ConnectionLostCS;
// This is set when we get disconnected.
CUtlVector<char> m_ErrorString;
};
// ------------------------------------------------------------------------------------------ //
// ITCPListenSocket implementation.
// ------------------------------------------------------------------------------------------ //
class CTCPListenSocket : public ITCPListenSocket
{
public:
CTCPListenSocket()
{
m_Socket = INVALID_SOCKET;
}
virtual ~CTCPListenSocket()
{
if ( m_Socket != INVALID_SOCKET )
{
closesocket( m_Socket );
}
}
// The main function to create one of these suckers.
static ITCPListenSocket* Create( const unsigned short port, int nQueueLength )
{
CTCPListenSocket *pRet = new CTCPListenSocket;
if ( !pRet )
return NULL;
// Bind it to a socket and start listening.
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
pRet->m_Socket = TCPBind( &addr );
if ( pRet->m_Socket == INVALID_SOCKET ||
listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 )
{
pRet->Release();
return false;
}
return pRet;
}
virtual void Release()
{
delete this;
}
virtual ITCPSocket* UpdateListen( CIPAddr *pAddr )
{
// We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
fd_set readSet;
readSet.fd_count = 1;
readSet.fd_array[0] = m_Socket;
TIMEVAL timeVal = SetupTimeVal( 0 );
// Wait until it connects.
int status = select( 0, &readSet, NULL, NULL, &timeVal );
if ( status > 0 )
{
sockaddr_in addr;
int addrSize = sizeof( addr );
// Now accept the final connection.
SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize );
if ( newSock == INVALID_SOCKET )
{
Assert( false );
}
else
{
CTCPSocket *pRet = new CTCPSocket;
if ( !pRet )
{
closesocket( newSock );
return NULL;
}
pRet->m_Socket = newSock;
pRet->SetInitialSocketOptions();
pRet->SetupConnected();
// Report the address..
SockAddrToIPAddr( &addr, pAddr );
return pRet;
}
}
return NULL;
}
private:
SOCKET m_Socket;
};
ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength )
{
return CTCPListenSocket::Create( port, nQueueLength );
}
ITCPSocket* CreateTCPSocket()
{
return new CTCPSocket;
}
void TCPSocket_EnableTimeout( bool bEnable )
{
g_bEnableTCPTimeout = bEnable;
}
// --------------------------------------------------------------------------------- //
// This thread sends keepalives on all active TCP sockets.
// --------------------------------------------------------------------------------- //
HANDLE g_hKeepaliveThread;
HANDLE g_hKeepaliveThreadSignal;
HANDLE g_hKeepaliveThreadReply;
CUtlLinkedList<CTCPSocket*,int> g_TCPSockets;
CCriticalSection g_TCPSocketsCS;
DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter )
{
while ( 1 )
{
if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 )
break;
// Tell all TCP sockets to send a keepalive.
CCriticalSectionLock csLock( &g_TCPSocketsCS );
csLock.Lock();
FOR_EACH_LL( g_TCPSockets, i )
{
g_TCPSockets[i]->SendKeepalive();
}
}
SetEvent( g_hKeepaliveThreadReply );
return 0;
}
void AddGlobalTCPSocket( CTCPSocket *pSocket )
{
CCriticalSectionLock csLock( &g_TCPSocketsCS );
csLock.Lock();
Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() );
g_TCPSockets.AddToTail( pSocket );
// If this is the first one, create the keepalive thread.
if ( g_TCPSockets.Count() == 1 )
{
g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL );
g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL );
DWORD dwThreadID = 0;
g_hKeepaliveThread = CreateThread(
NULL,
0,
TCPKeepaliveThread,
NULL,
0,
&dwThreadID
);
}
}
void RemoveGlobalTCPSocket( CTCPSocket *pSocket )
{
bool bThreadRunning = false;
DWORD dwExitCode = 0;
if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE )
{
bThreadRunning = true;
}
CCriticalSectionLock csLock( &g_TCPSocketsCS );
csLock.Lock();
int index = g_TCPSockets.Find( pSocket );
if ( index != g_TCPSockets.InvalidIndex() )
{
g_TCPSockets.Remove( index );
// If this was the last one, delete the thread.
if ( g_TCPSockets.Count() == 0 )
{
csLock.Unlock();
if ( bThreadRunning )
{
SetEvent( g_hKeepaliveThreadSignal );
WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE );
}
CloseHandle( g_hKeepaliveThreadSignal );
CloseHandle( g_hKeepaliveThreadReply );
CloseHandle( g_hKeepaliveThread );
return;
}
}
csLock.Unlock();
}