1178 lines
26 KiB
C++
1178 lines
26 KiB
C++
//========= Copyright Valve Corporation, All rights reserved. ============//
|
|
//
|
|
// Purpose:
|
|
//
|
|
// $NoKeywords: $
|
|
//=============================================================================//
|
|
|
|
|
|
//#define PARANOID
|
|
|
|
#if defined( PARANOID )
|
|
#include <stdlib.h>
|
|
#include <crtdbg.h>
|
|
#endif
|
|
|
|
#include <winsock2.h>
|
|
#include <mswsock.h>
|
|
#include "tcpsocket.h"
|
|
#include "tier1/utllinkedlist.h"
|
|
#include <stdio.h>
|
|
#include "threadhelpers.h"
|
|
#include "tier0/dbg.h"
|
|
|
|
|
|
|
|
#error "I am TCPSocket and I suck. Use IThreadedTCPSocket or ThreadedTCPSocketEmu instead."
|
|
|
|
|
|
extern TIMEVAL SetupTimeVal( double flTimeout );
|
|
extern void IPAddrToSockAddr( const CIPAddr *pIn, sockaddr_in *pOut );
|
|
extern void SockAddrToIPAddr( const sockaddr_in *pIn, CIPAddr *pOut );
|
|
|
|
|
|
#define SENTINEL_DISCONNECT -1
|
|
#define SENTINEL_KEEPALIVE -2
|
|
|
|
|
|
#define KEEPALIVE_INTERVAL_MS 3000 // keepalives are sent every N MS
|
|
#define KEEPALIVE_TIMEOUT_SECONDS 15.0 // connections timeout after this long
|
|
|
|
|
|
static bool g_bEnableTCPTimeout = true;
|
|
|
|
|
|
class CRecvData
|
|
{
|
|
public:
|
|
int m_Count;
|
|
unsigned char m_Data[1];
|
|
};
|
|
|
|
|
|
|
|
SOCKET TCPBind( const CIPAddr *pAddr )
|
|
{
|
|
// Create a socket to send and receive through.
|
|
SOCKET sock = WSASocket( AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED );
|
|
if ( sock == INVALID_SOCKET )
|
|
{
|
|
Assert( false );
|
|
return INVALID_SOCKET;
|
|
}
|
|
|
|
// bind to it!
|
|
sockaddr_in addr;
|
|
IPAddrToSockAddr( pAddr, &addr );
|
|
|
|
int status = bind( sock, (sockaddr*)&addr, sizeof(addr) );
|
|
if ( status == 0 )
|
|
{
|
|
return sock;
|
|
}
|
|
else
|
|
{
|
|
closesocket( sock );
|
|
return INVALID_SOCKET;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
// ---------------------------------------------------------------------------------------- //
|
|
// TCP sockets.
|
|
// ---------------------------------------------------------------------------------------- //
|
|
|
|
enum
|
|
{
|
|
OP_RECV=111,
|
|
OP_SEND
|
|
};
|
|
|
|
// We use this for all OVERLAPPED structures.
|
|
class COverlappedPlus : public WSAOVERLAPPED
|
|
{
|
|
public:
|
|
COverlappedPlus()
|
|
{
|
|
memset( this, 0, sizeof( WSAOVERLAPPED ) );
|
|
}
|
|
|
|
int m_OPType; // One of the OP_ defines.
|
|
};
|
|
|
|
typedef struct SendBuf_t
|
|
{
|
|
COverlappedPlus m_Overlapped;
|
|
int m_Index; // Index into m_SendBufs.
|
|
int m_DataLength;
|
|
char m_Data[1];
|
|
} SendBuf_s;
|
|
|
|
|
|
// These manage a thread that calls SendKeepalive() on all TCPSockets.
|
|
// AddGlobalTCPSocket shouldn't be called until you're ready for SendKeepalive() to be called.
|
|
class CTCPSocket;
|
|
void AddGlobalTCPSocket( CTCPSocket *pSocket );
|
|
void RemoveGlobalTCPSocket( CTCPSocket *pSocket );
|
|
|
|
|
|
|
|
// ------------------------------------------------------------------------------------------ //
|
|
// CTCPSocket implementation.
|
|
// ------------------------------------------------------------------------------------------ //
|
|
|
|
class CTCPSocket : public ITCPSocket
|
|
{
|
|
friend class CTCPListenSocket;
|
|
|
|
public:
|
|
|
|
CTCPSocket()
|
|
{
|
|
m_Socket = INVALID_SOCKET;
|
|
m_bConnected = false;
|
|
|
|
m_hIOCP = NULL;
|
|
|
|
m_bShouldExitThreads = false;
|
|
m_bConnectionLost = false;
|
|
m_nSizeBytesReceived = 0;
|
|
|
|
m_pIncomingData = NULL;
|
|
|
|
memset( &m_RecvOverlapped, 0, sizeof( m_RecvOverlapped ) );
|
|
m_RecvOverlapped.m_OPType = OP_RECV;
|
|
|
|
m_hRecvSignal = CreateEvent( NULL, FALSE, FALSE, NULL );
|
|
m_RecvStage = -1;
|
|
|
|
m_MainThreadID = GetCurrentThreadId();
|
|
}
|
|
|
|
virtual ~CTCPSocket()
|
|
{
|
|
Term();
|
|
CloseHandle( m_hRecvSignal );
|
|
}
|
|
|
|
void Term()
|
|
{
|
|
Assert( GetCurrentThreadId() == m_MainThreadID );
|
|
|
|
RemoveGlobalTCPSocket( this );
|
|
|
|
if ( m_Socket != SOCKET_ERROR && !m_bConnectionLost )
|
|
{
|
|
SendDisconnectSentinel();
|
|
|
|
// Give the sends a second to complete. SO_LINGER is having trouble for some reason.
|
|
WaitForSendsToComplete( 1 );
|
|
}
|
|
|
|
|
|
StopThreads();
|
|
|
|
if ( m_Socket != INVALID_SOCKET )
|
|
{
|
|
closesocket( m_Socket );
|
|
m_Socket = INVALID_SOCKET;
|
|
}
|
|
|
|
if ( m_hIOCP )
|
|
{
|
|
CloseHandle( m_hIOCP );
|
|
m_hIOCP = NULL;
|
|
}
|
|
|
|
m_bConnected = false;
|
|
m_bConnectionLost = true;
|
|
m_RecvStage = -1;
|
|
|
|
FOR_EACH_LL( m_SendBufs, i )
|
|
{
|
|
SendBuf_t *pSendBuf = m_SendBufs[i];
|
|
ParanoidMemoryCheck( pSendBuf );
|
|
free( pSendBuf );
|
|
}
|
|
m_SendBufs.Purge();
|
|
|
|
FOR_EACH_LL( m_RecvDatas, j )
|
|
{
|
|
CRecvData *pRecvData = m_RecvDatas[j];
|
|
ParanoidMemoryCheck( pRecvData );
|
|
free( pRecvData );
|
|
}
|
|
m_RecvDatas.Purge();
|
|
|
|
if ( m_pIncomingData )
|
|
{
|
|
ParanoidMemoryCheck( m_pIncomingData );
|
|
free( m_pIncomingData );
|
|
m_pIncomingData = 0;
|
|
}
|
|
}
|
|
|
|
virtual void Release()
|
|
{
|
|
delete this;
|
|
}
|
|
|
|
|
|
void ParanoidMemoryCheck( void *ptr = NULL )
|
|
{
|
|
#if defined( PARANOID )
|
|
Assert( _CrtIsValidHeapPointer( this ) );
|
|
|
|
if ( ptr )
|
|
{
|
|
Assert( _CrtIsValidHeapPointer( ptr ) );
|
|
}
|
|
|
|
Assert( _CrtCheckMemory() == TRUE );
|
|
#endif
|
|
}
|
|
|
|
|
|
virtual bool BindToAny( const unsigned short port )
|
|
{
|
|
Term();
|
|
|
|
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
|
|
m_Socket = TCPBind( &addr );
|
|
if ( m_Socket == INVALID_SOCKET )
|
|
{
|
|
return false;
|
|
}
|
|
else
|
|
{
|
|
SetInitialSocketOptions();
|
|
return true;
|
|
}
|
|
}
|
|
|
|
|
|
// Set the initial socket options that we want.
|
|
void SetInitialSocketOptions()
|
|
{
|
|
// Set nodelay to improve latency.
|
|
BOOL val = TRUE;
|
|
setsockopt( m_Socket, IPPROTO_TCP, TCP_NODELAY, (const char FAR *)&val, sizeof(BOOL) );
|
|
|
|
// Make it linger for 3 seconds when it exits.
|
|
LINGER linger;
|
|
linger.l_onoff = 1;
|
|
linger.l_linger = 3;
|
|
setsockopt( m_Socket, SOL_SOCKET, SO_LINGER, (char*)&linger, sizeof( linger ) );
|
|
}
|
|
|
|
|
|
// Called only by main thread interface functions.
|
|
// Returns true if the connection is lost.
|
|
bool CheckConnectionLost()
|
|
{
|
|
Assert( GetCurrentThreadId() == m_MainThreadID );
|
|
|
|
if ( m_Socket == SOCKET_ERROR )
|
|
return true;
|
|
|
|
// Have we timed out?
|
|
if ( g_bEnableTCPTimeout && (Plat_FloatTime() - m_LastRecvTime > KEEPALIVE_TIMEOUT_SECONDS) )
|
|
{
|
|
SetConnectionLost( "Connection timed out." );
|
|
}
|
|
|
|
// Has any thread posted that the connection has been lost?
|
|
CCriticalSectionLock postLock( &m_ConnectionLostCS );
|
|
postLock.Lock();
|
|
if ( m_bConnectionLost )
|
|
{
|
|
Term();
|
|
return true;
|
|
}
|
|
else
|
|
{
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Called by any thread. All interface functions call CheckConnectionLost() and return errors if it's lost.
|
|
void SetConnectionLost( const char *pErrorString, int err = -1 )
|
|
{
|
|
CCriticalSectionLock postLock( &m_ConnectionLostCS );
|
|
postLock.Lock();
|
|
m_bConnectionLost = true;
|
|
postLock.Unlock();
|
|
|
|
// Handle it right away if we're in the main thread. If we're in an IO thread,
|
|
// it has to wait until the next interface function calls CheckConnectionLost().
|
|
if ( GetCurrentThreadId() == m_MainThreadID )
|
|
{
|
|
Term();
|
|
}
|
|
|
|
if ( pErrorString )
|
|
{
|
|
m_ErrorString.CopyArray( pErrorString, strlen( pErrorString ) + 1 );
|
|
}
|
|
else
|
|
{
|
|
char *lpMsgBuf;
|
|
FormatMessage(
|
|
FORMAT_MESSAGE_ALLOCATE_BUFFER |
|
|
FORMAT_MESSAGE_FROM_SYSTEM |
|
|
FORMAT_MESSAGE_IGNORE_INSERTS,
|
|
NULL,
|
|
err,
|
|
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), // Default language
|
|
(LPTSTR) &lpMsgBuf,
|
|
0,
|
|
NULL
|
|
);
|
|
|
|
m_ErrorString.CopyArray( lpMsgBuf, strlen( lpMsgBuf ) + 1 );
|
|
LocalFree( lpMsgBuf );
|
|
}
|
|
}
|
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
// The receive code.
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
|
|
virtual bool StartWaitingForSize( bool bFresh )
|
|
{
|
|
Assert( m_Socket != INVALID_SOCKET );
|
|
Assert( m_bConnected );
|
|
|
|
m_RecvStage = 0;
|
|
m_RecvDataSize = -1;
|
|
if ( bFresh )
|
|
m_nSizeBytesReceived = 0;
|
|
|
|
DWORD dwNumBytesReceived = 0;
|
|
WSABUF buf = { sizeof( &m_RecvDataSize ) - m_nSizeBytesReceived, ((char*)&m_RecvDataSize) + m_nSizeBytesReceived };
|
|
DWORD dwFlags = 0;
|
|
|
|
int status = WSARecv(
|
|
m_Socket,
|
|
&buf,
|
|
1,
|
|
&dwNumBytesReceived,
|
|
&dwFlags,
|
|
&m_RecvOverlapped,
|
|
NULL );
|
|
|
|
int err = -1;
|
|
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
|
|
{
|
|
SetConnectionLost( NULL, err );
|
|
return false;
|
|
}
|
|
else
|
|
{
|
|
return true;
|
|
}
|
|
}
|
|
|
|
|
|
bool PostNextDataPart()
|
|
{
|
|
DWORD dwNumBytesReceived = 0;
|
|
WSABUF buf = { m_RecvDataSize - m_AmountReceived, (char*)m_pIncomingData->m_Data + m_AmountReceived };
|
|
DWORD dwFlags = 0;
|
|
|
|
int status = WSARecv(
|
|
m_Socket,
|
|
&buf,
|
|
1,
|
|
&dwNumBytesReceived,
|
|
&dwFlags,
|
|
&m_RecvOverlapped,
|
|
NULL );
|
|
|
|
int err = -1;
|
|
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
|
|
{
|
|
SetConnectionLost( NULL, err );
|
|
return false;
|
|
}
|
|
else
|
|
{
|
|
return true;
|
|
}
|
|
}
|
|
|
|
|
|
bool StartWaitingForData()
|
|
{
|
|
Assert( m_Socket != INVALID_SOCKET );
|
|
Assert( m_RecvStage == 0 );
|
|
Assert( m_bConnected );
|
|
Assert( m_RecvDataSize > 0 );
|
|
|
|
m_RecvStage = 1;
|
|
|
|
// Add a CRecvData element.
|
|
ParanoidMemoryCheck();
|
|
m_pIncomingData = (CRecvData*)malloc( sizeof( CRecvData ) - 1 + m_RecvDataSize );
|
|
if ( !m_pIncomingData )
|
|
{
|
|
char str[512];
|
|
_snprintf( str, sizeof( str ), "malloc() failed. m_RecvDataSize = %d\n", m_RecvDataSize );
|
|
SetConnectionLost( str );
|
|
return false;
|
|
}
|
|
|
|
m_pIncomingData->m_Count = m_RecvDataSize;
|
|
|
|
m_AmountReceived = 0;
|
|
|
|
return PostNextDataPart();
|
|
}
|
|
|
|
virtual bool Recv( CUtlVector<unsigned char> &data, double flTimeout )
|
|
{
|
|
if ( CheckConnectionLost() )
|
|
return false;
|
|
|
|
// Wait in 50ms chunks, checking for disconnections along the way.
|
|
bool bGotData = false;
|
|
DWORD msToWait = (DWORD)( flTimeout * 1000.0 );
|
|
do
|
|
{
|
|
DWORD curWaitTime = min( msToWait, 50 );
|
|
DWORD ret = WaitForSingleObject( m_hRecvSignal, curWaitTime );
|
|
if ( ret == WAIT_OBJECT_0 )
|
|
{
|
|
bGotData = true;
|
|
break;
|
|
}
|
|
|
|
// Did the connection timeout?
|
|
if ( CheckConnectionLost() )
|
|
return false;
|
|
|
|
msToWait -= curWaitTime;
|
|
} while ( msToWait );
|
|
|
|
// If we never got a WAIT_OBJECT_0, then we never received anything.
|
|
if ( !bGotData )
|
|
return false;
|
|
|
|
|
|
CCriticalSectionLock csLock( &m_RecvDataCS );
|
|
csLock.Lock();
|
|
|
|
// Pickup the head m_RecvDatas element.
|
|
CRecvData *pRecvData = m_RecvDatas[ m_RecvDatas.Head() ];
|
|
data.CopyArray( pRecvData->m_Data, pRecvData->m_Count );
|
|
|
|
// Now free it.
|
|
m_RecvDatas.Remove( m_RecvDatas.Head() );
|
|
ParanoidMemoryCheck( pRecvData );
|
|
free( pRecvData );
|
|
|
|
// Set the event again for the next time around, if there is more data waiting.
|
|
if ( m_RecvDatas.Count() > 0 )
|
|
SetEvent( m_hRecvSignal );
|
|
|
|
return true;
|
|
}
|
|
|
|
// INSIDE IO THREAD.
|
|
void HandleRecvCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
|
|
{
|
|
if ( dwNumBytes == 0 )
|
|
{
|
|
SetConnectionLost( "Got 0 bytes in HandleRecvCompletion" );
|
|
return;
|
|
}
|
|
|
|
m_LastRecvTime = Plat_FloatTime();
|
|
if ( m_RecvStage == 0 )
|
|
{
|
|
m_nSizeBytesReceived += dwNumBytes;
|
|
if ( m_nSizeBytesReceived == sizeof( m_RecvDataSize ) )
|
|
{
|
|
// Size of -1 means the other size is breaking the connection.
|
|
if ( m_RecvDataSize == SENTINEL_DISCONNECT )
|
|
{
|
|
SetConnectionLost( "Got a graceful disconnect message." );
|
|
return;
|
|
}
|
|
else if ( m_RecvDataSize == SENTINEL_KEEPALIVE )
|
|
{
|
|
// No data follows this. Just let m_LastRecvTime get updated.
|
|
StartWaitingForSize( true );
|
|
return;
|
|
}
|
|
|
|
StartWaitingForData();
|
|
}
|
|
else if ( m_nSizeBytesReceived < sizeof( m_RecvDataSize ) )
|
|
{
|
|
// Handle the case where we only got some of the data (maybe one of the clients got disconnected).
|
|
StartWaitingForSize( false );
|
|
}
|
|
else
|
|
{
|
|
// This case should never ever happen!
|
|
#if defined( _DEBUG )
|
|
__asm int 3;
|
|
#endif
|
|
|
|
SetConnectionLost( "Received too much data in a packet!" );
|
|
return;
|
|
}
|
|
}
|
|
else if ( m_RecvStage == 1 )
|
|
{
|
|
// Got the data, make sure we got it all.
|
|
m_AmountReceived += dwNumBytes;
|
|
|
|
// Sanity check.
|
|
#if defined( _DEBUG )
|
|
Assert( m_RecvDataSize == m_pIncomingData->m_Count );
|
|
Assert( m_AmountReceived <= m_RecvDataSize ); // TODO: make this threadsafe for multiple IO threads.
|
|
#endif
|
|
|
|
if ( m_AmountReceived == m_RecvDataSize )
|
|
{
|
|
m_RecvStage = 2;
|
|
|
|
// Add the data to the list of packets waiting to be picked up.
|
|
CCriticalSectionLock csLock( &m_RecvDataCS );
|
|
csLock.Lock();
|
|
|
|
m_RecvDatas.AddToTail( m_pIncomingData );
|
|
m_pIncomingData = NULL;
|
|
|
|
if ( m_RecvDatas.Count() == 1 )
|
|
SetEvent( m_hRecvSignal ); // Notify the Recv() function.
|
|
|
|
StartWaitingForSize( true );
|
|
}
|
|
else
|
|
{
|
|
PostNextDataPart();
|
|
}
|
|
}
|
|
else
|
|
{
|
|
Assert( false );
|
|
}
|
|
}
|
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
// The send code.
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
|
|
virtual void WaitForSendsToComplete( double flTimeout )
|
|
{
|
|
CWaitTimer waitTimer( flTimeout );
|
|
while ( 1 )
|
|
{
|
|
CCriticalSectionLock sendBufLock( &m_SendCS );
|
|
sendBufLock.Lock();
|
|
if( m_SendBufs.Count() == 0 )
|
|
return;
|
|
sendBufLock.Unlock();
|
|
|
|
if ( waitTimer.ShouldKeepWaiting() )
|
|
Sleep( 10 );
|
|
else
|
|
break;
|
|
}
|
|
}
|
|
|
|
|
|
// This is called in the keepalive thread.
|
|
void SendKeepalive()
|
|
{
|
|
// Send a message saying we're exiting.
|
|
ParanoidMemoryCheck();
|
|
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
|
|
if ( !pBuf )
|
|
{
|
|
SetConnectionLost( "malloc() in SendKeepalive() failed." );
|
|
return;
|
|
}
|
|
|
|
pBuf->m_DataLength = sizeof( int );
|
|
*((int*)pBuf->m_Data) = SENTINEL_KEEPALIVE;
|
|
InternalSendDataBuf( pBuf );
|
|
}
|
|
|
|
|
|
void SendDisconnectSentinel()
|
|
{
|
|
// Send a message saying we're exiting.
|
|
ParanoidMemoryCheck();
|
|
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + sizeof( int ) );
|
|
if ( pBuf )
|
|
{
|
|
pBuf->m_DataLength = sizeof( int );
|
|
*((int*)pBuf->m_Data) = SENTINEL_DISCONNECT; // This signifies that we're exiting.
|
|
InternalSendDataBuf( pBuf );
|
|
}
|
|
}
|
|
|
|
|
|
virtual bool Send( const void *pData, int len )
|
|
{
|
|
const void *pChunks[1] = { pData };
|
|
int chunkLengths[1] = { len };
|
|
return SendChunks( pChunks, chunkLengths, 1 );
|
|
}
|
|
|
|
|
|
virtual bool SendChunks( void const * const *pChunks, const int *pChunkLengths, int nChunks )
|
|
{
|
|
if ( CheckConnectionLost() )
|
|
return false;
|
|
|
|
CChunkWalker walker( pChunks, pChunkLengths, nChunks );
|
|
int totalLength = walker.GetTotalLength();
|
|
|
|
if ( !totalLength )
|
|
return true;
|
|
|
|
// Create a buffer to hold the data and copy the data in.
|
|
ParanoidMemoryCheck();
|
|
SendBuf_t *pBuf = (SendBuf_t*)malloc( sizeof( SendBuf_t ) - 1 + totalLength + sizeof( int ) );
|
|
if ( !pBuf )
|
|
{
|
|
char str[512];
|
|
_snprintf( str, sizeof( str ), "malloc() in SendChunks() failed. totalLength = %d.", totalLength );
|
|
SetConnectionLost( str );
|
|
return false;
|
|
}
|
|
|
|
pBuf->m_DataLength = totalLength + sizeof( int );
|
|
|
|
int *pByteCountPos = (int*)pBuf->m_Data;
|
|
*pByteCountPos = totalLength;
|
|
|
|
char *pDataPos = &pBuf->m_Data[ sizeof( int ) ];
|
|
walker.CopyTo( pDataPos, totalLength );
|
|
|
|
int status = InternalSendDataBuf( pBuf );
|
|
int err = -1;
|
|
if ( status == SOCKET_ERROR && (err = WSAGetLastError()) != ERROR_IO_PENDING )
|
|
{
|
|
SetConnectionLost( NULL, err );
|
|
return false;
|
|
}
|
|
else
|
|
{
|
|
return true;
|
|
}
|
|
}
|
|
|
|
|
|
int InternalSendDataBuf( SendBuf_t *pBuf )
|
|
{
|
|
// Protect against interference from the keepalive thread.
|
|
CCriticalSectionLock csLock( &m_SendCS );
|
|
csLock.Lock();
|
|
|
|
|
|
pBuf->m_Overlapped.m_OPType = OP_SEND;
|
|
pBuf->m_Overlapped.hEvent = NULL;
|
|
|
|
// Add it to our list of buffers.
|
|
pBuf->m_Index = m_SendBufs.AddToTail( pBuf );
|
|
|
|
// Tell Winsock to send it.
|
|
WSABUF buf = { pBuf->m_DataLength, pBuf->m_Data };
|
|
|
|
DWORD dwNumBytesSent = 0;
|
|
return WSASend(
|
|
m_Socket,
|
|
&buf,
|
|
1,
|
|
&dwNumBytesSent,
|
|
0,
|
|
&pBuf->m_Overlapped,
|
|
NULL );
|
|
}
|
|
|
|
|
|
// INSIDE IO THREAD.
|
|
void HandleSendCompletion( COverlappedPlus *pInfo, DWORD dwNumBytes )
|
|
{
|
|
if ( dwNumBytes == 0 )
|
|
{
|
|
SetConnectionLost( "0 bytes in HandleSendCompletion." );
|
|
return;
|
|
}
|
|
|
|
// Just free the buffer.
|
|
SendBuf_t *pBuf = (SendBuf_t*)pInfo;
|
|
Assert( dwNumBytes == (DWORD)pBuf->m_DataLength );
|
|
|
|
CCriticalSectionLock sendBufLock( &m_SendCS );
|
|
sendBufLock.Lock();
|
|
m_SendBufs.Remove( pBuf->m_Index );
|
|
sendBufLock.Unlock();
|
|
|
|
ParanoidMemoryCheck( pBuf );
|
|
free( pBuf );
|
|
}
|
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
// The connect code.
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
|
|
virtual bool BeginConnect( const CIPAddr &inputAddr )
|
|
{
|
|
sockaddr_in addr;
|
|
IPAddrToSockAddr( &inputAddr, &addr );
|
|
|
|
m_bConnected = false;
|
|
int ret = connect( m_Socket, (struct sockaddr*)&addr, sizeof( addr ) );
|
|
ret=ret;
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
virtual bool UpdateConnect()
|
|
{
|
|
// We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
|
|
fd_set writeSet;
|
|
writeSet.fd_count = 1;
|
|
writeSet.fd_array[0] = m_Socket;
|
|
TIMEVAL timeVal = SetupTimeVal( 0 );
|
|
|
|
// See if it has a packet waiting.
|
|
int status = select( 0, NULL, &writeSet, NULL, &timeVal );
|
|
if ( status > 0 )
|
|
{
|
|
SetupConnected();
|
|
return true;
|
|
}
|
|
|
|
return false;
|
|
}
|
|
|
|
|
|
void SetupConnected()
|
|
{
|
|
m_bConnected = true;
|
|
m_bConnectionLost = false;
|
|
m_LastRecvTime = Plat_FloatTime();
|
|
|
|
CreateThreads();
|
|
StartWaitingForSize( true );
|
|
AddGlobalTCPSocket( this );
|
|
}
|
|
|
|
|
|
virtual bool IsConnected()
|
|
{
|
|
CheckConnectionLost();
|
|
return m_bConnected;
|
|
}
|
|
|
|
|
|
virtual void GetDisconnectReason( CUtlVector<char> &reason )
|
|
{
|
|
reason = m_ErrorString;
|
|
}
|
|
|
|
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
// Threads code.
|
|
// -------------------------------------------------------------------------------------------------- //
|
|
|
|
// Create our IO Completion Port threads.
|
|
bool CreateThreads()
|
|
{
|
|
int nThreads = 1;
|
|
SetShouldExitThreads( false );
|
|
|
|
// Create our IO completion port and hook it to our socket.
|
|
m_hIOCP = CreateIoCompletionPort(
|
|
INVALID_HANDLE_VALUE, NULL, 0, 0);
|
|
|
|
m_hIOCP = CreateIoCompletionPort( (HANDLE)m_Socket, m_hIOCP, (unsigned long)this, nThreads );
|
|
|
|
for ( int i=0; i < nThreads; i++ )
|
|
{
|
|
DWORD dwThreadID = 0;
|
|
HANDLE hThread = CreateThread(
|
|
NULL,
|
|
0,
|
|
&CTCPSocket::StaticThreadFn,
|
|
this,
|
|
0,
|
|
&dwThreadID );
|
|
|
|
if ( hThread )
|
|
{
|
|
SetThreadPriority( hThread, THREAD_PRIORITY_ABOVE_NORMAL );
|
|
m_Threads.AddToTail( hThread );
|
|
}
|
|
else
|
|
{
|
|
StopThreads();
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
|
|
void StopThreads()
|
|
{
|
|
// Tell the threads to exit, then wait for them to do so.
|
|
SetShouldExitThreads( true );
|
|
WaitForMultipleObjects( m_Threads.Count(), m_Threads.Base(), TRUE, INFINITE );
|
|
|
|
for ( int i=0; i < m_Threads.Count(); i++ )
|
|
{
|
|
CloseHandle( m_Threads[i] );
|
|
}
|
|
m_Threads.Purge();
|
|
}
|
|
|
|
|
|
void SetShouldExitThreads( bool bShouldExit )
|
|
{
|
|
CCriticalSectionLock lock( &m_ThreadsCS );
|
|
lock.Lock();
|
|
m_bShouldExitThreads = bShouldExit;
|
|
}
|
|
|
|
|
|
bool ShouldExitThreads()
|
|
{
|
|
CCriticalSectionLock lock( &m_ThreadsCS );
|
|
lock.Lock();
|
|
|
|
bool bRet = m_bShouldExitThreads;
|
|
return bRet;
|
|
}
|
|
|
|
|
|
DWORD ThreadFn()
|
|
{
|
|
while ( 1 )
|
|
{
|
|
DWORD dwNumBytes = 0;
|
|
unsigned long pInputTCPSocket;
|
|
LPOVERLAPPED pOverlapped;
|
|
|
|
if ( GetQueuedCompletionStatus(
|
|
m_hIOCP, // the port we're listening on
|
|
&dwNumBytes, // # bytes received on the port
|
|
&pInputTCPSocket,// "completion key" = CTCPSocket*
|
|
&pOverlapped, // the overlapped info that was passed into AcceptEx, WSARecv, or WSASend.
|
|
100 // listen for 100ms at a time so we can exit gracefully when the socket is deleted.
|
|
) )
|
|
{
|
|
COverlappedPlus *pInfo = (COverlappedPlus*)pOverlapped;
|
|
ParanoidMemoryCheck( pInfo );
|
|
|
|
if ( pInfo->m_OPType == OP_RECV )
|
|
{
|
|
Assert( pInfo == &m_RecvOverlapped );
|
|
HandleRecvCompletion( pInfo, dwNumBytes );
|
|
}
|
|
else
|
|
{
|
|
Assert( pInfo->m_OPType == OP_SEND );
|
|
HandleSendCompletion( pInfo, dwNumBytes );
|
|
}
|
|
}
|
|
|
|
if ( ShouldExitThreads() )
|
|
break;
|
|
}
|
|
|
|
return 0;
|
|
}
|
|
|
|
|
|
static DWORD WINAPI StaticThreadFn( LPVOID pParameter )
|
|
{
|
|
return ((CTCPSocket*)pParameter)->ThreadFn();
|
|
}
|
|
|
|
|
|
|
|
private:
|
|
|
|
SOCKET m_Socket;
|
|
bool m_bConnected;
|
|
|
|
|
|
// m_RecvOverlapped is setup to first wait for the size, then the data.
|
|
// Then it is not posted until the app grabs the data.
|
|
HANDLE m_hRecvSignal; // Tells Recv() when we have data.
|
|
COverlappedPlus m_RecvOverlapped;
|
|
int m_RecvStage; // -1 = not initialized
|
|
// 0 = waiting for size
|
|
// 1 = waiting for data
|
|
// 2 = waiting for app to pickup the data
|
|
|
|
CUtlLinkedList<CRecvData*,int> m_RecvDatas; // The head element is the next one to be picked up.
|
|
CRecvData *m_pIncomingData; // The packet we're currently receiving.
|
|
CCriticalSection m_RecvDataCS; // This protects adds and removes in the list.
|
|
|
|
// These reference the element at the tail of m_RecvData. It is the current one getting
|
|
volatile int m_nSizeBytesReceived; // How much of m_RecvDataSize have we received yet?
|
|
int m_RecvDataSize; // this is received over the network
|
|
int m_AmountReceived; // How much we've received so far.
|
|
|
|
// Last time we received anything from this connection. Used to determine if the connection is
|
|
// still active.
|
|
double m_LastRecvTime;
|
|
|
|
|
|
// Outgoing send buffers.
|
|
CUtlLinkedList<SendBuf_t*,int> m_SendBufs;
|
|
CCriticalSection m_SendCS;
|
|
|
|
|
|
// All the threads waiting for IO.
|
|
CUtlVector<HANDLE> m_Threads;
|
|
HANDLE m_hIOCP;
|
|
|
|
// Used during shutdown.
|
|
volatile bool m_bShouldExitThreads;
|
|
CCriticalSection m_ThreadsCS;
|
|
|
|
// For debugging.
|
|
DWORD m_MainThreadID;
|
|
|
|
// Set by the main thread or IO threads to signal connection lost.
|
|
bool m_bConnectionLost;
|
|
CCriticalSection m_ConnectionLostCS;
|
|
|
|
// This is set when we get disconnected.
|
|
CUtlVector<char> m_ErrorString;
|
|
};
|
|
|
|
|
|
// ------------------------------------------------------------------------------------------ //
|
|
// ITCPListenSocket implementation.
|
|
// ------------------------------------------------------------------------------------------ //
|
|
|
|
class CTCPListenSocket : public ITCPListenSocket
|
|
{
|
|
public:
|
|
|
|
CTCPListenSocket()
|
|
{
|
|
m_Socket = INVALID_SOCKET;
|
|
}
|
|
|
|
|
|
virtual ~CTCPListenSocket()
|
|
{
|
|
if ( m_Socket != INVALID_SOCKET )
|
|
{
|
|
closesocket( m_Socket );
|
|
}
|
|
}
|
|
|
|
|
|
// The main function to create one of these suckers.
|
|
static ITCPListenSocket* Create( const unsigned short port, int nQueueLength )
|
|
{
|
|
CTCPListenSocket *pRet = new CTCPListenSocket;
|
|
if ( !pRet )
|
|
return NULL;
|
|
|
|
// Bind it to a socket and start listening.
|
|
CIPAddr addr( 0, 0, 0, 0, port ); // INADDR_ANY
|
|
pRet->m_Socket = TCPBind( &addr );
|
|
if ( pRet->m_Socket == INVALID_SOCKET ||
|
|
listen( pRet->m_Socket, nQueueLength == -1 ? SOMAXCONN : nQueueLength ) != 0 )
|
|
{
|
|
pRet->Release();
|
|
return false;
|
|
}
|
|
|
|
return pRet;
|
|
}
|
|
|
|
|
|
virtual void Release()
|
|
{
|
|
delete this;
|
|
}
|
|
|
|
|
|
virtual ITCPSocket* UpdateListen( CIPAddr *pAddr )
|
|
{
|
|
// We're still ok.. just wait until the socket becomes writable (is connected) or we timeout.
|
|
fd_set readSet;
|
|
readSet.fd_count = 1;
|
|
readSet.fd_array[0] = m_Socket;
|
|
TIMEVAL timeVal = SetupTimeVal( 0 );
|
|
|
|
// Wait until it connects.
|
|
int status = select( 0, &readSet, NULL, NULL, &timeVal );
|
|
if ( status > 0 )
|
|
{
|
|
sockaddr_in addr;
|
|
int addrSize = sizeof( addr );
|
|
|
|
// Now accept the final connection.
|
|
SOCKET newSock = accept( m_Socket, (struct sockaddr*)&addr, &addrSize );
|
|
if ( newSock == INVALID_SOCKET )
|
|
{
|
|
Assert( false );
|
|
}
|
|
else
|
|
{
|
|
CTCPSocket *pRet = new CTCPSocket;
|
|
if ( !pRet )
|
|
{
|
|
closesocket( newSock );
|
|
return NULL;
|
|
}
|
|
|
|
pRet->m_Socket = newSock;
|
|
pRet->SetInitialSocketOptions();
|
|
pRet->SetupConnected();
|
|
|
|
// Report the address..
|
|
SockAddrToIPAddr( &addr, pAddr );
|
|
|
|
return pRet;
|
|
}
|
|
}
|
|
|
|
return NULL;
|
|
}
|
|
|
|
|
|
private:
|
|
SOCKET m_Socket;
|
|
};
|
|
|
|
|
|
|
|
ITCPListenSocket* CreateTCPListenSocket( const unsigned short port, int nQueueLength )
|
|
{
|
|
return CTCPListenSocket::Create( port, nQueueLength );
|
|
}
|
|
|
|
|
|
ITCPSocket* CreateTCPSocket()
|
|
{
|
|
return new CTCPSocket;
|
|
}
|
|
|
|
|
|
void TCPSocket_EnableTimeout( bool bEnable )
|
|
{
|
|
g_bEnableTCPTimeout = bEnable;
|
|
}
|
|
|
|
|
|
// --------------------------------------------------------------------------------- //
|
|
// This thread sends keepalives on all active TCP sockets.
|
|
// --------------------------------------------------------------------------------- //
|
|
|
|
HANDLE g_hKeepaliveThread;
|
|
HANDLE g_hKeepaliveThreadSignal;
|
|
HANDLE g_hKeepaliveThreadReply;
|
|
CUtlLinkedList<CTCPSocket*,int> g_TCPSockets;
|
|
CCriticalSection g_TCPSocketsCS;
|
|
|
|
|
|
DWORD WINAPI TCPKeepaliveThread( LPVOID pParameter )
|
|
{
|
|
while ( 1 )
|
|
{
|
|
if ( WaitForSingleObject( g_hKeepaliveThreadSignal, KEEPALIVE_INTERVAL_MS ) == WAIT_OBJECT_0 )
|
|
break;
|
|
|
|
// Tell all TCP sockets to send a keepalive.
|
|
CCriticalSectionLock csLock( &g_TCPSocketsCS );
|
|
csLock.Lock();
|
|
|
|
FOR_EACH_LL( g_TCPSockets, i )
|
|
{
|
|
g_TCPSockets[i]->SendKeepalive();
|
|
}
|
|
}
|
|
|
|
SetEvent( g_hKeepaliveThreadReply );
|
|
return 0;
|
|
}
|
|
|
|
|
|
void AddGlobalTCPSocket( CTCPSocket *pSocket )
|
|
{
|
|
CCriticalSectionLock csLock( &g_TCPSocketsCS );
|
|
csLock.Lock();
|
|
|
|
Assert( g_TCPSockets.Find( pSocket ) == g_TCPSockets.InvalidIndex() );
|
|
g_TCPSockets.AddToTail( pSocket );
|
|
|
|
// If this is the first one, create the keepalive thread.
|
|
if ( g_TCPSockets.Count() == 1 )
|
|
{
|
|
g_hKeepaliveThreadSignal = CreateEvent( NULL, false, false, NULL );
|
|
g_hKeepaliveThreadReply = CreateEvent( NULL, false, false, NULL );
|
|
|
|
DWORD dwThreadID = 0;
|
|
g_hKeepaliveThread = CreateThread(
|
|
NULL,
|
|
0,
|
|
TCPKeepaliveThread,
|
|
NULL,
|
|
0,
|
|
&dwThreadID
|
|
);
|
|
}
|
|
}
|
|
|
|
|
|
void RemoveGlobalTCPSocket( CTCPSocket *pSocket )
|
|
{
|
|
bool bThreadRunning = false;
|
|
DWORD dwExitCode = 0;
|
|
if ( GetExitCodeThread( g_hKeepaliveThread, &dwExitCode ) && dwExitCode == STILL_ACTIVE )
|
|
{
|
|
bThreadRunning = true;
|
|
}
|
|
|
|
CCriticalSectionLock csLock( &g_TCPSocketsCS );
|
|
csLock.Lock();
|
|
|
|
int index = g_TCPSockets.Find( pSocket );
|
|
if ( index != g_TCPSockets.InvalidIndex() )
|
|
{
|
|
g_TCPSockets.Remove( index );
|
|
|
|
// If this was the last one, delete the thread.
|
|
if ( g_TCPSockets.Count() == 0 )
|
|
{
|
|
csLock.Unlock();
|
|
|
|
if ( bThreadRunning )
|
|
{
|
|
SetEvent( g_hKeepaliveThreadSignal );
|
|
WaitForSingleObject( g_hKeepaliveThreadReply, INFINITE );
|
|
}
|
|
|
|
CloseHandle( g_hKeepaliveThreadSignal );
|
|
CloseHandle( g_hKeepaliveThreadReply );
|
|
CloseHandle( g_hKeepaliveThread );
|
|
return;
|
|
}
|
|
}
|
|
|
|
csLock.Unlock();
|
|
}
|