//========= Copyright Valve Corporation, All rights reserved. ============// // // Purpose: // // $NoKeywords: $ // //=============================================================================// #pragma warning (disable:4127) #include <winsock2.h> #include <ws2tcpip.h> #pragma warning (default:4127) #include "iphelpers.h" #include "basetypes.h" #include <assert.h> #include "utllinkedlist.h" #include "utlvector.h" #include "tier1/strtools.h" // This automatically calls WSAStartup for the app at startup. class CIPStarter { public: CIPStarter() { WSADATA wsaData; WSAStartup( WINSOCK_VERSION, &wsaData ); } }; static CIPStarter g_Starter; unsigned long SampleMilliseconds() { CCycleCount cnt; cnt.Sample(); return cnt.GetMilliseconds(); } // ------------------------------------------------------------------------------------------ // // CChunkWalker. // ------------------------------------------------------------------------------------------ // CChunkWalker::CChunkWalker( void const * const *pChunks, const int *pChunkLengths, int nChunks ) { m_TotalLength = 0; for ( int i=0; i < nChunks; i++ ) m_TotalLength += pChunkLengths[i]; m_iCurChunk = 0; m_iCurChunkPos = 0; m_pChunks = pChunks; m_pChunkLengths = pChunkLengths; m_nChunks = nChunks; } int CChunkWalker::GetTotalLength() const { return m_TotalLength; } void CChunkWalker::CopyTo( void *pOut, int nBytes ) { unsigned char *pOutPos = (unsigned char*)pOut; int nBytesLeft = nBytes; while ( nBytesLeft > 0 ) { int toCopy = nBytesLeft; int curChunkLen = m_pChunkLengths[m_iCurChunk]; int amtLeft = curChunkLen - m_iCurChunkPos; if ( nBytesLeft > amtLeft ) { toCopy = amtLeft; } unsigned char *pCurChunkData = (unsigned char*)m_pChunks[m_iCurChunk]; memcpy( pOutPos, &pCurChunkData[m_iCurChunkPos], toCopy ); nBytesLeft -= toCopy; pOutPos += toCopy; // Slide up to the next chunk if we're done with the one we're on. m_iCurChunkPos += toCopy; assert( m_iCurChunkPos <= curChunkLen ); if ( m_iCurChunkPos == curChunkLen ) { ++m_iCurChunk; m_iCurChunkPos = 0; if ( m_iCurChunk == m_nChunks ) { assert( nBytesLeft == 0 ); } } } } // ------------------------------------------------------------------------------------------ // // CWaitTimer // ------------------------------------------------------------------------------------------ // bool g_bForceWaitTimers = false; CWaitTimer::CWaitTimer( double flSeconds ) { m_StartTime = SampleMilliseconds(); m_WaitMS = (unsigned long)( flSeconds * 1000.0 ); } bool CWaitTimer::ShouldKeepWaiting() { if ( m_WaitMS == 0 ) { return false; } else { return ( SampleMilliseconds() - m_StartTime ) <= m_WaitMS || g_bForceWaitTimers; } } // ------------------------------------------------------------------------------------------ // // CIPAddr. // ------------------------------------------------------------------------------------------ // CIPAddr::CIPAddr() { Init( 0, 0, 0, 0, 0 ); } CIPAddr::CIPAddr( const int inputIP[4], const int inputPort ) { Init( inputIP[0], inputIP[1], inputIP[2], inputIP[3], inputPort ); } CIPAddr::CIPAddr( int ip0, int ip1, int ip2, int ip3, int ipPort ) { Init( ip0, ip1, ip2, ip3, ipPort ); } void CIPAddr::Init( int ip0, int ip1, int ip2, int ip3, int ipPort ) { ip[0] = (unsigned char)ip0; ip[1] = (unsigned char)ip1; ip[2] = (unsigned char)ip2; ip[3] = (unsigned char)ip3; port = (unsigned short)ipPort; } bool CIPAddr::operator==( const CIPAddr &o ) const { return ip[0] == o.ip[0] && ip[1] == o.ip[1] && ip[2] == o.ip[2] && ip[3] == o.ip[3] && port == o.port; } bool CIPAddr::operator!=( const CIPAddr &o ) const { return !( *this == o ); } void CIPAddr::SetupLocal( int inPort ) { ip[0] = 0x7f; ip[1] = 0; ip[2] = 0; ip[3] = 1; port = inPort; } // ------------------------------------------------------------------------------------------ // // Static helpers. // ------------------------------------------------------------------------------------------ // static double IP_FloatTime() { CCycleCount cnt; cnt.Sample(); return cnt.GetSeconds(); } TIMEVAL SetupTimeVal( double flTimeout ) { TIMEVAL timeVal; timeVal.tv_sec = (long)flTimeout; timeVal.tv_usec = (long)( (flTimeout - (long)flTimeout) * 1000.0 ); return timeVal; } // Convert a CIPAddr to a sockaddr_in. void IPAddrToInAddr( const CIPAddr *pIn, in_addr *pOut ) { u_char *p = (u_char*)pOut; p[0] = pIn->ip[0]; p[1] = pIn->ip[1]; p[2] = pIn->ip[2]; p[3] = pIn->ip[3]; } // Convert a CIPAddr to a sockaddr_in. void IPAddrToSockAddr( const CIPAddr *pIn, struct sockaddr_in *pOut ) { memset( pOut, 0, sizeof(*pOut) ); pOut->sin_family = AF_INET; pOut->sin_port = htons( pIn->port ); IPAddrToInAddr( pIn, &pOut->sin_addr ); } // Convert a CIPAddr to a sockaddr_in. void SockAddrToIPAddr( const struct sockaddr_in *pIn, CIPAddr *pOut ) { const u_char *p = (const u_char*)&pIn->sin_addr; pOut->ip[0] = p[0]; pOut->ip[1] = p[1]; pOut->ip[2] = p[2]; pOut->ip[3] = p[3]; pOut->port = ntohs( pIn->sin_port ); } class CIPSocket : public ISocket { public: CIPSocket() { m_Socket = INVALID_SOCKET; m_bSetupToBroadcast = false; } virtual ~CIPSocket() { Term(); } // ISocket implementation. public: virtual void Release() { delete this; } virtual bool CreateSocket() { // Clear any old socket we had around. Term(); // Create a socket to send and receive through. SOCKET sock = socket( AF_INET, SOCK_DGRAM, IPPROTO_IP ); if ( sock == INVALID_SOCKET ) { Assert( false ); return false; } // Nonblocking please.. int status; DWORD val = 1; status = ioctlsocket( sock, FIONBIO, &val ); if ( status != 0 ) { assert( false ); closesocket( sock ); return false; } m_Socket = sock; return true; } // Called after we have a socket. virtual bool BindPart2( const CIPAddr *pAddr ) { Assert( m_Socket != INVALID_SOCKET ); // bind to it! sockaddr_in addr; IPAddrToSockAddr( pAddr, &addr ); int status = bind( m_Socket, (sockaddr*)&addr, sizeof(addr) ); if ( status == 0 ) { return true; } else { Term(); return false; } } virtual bool Bind( const CIPAddr *pAddr ) { if ( !CreateSocket() ) return false; return BindPart2( pAddr ); } virtual bool BindToAny( const unsigned short port ) { // (INADDR_ANY) CIPAddr addr; addr.ip[0] = addr.ip[1] = addr.ip[2] = addr.ip[3] = 0; addr.port = port; return Bind( &addr ); } virtual bool ListenToMulticastStream( const CIPAddr &addr, const CIPAddr &localInterface ) { ip_mreq mr; IPAddrToInAddr( &addr, &mr.imr_multiaddr ); IPAddrToInAddr( &localInterface, &mr.imr_interface ); // This helps a lot if the stream is sending really fast. int rcvBuf = 1024*1024*2; setsockopt( m_Socket, SOL_SOCKET, SO_RCVBUF, (char*)&rcvBuf, sizeof( rcvBuf ) ); if ( setsockopt( m_Socket, IPPROTO_IP, IP_ADD_MEMBERSHIP, (char*)&mr, sizeof( mr ) ) == 0 ) { // Remember this so we do IP_DEL_MEMBERSHIP on shutdown. m_bMulticastGroupMembership = true; m_MulticastGroupMREQ = mr; return true; } else { return false; } } virtual bool Broadcast( const void *pData, const int len, const unsigned short port ) { assert( m_Socket != INVALID_SOCKET ); // Make sure we're setup to broadcast. if ( !m_bSetupToBroadcast ) { BOOL bBroadcast = true; if ( setsockopt( m_Socket, SOL_SOCKET, SO_BROADCAST, (char*)&bBroadcast, sizeof( bBroadcast ) ) != 0 ) { assert( false ); return false; } m_bSetupToBroadcast = true; } CIPAddr addr; addr.ip[0] = addr.ip[1] = addr.ip[2] = addr.ip[3] = 0xFF; addr.port = port; return SendTo( &addr, pData, len ); } virtual bool SendTo( const CIPAddr *pAddr, const void *pData, const int len ) { return SendChunksTo( pAddr, &pData, &len, 1 ); } virtual bool SendChunksTo( const CIPAddr *pAddr, void const * const *pChunks, const int *pChunkLengths, int nChunks ) { WSABUF bufs[32]; if ( nChunks > 32 ) { Error( "CIPSocket::SendChunksTo: too many chunks (%d).", nChunks ); } int nTotalBytes = 0; for ( int i=0; i < nChunks; i++ ) { bufs[i].len = pChunkLengths[i]; bufs[i].buf = (char*)pChunks[i]; nTotalBytes += pChunkLengths[i]; } assert( m_Socket != INVALID_SOCKET ); // Translate the address. sockaddr_in addr; IPAddrToSockAddr( pAddr, &addr ); DWORD dwNumBytesSent = 0; DWORD ret = WSASendTo( m_Socket, bufs, nChunks, &dwNumBytesSent, 0, (sockaddr*)&addr, sizeof( addr ), NULL, NULL ); return ret == 0 && (int)dwNumBytesSent == nTotalBytes; } virtual int RecvFrom( void *pData, int maxDataLen, CIPAddr *pFrom ) { assert( m_Socket != INVALID_SOCKET ); fd_set readSet; readSet.fd_count = 1; readSet.fd_array[0] = m_Socket; TIMEVAL timeVal = SetupTimeVal( 0 ); // See if it has a packet waiting. int status = select( 0, &readSet, NULL, NULL, &timeVal ); if ( status == 0 || status == SOCKET_ERROR ) return -1; // Get the data. sockaddr_in sender; int fromSize = sizeof( sockaddr_in ); status = recvfrom( m_Socket, (char*)pData, maxDataLen, 0, (struct sockaddr*)&sender, &fromSize ); if ( status == 0 || status == SOCKET_ERROR ) { return -1; } else { if ( pFrom ) { SockAddrToIPAddr( &sender, pFrom ); } m_flLastRecvTime = IP_FloatTime(); return status; } } virtual double GetRecvTimeout() { return IP_FloatTime() - m_flLastRecvTime; } private: void Term() { if ( m_Socket != INVALID_SOCKET ) { if ( m_bMulticastGroupMembership ) { // Undo our multicast group membership. setsockopt( m_Socket, IPPROTO_IP, IP_DROP_MEMBERSHIP, (char*)&m_MulticastGroupMREQ, sizeof( m_MulticastGroupMREQ ) ); } closesocket( m_Socket ); m_Socket = INVALID_SOCKET; } m_bSetupToBroadcast = false; m_bMulticastGroupMembership = false; } private: SOCKET m_Socket; bool m_bMulticastGroupMembership; // Did we join a multicast group? ip_mreq m_MulticastGroupMREQ; bool m_bSetupToBroadcast; double m_flLastRecvTime; bool m_bListenSocket; }; ISocket* CreateIPSocket() { return new CIPSocket; } ISocket* CreateMulticastListenSocket( const CIPAddr &addr, const CIPAddr &localInterface ) { CIPSocket *pSocket = new CIPSocket; CIPAddr bindAddr = localInterface; bindAddr.port = addr.port; if ( pSocket->Bind( &bindAddr ) && pSocket->ListenToMulticastStream( addr, localInterface ) ) { return pSocket; } else { pSocket->Release(); return NULL; } } bool ConvertStringToIPAddr( const char *pStr, CIPAddr *pOut ) { char ipStr[512]; const char *pColon = strchr( pStr, ':' ); if ( pColon ) { int toCopy = pColon - pStr; if ( toCopy < 2 || toCopy > sizeof(ipStr)-1 ) { assert( false ); return false; } memcpy( ipStr, pStr, toCopy ); ipStr[toCopy] = 0; pOut->port = (unsigned short)atoi( pColon+1 ); } else { strncpy( ipStr, pStr, sizeof( ipStr ) ); ipStr[ sizeof(ipStr)-1 ] = 0; } if ( ipStr[0] >= '0' && ipStr[0] <= '9' ) { // It's numbers. int ip[4]; sscanf( ipStr, "%d.%d.%d.%d", &ip[0], &ip[1], &ip[2], &ip[3] ); pOut->ip[0] = (unsigned char)ip[0]; pOut->ip[1] = (unsigned char)ip[1]; pOut->ip[2] = (unsigned char)ip[2]; pOut->ip[3] = (unsigned char)ip[3]; } else { // It's a text string. struct hostent *pHost = gethostbyname( ipStr ); if( !pHost ) return false; pOut->ip[0] = pHost->h_addr_list[0][0]; pOut->ip[1] = pHost->h_addr_list[0][1]; pOut->ip[2] = pHost->h_addr_list[0][2]; pOut->ip[3] = pHost->h_addr_list[0][3]; } return true; } bool ConvertIPAddrToString( const CIPAddr *pIn, char *pOut, int outLen ) { in_addr addr; addr.S_un.S_un_b.s_b1 = pIn->ip[0]; addr.S_un.S_un_b.s_b2 = pIn->ip[1]; addr.S_un.S_un_b.s_b3 = pIn->ip[2]; addr.S_un.S_un_b.s_b4 = pIn->ip[3]; HOSTENT *pEnt = gethostbyaddr( (char*)&addr, sizeof( addr ), AF_INET ); if ( pEnt ) { Q_strncpy( pOut, pEnt->h_name, outLen ); return true; } else { return false; } } void IP_GetLastErrorString( char *pStr, int maxLen ) { char *lpMsgBuf; FormatMessage( FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language (LPTSTR) &lpMsgBuf, 0, NULL ); Q_strncpy( pStr, lpMsgBuf, maxLen ); LocalFree( lpMsgBuf ); }