//========= Copyright Valve Corporation, All rights reserved. ============//
//
// Purpose: 
//
// $NoKeywords: $
//
//=============================================================================//

#pragma warning (disable:4127)
#include <winsock2.h>
#include <ws2tcpip.h>
#pragma warning (default:4127)

#include "iphelpers.h"
#include "basetypes.h"
#include <assert.h>
#include "utllinkedlist.h"
#include "utlvector.h"
#include "tier1/strtools.h"


// This automatically calls WSAStartup for the app at startup.
class CIPStarter
{
public:
	CIPStarter()
	{
		WSADATA wsaData;
		WSAStartup( WINSOCK_VERSION, &wsaData );
	}
};
static CIPStarter g_Starter;


unsigned long SampleMilliseconds()
{
	CCycleCount cnt;
	cnt.Sample();
	return cnt.GetMilliseconds();
}


// ------------------------------------------------------------------------------------------ //
// CChunkWalker.
// ------------------------------------------------------------------------------------------ //

CChunkWalker::CChunkWalker( void const * const *pChunks, const int *pChunkLengths, int nChunks )
{
	m_TotalLength = 0;
	for ( int i=0; i < nChunks; i++ )
		m_TotalLength += pChunkLengths[i];
	
	m_iCurChunk = 0;
	m_iCurChunkPos = 0;
	m_pChunks = pChunks;
	m_pChunkLengths = pChunkLengths;
	m_nChunks = nChunks;
}

int CChunkWalker::GetTotalLength() const
{
	return m_TotalLength;
}

void CChunkWalker::CopyTo( void *pOut, int nBytes )
{
	unsigned char *pOutPos = (unsigned char*)pOut;

	int nBytesLeft = nBytes;
	while ( nBytesLeft > 0 )
	{
		int toCopy = nBytesLeft;
		int curChunkLen = m_pChunkLengths[m_iCurChunk];
		
		int amtLeft = curChunkLen - m_iCurChunkPos;
		if ( nBytesLeft > amtLeft )
		{
			toCopy = amtLeft;
		}

		unsigned char *pCurChunkData = (unsigned char*)m_pChunks[m_iCurChunk];
		memcpy( pOutPos, &pCurChunkData[m_iCurChunkPos], toCopy );
		nBytesLeft -= toCopy;
		pOutPos += toCopy;

		// Slide up to the next chunk if we're done with the one we're on.
		m_iCurChunkPos += toCopy;
		assert( m_iCurChunkPos <= curChunkLen );
		if ( m_iCurChunkPos == curChunkLen )
		{
			++m_iCurChunk;
			m_iCurChunkPos = 0;
			if ( m_iCurChunk == m_nChunks )
			{
				assert( nBytesLeft == 0 );
			}
		}
	}
}


// ------------------------------------------------------------------------------------------ //
// CWaitTimer
// ------------------------------------------------------------------------------------------ //

bool g_bForceWaitTimers = false;

CWaitTimer::CWaitTimer( double flSeconds )
{
	m_StartTime = SampleMilliseconds();
	m_WaitMS = (unsigned long)( flSeconds * 1000.0 );
}


bool CWaitTimer::ShouldKeepWaiting()
{
	if ( m_WaitMS == 0 )
	{
		return false;
	}
	else
	{
		return ( SampleMilliseconds() - m_StartTime ) <= m_WaitMS || g_bForceWaitTimers;
	}
}



// ------------------------------------------------------------------------------------------ //
// CIPAddr.
// ------------------------------------------------------------------------------------------ //

CIPAddr::CIPAddr()
{
	Init( 0, 0, 0, 0, 0 ); 
}


CIPAddr::CIPAddr( const int inputIP[4], const int inputPort )
{
	Init( inputIP[0], inputIP[1], inputIP[2], inputIP[3], inputPort ); 
}


CIPAddr::CIPAddr( int ip0, int ip1, int ip2, int ip3, int ipPort )
{
	Init( ip0, ip1, ip2, ip3, ipPort );
}


void CIPAddr::Init( int ip0, int ip1, int ip2, int ip3, int ipPort )
{
	ip[0] = (unsigned char)ip0;
	ip[1] = (unsigned char)ip1;
	ip[2] = (unsigned char)ip2;
	ip[3] = (unsigned char)ip3;
	port  = (unsigned short)ipPort;
}

bool CIPAddr::operator==( const CIPAddr &o ) const
{
	return ip[0] == o.ip[0] && ip[1] == o.ip[1] && ip[2] == o.ip[2] && ip[3] == o.ip[3] && port == o.port;
}


bool CIPAddr::operator!=( const CIPAddr &o ) const
{
	return !( *this == o );
}


void CIPAddr::SetupLocal( int inPort )
{
	ip[0] = 0x7f;
	ip[1] = 0;
	ip[2] = 0;
	ip[3] = 1;
	port = inPort;
}


// ------------------------------------------------------------------------------------------ //
// Static helpers.
// ------------------------------------------------------------------------------------------ //

static double IP_FloatTime()
{
	CCycleCount cnt;
	cnt.Sample();
	return cnt.GetSeconds();
}

TIMEVAL SetupTimeVal( double flTimeout )
{
	TIMEVAL timeVal;
	timeVal.tv_sec = (long)flTimeout;
	timeVal.tv_usec = (long)( (flTimeout - (long)flTimeout) * 1000.0 );
	return timeVal;
}

// Convert a CIPAddr to a sockaddr_in.
void IPAddrToInAddr( const CIPAddr *pIn, in_addr *pOut )
{
	u_char *p = (u_char*)pOut;
	p[0] = pIn->ip[0];
	p[1] = pIn->ip[1];
	p[2] = pIn->ip[2];
	p[3] = pIn->ip[3];
}

// Convert a CIPAddr to a sockaddr_in.
void IPAddrToSockAddr( const CIPAddr *pIn, struct sockaddr_in *pOut )
{
	memset( pOut, 0, sizeof(*pOut) );
	pOut->sin_family = AF_INET;
	pOut->sin_port = htons( pIn->port );
	
	IPAddrToInAddr( pIn, &pOut->sin_addr );
}

// Convert a CIPAddr to a sockaddr_in.
void SockAddrToIPAddr( const struct sockaddr_in *pIn, CIPAddr *pOut )
{
	const u_char *p = (const u_char*)&pIn->sin_addr;
	pOut->ip[0] = p[0];
	pOut->ip[1] = p[1];
	pOut->ip[2] = p[2];
	pOut->ip[3] = p[3];
	pOut->port = ntohs( pIn->sin_port );
}


class CIPSocket : public ISocket
{
public:
					CIPSocket()
					{
						m_Socket = INVALID_SOCKET;
						m_bSetupToBroadcast = false;
					}

	virtual			~CIPSocket()
	{
		Term();
	}


// ISocket implementation.
public:

	virtual void	Release()
	{
		delete this;
	}

	
	virtual bool CreateSocket()
	{
		// Clear any old socket we had around.
		Term();

		// Create a socket to send and receive through.
		SOCKET sock = socket( AF_INET, SOCK_DGRAM, IPPROTO_IP );
		if ( sock == INVALID_SOCKET )
		{
			Assert( false );
			return false;
		}

		// Nonblocking please..
		int status;
		DWORD val = 1;
		status = ioctlsocket( sock, FIONBIO, &val );
		if ( status != 0 )
		{
			assert( false );
			closesocket( sock );
			return false;
		}

		m_Socket = sock;
		return true;
	}


	// Called after we have a socket.
	virtual bool BindPart2( const CIPAddr *pAddr )
	{
		Assert( m_Socket != INVALID_SOCKET );

		// bind to it!
		sockaddr_in addr;
		IPAddrToSockAddr( pAddr, &addr );

		int status = bind( m_Socket, (sockaddr*)&addr, sizeof(addr) );
		if ( status == 0 )
		{
			return true;
		}
		else
		{
			Term();
			return false;
		}
	}


	virtual bool	Bind( const CIPAddr *pAddr )
	{
		if ( !CreateSocket() )
			return false;
    
		return BindPart2( pAddr );
	}

	virtual bool	BindToAny( const unsigned short port )
	{
		// (INADDR_ANY)
		CIPAddr addr;
		addr.ip[0] = addr.ip[1] = addr.ip[2] = addr.ip[3] = 0;
		addr.port = port;
		return Bind( &addr );
	}

	virtual bool	ListenToMulticastStream( const CIPAddr &addr, const CIPAddr &localInterface )
	{
		ip_mreq mr;
		IPAddrToInAddr( &addr, &mr.imr_multiaddr );
		IPAddrToInAddr( &localInterface, &mr.imr_interface );

		// This helps a lot if the stream is sending really fast.
		int rcvBuf = 1024*1024*2;
		setsockopt( m_Socket, SOL_SOCKET, SO_RCVBUF, (char*)&rcvBuf, sizeof( rcvBuf ) );

		if ( setsockopt( m_Socket, IPPROTO_IP, IP_ADD_MEMBERSHIP, (char*)&mr, sizeof( mr ) ) == 0 )
		{
			// Remember this so we do IP_DEL_MEMBERSHIP on shutdown.
			m_bMulticastGroupMembership = true;
			m_MulticastGroupMREQ = mr;
			return true;
		}
		else
		{
			return false;
		}
	}
	
	virtual bool	Broadcast( const void *pData, const int len, const unsigned short port )
	{
		assert( m_Socket != INVALID_SOCKET );

		// Make sure we're setup to broadcast.
		if ( !m_bSetupToBroadcast )
		{
			BOOL bBroadcast = true;
			if ( setsockopt( m_Socket, SOL_SOCKET, SO_BROADCAST, (char*)&bBroadcast, sizeof( bBroadcast ) ) != 0 )
			{
				assert( false );
				return false;
			}

			m_bSetupToBroadcast = true;
		}

		CIPAddr addr;
		addr.ip[0] = addr.ip[1] = addr.ip[2] = addr.ip[3] = 0xFF;
		addr.port = port;
		return SendTo( &addr, pData, len );
	}
	
	virtual bool	SendTo( const CIPAddr *pAddr, const void *pData, const int len )
	{
		return SendChunksTo( pAddr, &pData, &len, 1 );
	}

	virtual bool SendChunksTo( const CIPAddr *pAddr, void const * const *pChunks, const int *pChunkLengths, int nChunks )
	{
		WSABUF bufs[32];
		if ( nChunks > 32 )
		{
			Error( "CIPSocket::SendChunksTo: too many chunks (%d).", nChunks );
		}

		int nTotalBytes = 0;
		for ( int i=0; i < nChunks; i++ )
		{
			bufs[i].len = pChunkLengths[i];
			bufs[i].buf = (char*)pChunks[i];
			nTotalBytes += pChunkLengths[i];
		}

		assert( m_Socket != INVALID_SOCKET );

		// Translate the address.
		sockaddr_in addr;
		IPAddrToSockAddr( pAddr, &addr );

		DWORD dwNumBytesSent = 0;
		DWORD ret = WSASendTo( 
			m_Socket, 
			bufs,
			nChunks,
			&dwNumBytesSent,
			0,
			(sockaddr*)&addr,
			sizeof( addr ),
			NULL,
			NULL
			);

		return ret == 0 && (int)dwNumBytesSent == nTotalBytes;
	}

	virtual int		RecvFrom( void *pData, int maxDataLen, CIPAddr *pFrom )
	{
		assert( m_Socket != INVALID_SOCKET );

		fd_set readSet;
		readSet.fd_count = 1;
		readSet.fd_array[0] = m_Socket;

		TIMEVAL timeVal = SetupTimeVal( 0 );

		// See if it has a packet waiting.
		int status = select( 0, &readSet, NULL, NULL, &timeVal );
		if ( status == 0 || status == SOCKET_ERROR )
			return -1;

		// Get the data.
		sockaddr_in sender;
		int fromSize = sizeof( sockaddr_in );
		status = recvfrom( m_Socket, (char*)pData, maxDataLen, 0, (struct sockaddr*)&sender, &fromSize );
		if ( status == 0 || status == SOCKET_ERROR )
		{
			return -1;
		}
		else
		{
			if ( pFrom )
			{
				SockAddrToIPAddr( &sender, pFrom );
			}
	
			m_flLastRecvTime = IP_FloatTime();
			return status;
		}
	}

	virtual double	GetRecvTimeout()
	{
		return IP_FloatTime() - m_flLastRecvTime;
	}
	

private:

	void			Term()
	{
		if ( m_Socket != INVALID_SOCKET )
		{
			if ( m_bMulticastGroupMembership )
			{
				// Undo our multicast group membership.
				setsockopt( m_Socket, IPPROTO_IP, IP_DROP_MEMBERSHIP, (char*)&m_MulticastGroupMREQ, sizeof( m_MulticastGroupMREQ ) );
			}

			closesocket( m_Socket );
			m_Socket = INVALID_SOCKET;
		}

		m_bSetupToBroadcast = false;
		m_bMulticastGroupMembership = false;
	}


private:

	SOCKET			m_Socket;
	
	bool			m_bMulticastGroupMembership; // Did we join a multicast group?
	ip_mreq			m_MulticastGroupMREQ;

	bool			m_bSetupToBroadcast;
	double			m_flLastRecvTime;
	bool			m_bListenSocket;
};



ISocket* CreateIPSocket()
{
	return new CIPSocket;
}


ISocket* CreateMulticastListenSocket( 
	const CIPAddr &addr, 
	const CIPAddr &localInterface )
{
	CIPSocket *pSocket = new CIPSocket;

	CIPAddr bindAddr = localInterface;
	bindAddr.port = addr.port;

	if ( pSocket->Bind( &bindAddr ) &&
		pSocket->ListenToMulticastStream( addr, localInterface )
		)
	{
		return pSocket;
	}
	else
	{
		pSocket->Release();
		return NULL;
	}
}


bool ConvertStringToIPAddr( const char *pStr, CIPAddr *pOut )
{
	char ipStr[512];

	const char *pColon = strchr( pStr, ':' );
	if ( pColon )
	{
		int toCopy = pColon - pStr;
		if ( toCopy < 2 || toCopy > sizeof(ipStr)-1 )
		{
			assert( false );
			return false;
		}

		memcpy( ipStr, pStr, toCopy );
		ipStr[toCopy] = 0;

		pOut->port = (unsigned short)atoi( pColon+1 );
	}
	else
	{
		strncpy( ipStr, pStr, sizeof( ipStr ) );
		ipStr[ sizeof(ipStr)-1 ] = 0;
	}

	if ( ipStr[0] >= '0' && ipStr[0] <= '9' )
	{
		// It's numbers.
		int ip[4];
		sscanf( ipStr, "%d.%d.%d.%d", &ip[0], &ip[1], &ip[2], &ip[3] );
		pOut->ip[0] = (unsigned char)ip[0];
		pOut->ip[1] = (unsigned char)ip[1];
		pOut->ip[2] = (unsigned char)ip[2];
		pOut->ip[3] = (unsigned char)ip[3];
	}
	else
	{
		// It's a text string.
		struct hostent *pHost = gethostbyname( ipStr );
		if( !pHost )
			return false;

		pOut->ip[0] = pHost->h_addr_list[0][0];
		pOut->ip[1] = pHost->h_addr_list[0][1];
		pOut->ip[2] = pHost->h_addr_list[0][2];
		pOut->ip[3] = pHost->h_addr_list[0][3];
	}

	return true;
}


bool ConvertIPAddrToString( const CIPAddr *pIn, char *pOut, int outLen )
{
	in_addr addr;
	addr.S_un.S_un_b.s_b1 = pIn->ip[0];
	addr.S_un.S_un_b.s_b2 = pIn->ip[1];
	addr.S_un.S_un_b.s_b3 = pIn->ip[2];
	addr.S_un.S_un_b.s_b4 = pIn->ip[3];

	HOSTENT *pEnt = gethostbyaddr( (char*)&addr, sizeof( addr ), AF_INET );
	if ( pEnt )
	{
		Q_strncpy( pOut, pEnt->h_name, outLen );
		return true;
	}
	else
	{
		return false;
	}
}


void IP_GetLastErrorString( char *pStr, int maxLen )
{
	char *lpMsgBuf;
	FormatMessage( 
		FORMAT_MESSAGE_ALLOCATE_BUFFER | 
		FORMAT_MESSAGE_FROM_SYSTEM | 
		FORMAT_MESSAGE_IGNORE_INSERTS,
		NULL,
		GetLastError(),
		MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
		(LPTSTR) &lpMsgBuf,
		0,
		NULL 
	);

	Q_strncpy( pStr, lpMsgBuf, maxLen );
	LocalFree( lpMsgBuf );	
}