/*==LICENSE==*

CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011  Cyan Worlds, Inc.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

You can contact Cyan Worlds, Inc. by email legal@cyan.com
 or by snail mail at:
      Cyan Worlds, Inc.
      14617 N Newport Hwy
      Mead, WA   99021

*==LICENSE==*/
/*****************************************************************************
*
*   $/Plasma20/Sources/Plasma/PubUtilLib/plNetGameLib/Private/plNglFile.cpp
*   
***/

#include "../Pch.h"
#pragma hdrstop

// Define this if the file servers are running behind load-balancing hardware.
// It changes the logic by which the decision to attempt a reconnect is made.
#define LOAD_BALANCER_HARDWARE


namespace Ngl { namespace File {
/*****************************************************************************
*
*   Private
*
***/

struct CliFileConn : AtomicRef {
	LINK(CliFileConn)	link;
	CLock				sockLock; // to protect the socket pointer so we don't nuke it while using it
	AsyncSocket			sock;
	wchar				name[MAX_PATH];
	NetAddress			addr;
	unsigned			seq;
	ARRAY(byte)			recvBuffer;
	AsyncCancelId		cancelId;
	bool				abandoned;
	unsigned			buildId;
	unsigned			serverType;

	CCritSect			timerCritsect; // critsect for both timers

	// Reconnection
	AsyncTimer *		reconnectTimer;
	unsigned			reconnectStartMs;
	unsigned			connectStartMs;
	unsigned			numImmediateDisconnects;
	unsigned			numFailedConnects;

	// Ping
	AsyncTimer *		pingTimer;
	unsigned			pingSendTimeMs;
	unsigned			lastHeardTimeMs;

	CliFileConn ();
	~CliFileConn ();

	// This function should be called during object construction
	// to initiate connection attempts to the remote host whenever
	// the socket is disconnected.
	void AutoReconnect ();
	bool AutoReconnectEnabled () {return (reconnectTimer != nil);}
	void StopAutoReconnect (); // call before destruction
	void StartAutoReconnect ();
	void TimerReconnect ();

	// ping
	void AutoPing ();
	void StopAutoPing ();
	void TimerPing ();
	
	void Send (const void * data, unsigned bytes);

	void Destroy(); // cleans up the socket and buffer

	void Dispatch (const Cli2File_MsgHeader * msg);
	bool Recv_PingReply (const File2Cli_PingReply * msg);
	bool Recv_BuildIdReply (const File2Cli_BuildIdReply * msg);
	bool Recv_BuildIdUpdate (const File2Cli_BuildIdUpdate * msg);
	bool Recv_ManifestReply (const File2Cli_ManifestReply * msg);
	bool Recv_FileDownloadReply (const File2Cli_FileDownloadReply * msg);
};


//============================================================================
// BuildIdRequestTrans
//============================================================================
struct BuildIdRequestTrans : NetFileTrans {
	FNetCliFileBuildIdRequestCallback	m_callback;
	void *								m_param;

	unsigned							m_buildId;
	
	BuildIdRequestTrans (
		FNetCliFileBuildIdRequestCallback	callback,
		void *								param
	);

	bool Send ();
	void Post ();
	bool Recv (
		const byte	msg[],
		unsigned	bytes
	);
};

//============================================================================
// ManifestRequestTrans
//============================================================================
struct ManifestRequestTrans : NetFileTrans {
	FNetCliFileManifestRequestCallback	m_callback;
	void *								m_param;
	wchar								m_group[MAX_PATH];
	unsigned							m_buildId;

	ARRAY(NetCliFileManifestEntry)		m_manifest;
	unsigned							m_numEntriesReceived;

	ManifestRequestTrans (
		FNetCliFileManifestRequestCallback	callback,
		void *								param,
		const wchar							group[],
		unsigned							buildId
	);

	bool Send ();
	void Post ();
	bool Recv (
		const byte	msg[],
		unsigned	bytes
	);
};

//============================================================================
// DownloadRequestTrans
//============================================================================
struct DownloadRequestTrans : NetFileTrans {
	FNetCliFileDownloadRequestCallback	m_callback;
	void *								m_param;

	wchar								m_filename[MAX_PATH];
	hsStream *							m_writer;
	unsigned							m_buildId;
	
	unsigned							m_totalBytesReceived;

	DownloadRequestTrans (
		FNetCliFileDownloadRequestCallback	callback,
		void *								param,
		const wchar							filename[],
		hsStream *							writer,
		unsigned							buildId
	);

	bool Send ();
	void Post ();
	bool Recv (
		const byte	msg[],
		unsigned	bytes
	);
};

//============================================================================
// RcvdFileDownloadChunkTrans
//============================================================================
struct RcvdFileDownloadChunkTrans : NetNotifyTrans {

	unsigned	bytes;
	byte *		data;
	hsStream *	writer;

	RcvdFileDownloadChunkTrans () : NetNotifyTrans (kFileRcvdFileDownloadChunkTrans) {}
	~RcvdFileDownloadChunkTrans ();
    void Post ();
};


/*****************************************************************************
*
*   Private data
*
***/

enum {
	kPerfConnCount,
	kNumPerf
};

static bool							s_running;
static CCritSect					s_critsect;
static LISTDECL(CliFileConn, link)	s_conns;
static CliFileConn *				s_active;
static long							s_perf[kNumPerf];
static unsigned						s_connectBuildId;
static unsigned						s_serverType;

static FNetCliFileBuildIdUpdateCallback	s_buildIdCallback = nil;

const unsigned kMinValidConnectionMs				= 25 * 1000;



/*****************************************************************************
*
*   Internal functions
*
***/

//===========================================================================
static unsigned GetNonZeroTimeMs () {
	if (unsigned ms = TimeGetMs())
		return ms;
	return 1;
}

//============================================================================
static CliFileConn * GetConnIncRef_CS (const char tag[]) {
	if (CliFileConn * conn = s_active) {
		conn->IncRef(tag);
		return conn;
	}
	return nil;
}

//============================================================================
static CliFileConn * GetConnIncRef (const char tag[]) {
	CliFileConn * conn;
	s_critsect.Enter();
	{
		conn = GetConnIncRef_CS(tag);
	}
	s_critsect.Leave();
	return conn;
}

//============================================================================
static void UnlinkAndAbandonConn_CS (CliFileConn * conn) {
	s_conns.Unlink(conn);
	conn->abandoned = true;

	if (conn->AutoReconnectEnabled())
		conn->StopAutoReconnect();

	bool needsDecref = true;
	if (conn->cancelId) {
		AsyncSocketConnectCancel(nil, conn->cancelId);
		conn->cancelId  = 0;
		needsDecref = false;
	}
	else {
		conn->sockLock.EnterRead();
		if (conn->sock) {
			AsyncSocketDisconnect(conn->sock, true);
			needsDecref = false;
		}
		conn->sockLock.LeaveRead();
	}
	if (needsDecref) {
		conn->DecRef("Lifetime");
	}
}

//============================================================================
static void NotifyConnSocketConnect (CliFileConn * conn) {

	conn->TransferRef("Connecting", "Connected");
	conn->connectStartMs = TimeGetMs();
    conn->numFailedConnects = 0;

    // Make this the active server
	s_critsect.Enter();
	{
		if (!conn->abandoned) {
			conn->AutoPing();
			s_active = conn;
		}
		else
		{
			conn->sockLock.EnterRead();
			AsyncSocketDisconnect(conn->sock, true);
			conn->sockLock.LeaveRead();
		}
	}
	s_critsect.Leave();
}

//============================================================================
static void NotifyConnSocketConnectFailed (CliFileConn * conn) {
    s_critsect.Enter();
    {
		conn->cancelId = 0;
        s_conns.Unlink(conn);

        if (conn == s_active)
            s_active = nil;
    }
    s_critsect.Leave();
    
    // Cancel all transactions in progress on this connection.
    NetTransCancelByConnId(conn->seq, kNetErrTimeout);
	
#ifndef SERVER
	// Client apps fail if unable to connect for a time
    if (++conn->numFailedConnects >= kMaxFailedConnects) {
		ReportNetError(kNetProtocolCli2File, kNetErrConnectFailed);
	}
	else
#endif // ndef SERVER
	{
		// start reconnect, if we are doing that
		if (s_running && conn->AutoReconnectEnabled())
			conn->StartAutoReconnect();
		else
			conn->DecRef("Lifetime"); // if we are not reconnecting, this socket is done, so remove the lifetime ref
	}
	conn->DecRef("Connecting");
}

//============================================================================
static void NotifyConnSocketDisconnect (CliFileConn * conn) {
	conn->StopAutoPing();
    s_critsect.Enter();
    {
		conn->cancelId = 0;
        s_conns.Unlink(conn);
			
        if (conn == s_active)
            s_active = nil;
    }
    s_critsect.Leave();

    // Cancel all transactions in progress on this connection.
    NetTransCancelByConnId(conn->seq, kNetErrTimeout);


	bool notify = false;

#ifdef SERVER
	{
		if (TimeGetMs() - conn->connectStartMs > kMinValidConnectionMs)
			conn->reconnectStartMs = 0;
		else
			conn->reconnectStartMs = GetNonZeroTimeMs() + kMaxReconnectIntervalMs;
	}
#else
	{
	#ifndef LOAD_BALANCER_HARDWARE
		// If the connection to the remote server was open for longer than
		// kMinValidConnectionMs then assume that the connection was to
		// a valid server and try to perform reconnection immediately. If
		// less time elapsed then the connection was likely to a server
		// with an open port but with no notification procedure registered
		// for this type of communication channel.
		if (TimeGetMs() - conn->connectStartMs > kMinValidConnectionMs) {
			conn->reconnectStartMs = 0;
		}
		else {
			if (++conn->numImmediateDisconnects < kMaxImmediateDisconnects)
				conn->reconnectStartMs = GetNonZeroTimeMs() + kMaxReconnectIntervalMs;
			else
				notify = true;
		}
	#else
		// File server is running behind a load-balancer, so the next connection may
		// send us to a new server, therefore attempt a reconnection to the same
		// address even if the disconnect was immediate.  This is safe because the
		// file server is stateless with respect to clients.
		if (TimeGetMs() - conn->connectStartMs <= kMinValidConnectionMs) {
			if (++conn->numImmediateDisconnects < kMaxImmediateDisconnects)
				conn->reconnectStartMs = GetNonZeroTimeMs() + kMaxReconnectIntervalMs;
			else
				notify = true;
		}
		else {
			// disconnect was not immediate. attempt a reconnect unless we're shutting down
			conn->numImmediateDisconnects = 0;
			conn->reconnectStartMs = 0;
		}
	#endif	// LOAD_BALANCER
	}
#endif // ndef SERVER

	if (notify) {
		ReportNetError(kNetProtocolCli2File, kNetErrDisconnected);
	}
	else {	
		// clean up the socket and start reconnect, if we are doing that
		conn->Destroy();
		if (conn->AutoReconnectEnabled())
			conn->StartAutoReconnect();
		else
			conn->DecRef("Lifetime"); // if we are not reconnecting, this socket is done, so remove the lifetime ref
	}

	conn->DecRef("Connected");
}

//============================================================================
static bool NotifyConnSocketRead (CliFileConn * conn, AsyncNotifySocketRead * read) {
	conn->lastHeardTimeMs = GetNonZeroTimeMs();
	conn->recvBuffer.Add(read->buffer, read->bytes);
	read->bytesProcessed += read->bytes;

	for (;;) {
		if (conn->recvBuffer.Count() < sizeof(dword))
			return true;

		dword msgSize = *(dword *)conn->recvBuffer.Ptr();
		if (conn->recvBuffer.Count() < msgSize)
			return true;

		const Cli2File_MsgHeader * msg = (const Cli2File_MsgHeader *) conn->recvBuffer.Ptr();
		conn->Dispatch(msg);

		conn->recvBuffer.Move(0, msgSize, conn->recvBuffer.Count() - msgSize);
		conn->recvBuffer.ShrinkBy(msgSize);
	}
}

//============================================================================
static bool SocketNotifyCallback (
	AsyncSocket			sock,
	EAsyncNotifySocket	code,
	AsyncNotifySocket *	notify,
		void **				userState
) {
	bool result = true;
	CliFileConn * conn;

	switch (code) {
		case kNotifySocketConnectSuccess:
            conn = (CliFileConn *) notify->param;
            *userState = conn;
            s_critsect.Enter();
            {
				conn->sockLock.EnterWrite();
				conn->sock		= sock;
				conn->sockLock.LeaveWrite();
				conn->cancelId	= 0;
            }
            s_critsect.Leave();
            NotifyConnSocketConnect(conn);
		break;

		case kNotifySocketConnectFailed:
			conn = (CliFileConn *) notify->param;
			NotifyConnSocketConnectFailed(conn);
		break;

		case kNotifySocketDisconnect:
			conn = (CliFileConn *) *userState;
			NotifyConnSocketDisconnect(conn);
		break;

		case kNotifySocketRead:
			conn = (CliFileConn *) *userState;
			result = NotifyConnSocketRead(conn, (AsyncNotifySocketRead *) notify);
		break;
	}

	return result;
}

//============================================================================
static void Connect (CliFileConn * conn) {
	ASSERT(s_running);

	conn->pingSendTimeMs = 0;

    s_critsect.Enter();
    {
		while (CliFileConn * oldConn = s_conns.Head()) {
			if (oldConn != conn)
				UnlinkAndAbandonConn_CS(oldConn);
			else
				s_conns.Unlink(oldConn);
		}
        s_conns.Link(conn);
    }
    s_critsect.Leave();

	Cli2File_Connect connect;
	connect.hdr.connType	= kConnTypeCliToFile;
	connect.hdr.hdrBytes	= sizeof(connect.hdr);
	connect.hdr.buildId		= kFileSrvBuildId;
	connect.hdr.buildType	= BuildType();
	connect.hdr.branchId	= BranchId();
	connect.hdr.productId	= ProductId();
	connect.data.buildId	= conn->buildId;
	connect.data.serverType = conn->serverType;
	connect.data.dataBytes	= sizeof(connect.data);

	AsyncSocketConnect(
		&conn->cancelId,
		conn->addr,
		SocketNotifyCallback,
		conn,
		&connect,
		sizeof(connect),
		0,
		0
	);
}

//============================================================================
static void Connect (
	const wchar			name[],
	const NetAddress &	addr
) {
	ASSERT(s_running);
	
    CliFileConn * conn = NEWZERO(CliFileConn);
    StrCopy(conn->name, name, arrsize(conn->name));
    conn->addr			= addr;
	conn->buildId		= s_connectBuildId;
	conn->serverType	= s_serverType;
    conn->seq			= ConnNextSequence();
	conn->lastHeardTimeMs	= GetNonZeroTimeMs();	// used in connect timeout, and ping timeout

    conn->IncRef("Lifetime");
	conn->AutoReconnect();
}

//============================================================================
static void AsyncLookupCallback (
	void *				param,
	const wchar			name[],
	unsigned			addrCount,
	const NetAddress	addrs[]
) {
    if (!addrCount) {
		ReportNetError(kNetProtocolCli2File, kNetErrNameLookupFailed);
		return;
	}

	for (unsigned i = 0; i < addrCount; ++i) {
		Connect(name, addrs[i]);
	}
}

/*****************************************************************************
*
*   CliFileConn
*
***/

//============================================================================
CliFileConn::CliFileConn () {
	AtomicAdd(&s_perf[kPerfConnCount], 1);
}

//============================================================================
CliFileConn::~CliFileConn () {
	ASSERT(!cancelId);
	ASSERT(!reconnectTimer);
	Destroy();
	AtomicAdd(&s_perf[kPerfConnCount], -1);
}

//===========================================================================
void CliFileConn::TimerReconnect () {
	ASSERT(!sock);
	ASSERT(!cancelId);
	
	if (!s_running) {
		s_critsect.Enter();
		UnlinkAndAbandonConn_CS(this);
		s_critsect.Leave();
	}
	else {
		IncRef("Connecting");

		// Remember the time we started the reconnect attempt, guarding against
		// TimeGetMs() returning zero (unlikely), as a value of zero indicates
		// a first-time connect condition to StartAutoReconnect()
		reconnectStartMs = GetNonZeroTimeMs();

		Connect(this);
	}
}

//===========================================================================
static unsigned CliFileConnTimerReconnectProc (void * param) {
	((CliFileConn *) param)->TimerReconnect();
	return kAsyncTimeInfinite;
}

//===========================================================================
// This function is called when after a disconnect to start a new connection
void CliFileConn::StartAutoReconnect () {
	timerCritsect.Enter();
	if (reconnectTimer) {
		// Make reconnect attempts at regular intervals. If the last attempt
		// took more than the specified max interval time then reconnect
		// immediately; otherwise wait until the time interval is up again
		// then reconnect.
		unsigned remainingMs = 0;
		if (reconnectStartMs) {
			remainingMs = reconnectStartMs - GetNonZeroTimeMs();
			if ((signed)remainingMs < 0)
				remainingMs = 0;
		}
		AsyncTimerUpdate(reconnectTimer, remainingMs);
	}
	timerCritsect.Leave();
}

//===========================================================================
// This function should be called during object construction
// to initiate connection attempts to the remote host whenever
// the socket is disconnected.
void CliFileConn::AutoReconnect () {
	timerCritsect.Enter();
	{
		ASSERT(!reconnectTimer);
		IncRef("ReconnectTimer");
		AsyncTimerCreate(
			&reconnectTimer,
			CliFileConnTimerReconnectProc,
			0,  // immediate callback
			this
		);
	}
	timerCritsect.Leave();
}

//===========================================================================
static unsigned CliFileConnTimerDestroyed (void * param) {
	CliFileConn * sock = (CliFileConn *) param;
	sock->DecRef("TimerDestroyed");
	return kAsyncTimeInfinite;
}

//============================================================================
void CliFileConn::StopAutoReconnect () {
	timerCritsect.Enter();
	{
		if (AsyncTimer * timer = reconnectTimer) {
			reconnectTimer = nil;
			AsyncTimerDeleteCallback(timer, CliFileConnTimerDestroyed);
		}
	}
	timerCritsect.Leave();
}

//===========================================================================
static unsigned CliFileConnPingTimerProc (void * param) {
	((CliFileConn *) param)->TimerPing();
	return kPingIntervalMs;
}

//============================================================================
void CliFileConn::AutoPing () {
	ASSERT(!pingTimer);
	IncRef("PingTimer");
	timerCritsect.Enter();
	{
		sockLock.EnterRead();
		unsigned timerPeriod = sock ? 0 : kAsyncTimeInfinite;
		sockLock.LeaveRead();

		AsyncTimerCreate(
			&pingTimer,
			CliFileConnPingTimerProc,
			timerPeriod,
			this
		);
	}
	timerCritsect.Leave();
}

//============================================================================
void CliFileConn::StopAutoPing () {
	timerCritsect.Enter();
	{
		if (AsyncTimer * timer = pingTimer) {
			pingTimer = nil;
			AsyncTimerDeleteCallback(timer, CliFileConnTimerDestroyed);
		}
	}
	timerCritsect.Leave();
}

//============================================================================
void CliFileConn::TimerPing () {
	sockLock.EnterRead();
	for (;;) {
		if (!sock) // make sure it exists
			break;
#if 0
		// if the time difference between when we last sent a ping and when we last
		// heard from the server is >= 3x the ping interval, the socket is stale.
		if (pingSendTimeMs && abs(int(pingSendTimeMs - lastHeardTimeMs)) >= kPingTimeoutMs) {
			// ping timed out, disconnect the socket
			AsyncSocketDisconnect(sock, true);
		}
		else
#endif
		{
			// Send a ping request
			pingSendTimeMs = GetNonZeroTimeMs();

			Cli2File_PingRequest msg;
			msg.messageId = kCli2File_PingRequest;
			msg.messageBytes = sizeof(msg);
			msg.pingTimeMs = pingSendTimeMs;

			// read locks are reentrant, so calling Send is ok here within the read lock
			Send(&msg, msg.messageBytes);
		}
		break;
	}
	sockLock.LeaveRead();
}

//============================================================================
void CliFileConn::Destroy () {
	AsyncSocket oldSock = nil;

	sockLock.EnterWrite();
	{
		SWAP(oldSock, sock);
	}
	sockLock.LeaveWrite();

	if (oldSock)
		AsyncSocketDelete(oldSock);
	recvBuffer.Clear();
}

//============================================================================
void CliFileConn::Send (const void * data, unsigned bytes) {
	sockLock.EnterRead();
	if (sock) {
		AsyncSocketSend(sock, data, bytes);
	}
	sockLock.LeaveRead();
}

//============================================================================
void CliFileConn::Dispatch (const Cli2File_MsgHeader * msg) {

#define DISPATCH(a) case kFile2Cli_##a: Recv_##a((const File2Cli_##a *) msg); break
	switch (msg->messageId) {
		DISPATCH(PingReply);
		DISPATCH(BuildIdReply);
		DISPATCH(BuildIdUpdate);
		DISPATCH(ManifestReply);
		DISPATCH(FileDownloadReply);
		DEFAULT_FATAL(msg->messageId)
	}
#undef DISPATCH
}

//============================================================================
bool CliFileConn::Recv_PingReply (
	const File2Cli_PingReply * msg
) {
	return true;
}

//============================================================================
bool CliFileConn::Recv_BuildIdReply (
	const File2Cli_BuildIdReply * msg
) {
	NetTransRecv(msg->transId, (const byte *)msg, msg->messageBytes);

	return true;
}

//============================================================================
bool CliFileConn::Recv_BuildIdUpdate (
	const File2Cli_BuildIdUpdate * msg
) {
	if (s_buildIdCallback)
		s_buildIdCallback(msg->buildId);
	return true;
}

//============================================================================
bool CliFileConn::Recv_ManifestReply (
	const File2Cli_ManifestReply * msg
) {
	NetTransRecv(msg->transId, (const byte *)msg, msg->messageBytes);

	return true;
}

//============================================================================
bool CliFileConn::Recv_FileDownloadReply (
	const File2Cli_FileDownloadReply * msg
) {
	NetTransRecv(msg->transId, (const byte *)msg, msg->messageBytes);

	return true;
}


/*****************************************************************************
*
*   BuildIdRequestTrans
*
***/

//============================================================================
BuildIdRequestTrans::BuildIdRequestTrans (
	FNetCliFileBuildIdRequestCallback	callback,
	void *								param
) : NetFileTrans(kBuildIdRequestTrans)
,	m_callback(callback)
,	m_param(param)
{}

//============================================================================
bool BuildIdRequestTrans::Send () {
	if (!AcquireConn())
		return false;

	Cli2File_BuildIdRequest buildIdReq;
	buildIdReq.messageId = kCli2File_BuildIdRequest;
	buildIdReq.transId = m_transId;
	buildIdReq.messageBytes = sizeof(buildIdReq);

	m_conn->Send(&buildIdReq, buildIdReq.messageBytes);	
	
	return true;
}

//============================================================================
void BuildIdRequestTrans::Post () {
	m_callback(m_result, m_param, m_buildId);
}

//============================================================================
bool BuildIdRequestTrans::Recv (
	const byte	msg[],
	unsigned	bytes
) {
	const File2Cli_BuildIdReply & reply = *(const File2Cli_BuildIdReply *) msg;

	if (IS_NET_ERROR(reply.result)) {
		// we have a problem...
		m_result	= reply.result;
		m_state		= kTransStateComplete;
		return true;
	}

	m_buildId = reply.buildId;

	// mark as complete
	m_result	= reply.result;
	m_state		= kTransStateComplete;

	return true;
}


/*****************************************************************************
*
*   ManifestRequestTrans
*
***/

//============================================================================
ManifestRequestTrans::ManifestRequestTrans (
	FNetCliFileManifestRequestCallback	callback,
	void *								param,
	const wchar							group[],
	unsigned							buildId
) : NetFileTrans(kManifestRequestTrans)
,	m_callback(callback)
,	m_param(param)
,	m_numEntriesReceived(0)
,	m_buildId(buildId)
{
	if (group)
		StrCopy(m_group, group, arrsize(m_group));
	else
		m_group[0] = L'\0';
}

//============================================================================
bool ManifestRequestTrans::Send () {
	if (!AcquireConn())
		return false;

	Cli2File_ManifestRequest manifestReq;
	StrCopy(manifestReq.group, m_group, arrsize(manifestReq.group));
	manifestReq.messageId = kCli2File_ManifestRequest;
	manifestReq.transId = m_transId;
	manifestReq.messageBytes = sizeof(manifestReq);
	manifestReq.buildId = m_buildId;

	m_conn->Send(&manifestReq, manifestReq.messageBytes);	

	return true;
}

//============================================================================
void ManifestRequestTrans::Post () {
	m_callback(m_result, m_param, m_group, m_manifest.Ptr(), m_manifest.Count());
}

//============================================================================
void ReadStringFromMsg(const wchar* curMsgPtr, wchar str[], unsigned maxStrLen, unsigned* length) {
	StrCopy(str, curMsgPtr, maxStrLen);
	str[maxStrLen - 1] = L'\0'; // make sure it's terminated

	(*length) = StrLen(str);
}

//============================================================================
void ReadUnsignedFromMsg(const wchar* curMsgPtr, unsigned* val) {
	(*val) = ((*curMsgPtr) << 16) + (*(curMsgPtr + 1));
}

//============================================================================
bool ManifestRequestTrans::Recv (
	const byte	msg[],
	unsigned	bytes
) {
	m_timeoutAtMs = TimeGetMs() + NetTransGetTimeoutMs(); // Reset the timeout counter

	const File2Cli_ManifestReply & reply = *(const File2Cli_ManifestReply *) msg;

	dword numFiles = reply.numFiles;
	dword wcharCount = reply.wcharCount;
	const wchar* curChar = reply.manifestData;

	// tell the server we got the data
	Cli2File_ManifestEntryAck manifestAck;
	manifestAck.messageId = kCli2File_ManifestEntryAck;
	manifestAck.transId = reply.transId;
	manifestAck.messageBytes = sizeof(manifestAck);
	manifestAck.readerId = reply.readerId;

	m_conn->Send(&manifestAck, manifestAck.messageBytes);	

	// if wcharCount is 2, the data only contains the terminator "\0\0" and we
	// don't need to convert anything (and we are done)
	if ((IS_NET_ERROR(reply.result)) || (wcharCount == 2)) {
		// we have a problem... or we have nothing to so, so we're done
		m_result	= reply.result;
		m_state		= kTransStateComplete;
		return true;
	}

	if (numFiles > m_manifest.Count())
		m_manifest.SetCount(numFiles); // reserve the space ahead of time

	// manifestData format: "clientFile\0downloadFile\0md5\0filesize\0zipsize\0flags\0...\0\0"
	bool done = false;
	while (!done) {
		if (wcharCount == 0)
		{
			done = true;
			break;
		}

		// copy the data over to our array (m_numEntriesReceived is the current index)
		NetCliFileManifestEntry& entry = m_manifest[m_numEntriesReceived];

		// --------------------------------------------------------------------
		// read in the clientFilename
		unsigned filenameLen;
		ReadStringFromMsg(curChar, entry.clientName, arrsize(entry.clientName), &filenameLen);
		curChar += filenameLen; // advance the pointer
		wcharCount -= filenameLen; // keep track of the amount remaining
		if ((*curChar != L'\0') || (wcharCount <= 0))
			return false; // something is screwy, abort and disconnect

		// point it at the downloadFile
		curChar++;
		wcharCount--;

		// --------------------------------------------------------------------
		// read in the downloadFilename
		ReadStringFromMsg(curChar, entry.downloadName, arrsize(entry.downloadName), &filenameLen);
		curChar += filenameLen; // advance the pointer
		wcharCount -= filenameLen; // keep track of the amount remaining
		if ((*curChar != L'\0') || (wcharCount <= 0))
			return false; // something is screwy, abort and disconnect

		// point it at the md5
		curChar++;
		wcharCount--;

		// --------------------------------------------------------------------
		// read in the md5
		ReadStringFromMsg(curChar, entry.md5, arrsize(entry.md5), &filenameLen);
		curChar += filenameLen; // advance the pointer
		wcharCount -= filenameLen; // keep track of the amount remaining
		if ((*curChar != L'\0') || (wcharCount <= 0))
			return false; // something is screwy, abort and disconnect

		// point it at the md5 for compressed files
		curChar++; 
		wcharCount--;

		// --------------------------------------------------------------------
		// read in the md5 for compressed files
		ReadStringFromMsg(curChar, entry.md5compressed, arrsize(entry.md5compressed), &filenameLen);
		curChar += filenameLen; // advance the pointer
		wcharCount -= filenameLen; // keep track of the amount remaining
		if ((*curChar != L'\0') || (wcharCount <= 0))
			return false; // something is screwy, abort and disconnect

		// point it at the first part of the filesize value (format: 0xHHHHLLLL)
		curChar++; 
		wcharCount--;

		// --------------------------------------------------------------------
		if (wcharCount < 2) // we have to have 2 chars for the size
			return false; // screwy data
		ReadUnsignedFromMsg(curChar, &entry.fileSize);
		curChar += 2;
		wcharCount -= 2;
		if ((*curChar != L'\0') || (wcharCount <= 0))
			return false; // screwy data

		// point it at the first part of the zipsize value (format: 0xHHHHLLLL)
		curChar++; 
		wcharCount--;

		// --------------------------------------------------------------------
		if (wcharCount < 2) // we have to have 2 chars for the size
			return false; // screwy data
		ReadUnsignedFromMsg(curChar, &entry.zipSize);
		curChar += 2;
		wcharCount -= 2;
		if ((*curChar != L'\0') || (wcharCount <= 0))
			return false; // screwy data

		// point it at the first part of the flags value (format: 0xHHHHLLLL)
		curChar++; 
		wcharCount--;

		// --------------------------------------------------------------------
		if (wcharCount < 2) // we have to have 2 chars for the size
			return false; // screwy data
		ReadUnsignedFromMsg(curChar, &entry.flags);
		curChar += 2;
		wcharCount -= 2;
		if ((*curChar != L'\0') || (wcharCount <= 0))
			return false; // screwy data

		// --------------------------------------------------------------------
		// point it at either the second part of the terminator, or the next filename
		curChar++;
		wcharCount--;

		// do sanity checking
		if (*curChar == L'\0') {
			// we hit the terminator
			if (wcharCount != 1)
				return false; // invalid data, we shouldn't have any more
			done = true; // we're done
		}
		else if (wcharCount < 14)
			// we must have at least three 1-char strings, three nulls, three 32-bit ints, and 2-char terminator left (3+3+6+2)
			return false; // screwy data

		// increment entries received
		m_numEntriesReceived++;
		if ((m_numEntriesReceived >= numFiles) && !done) {
			// too much data, abort
			return false;
		}
	}
	
	// check for completion
	if (m_numEntriesReceived >= numFiles)
	{
		// all entires received, mark as complete
		m_result	= reply.result;
		m_state		= kTransStateComplete;
	}
	return true;
}

/*****************************************************************************
*
*   FileDownloadRequestTrans
*
***/

//============================================================================
DownloadRequestTrans::DownloadRequestTrans (
	FNetCliFileDownloadRequestCallback	callback,
	void *								param,
	const wchar							filename[],
	hsStream *							writer,
	unsigned							buildId
) : NetFileTrans(kDownloadRequestTrans)
,	m_callback(callback)
,	m_param(param)
,	m_writer(writer)
,	m_totalBytesReceived(0)
,	m_buildId(buildId)
{
	StrCopy(m_filename, filename, arrsize(m_filename));
	// This transaction issues "sub transactions" which must complete
	// before this one even though they were issued after us.
	m_hasSubTrans = true;
}

//============================================================================
bool DownloadRequestTrans::Send () {
	if (!AcquireConn())
		return false;

	Cli2File_FileDownloadRequest filedownloadReq;
	StrCopy(filedownloadReq.filename, m_filename, arrsize(m_filename));
	filedownloadReq.messageId = kCli2File_FileDownloadRequest;
	filedownloadReq.transId = m_transId;
	filedownloadReq.messageBytes = sizeof(filedownloadReq);
	filedownloadReq.buildId = m_buildId;

	m_conn->Send(&filedownloadReq, sizeof(filedownloadReq));
	
	return true;
}

//============================================================================
void DownloadRequestTrans::Post () {
	m_callback(m_result, m_param, m_filename, m_writer);
}

//============================================================================
bool DownloadRequestTrans::Recv (
	const byte	msg[],
	unsigned	bytes
) {
	m_timeoutAtMs = TimeGetMs() + NetTransGetTimeoutMs(); // Reset the timeout counter

	const File2Cli_FileDownloadReply & reply = *(const File2Cli_FileDownloadReply *) msg;

	dword byteCount = reply.byteCount;
	const byte* data = reply.fileData;

	// tell the server we got the data
	Cli2File_FileDownloadChunkAck fileAck;
	fileAck.messageId = kCli2File_FileDownloadChunkAck;
	fileAck.transId = reply.transId;
	fileAck.messageBytes = sizeof(fileAck);
	fileAck.readerId = reply.readerId;

	m_conn->Send(&fileAck, fileAck.messageBytes);

	if (IS_NET_ERROR(reply.result)) {
		// we have a problem... indicate we are done and abort
		m_result	= reply.result;
		m_state		= kTransStateComplete;
		return true;
	}

	// we have data to write, so queue it for write in the main thread (we're
	// currently in a net recv thread)
	if (byteCount > 0) {
		RcvdFileDownloadChunkTrans * writeTrans = NEW(RcvdFileDownloadChunkTrans);
		writeTrans->writer	= m_writer;
		writeTrans->bytes	= byteCount;
		writeTrans->data	= (byte *)ALLOC(byteCount);
		MemCopy(writeTrans->data, data, byteCount);
		NetTransSend(writeTrans);
	}
	m_totalBytesReceived += byteCount;

	if (m_totalBytesReceived >= reply.totalFileSize) {
		// all bytes received, mark as complete
		m_result	= reply.result;
		m_state		= kTransStateComplete;
	}
	return true;
}

/*****************************************************************************
*
*   RcvdFileDownloadChunkTrans
*
***/

//============================================================================
RcvdFileDownloadChunkTrans::~RcvdFileDownloadChunkTrans () {
	FREE(data);
}

//============================================================================
void RcvdFileDownloadChunkTrans::Post () {
	writer->Write(bytes, data);
	m_result = kNetSuccess;
	m_state	 = kTransStateComplete;
}


} using namespace File;


/*****************************************************************************
*
*   NetFileTrans
*
***/

//============================================================================
NetFileTrans::NetFileTrans (ETransType transType)
:   NetTrans(kNetProtocolCli2File, transType)
,   m_conn(nil)
{
}

//============================================================================
NetFileTrans::~NetFileTrans () {
	ReleaseConn();
}

//============================================================================
bool NetFileTrans::AcquireConn () {
	if (!m_conn)
		m_conn = GetConnIncRef("AcquireConn");
	return m_conn != nil;
}

//============================================================================
void NetFileTrans::ReleaseConn () {
	if (m_conn) {
		m_conn->DecRef("AcquireConn");
		m_conn = nil;
	}
}


/*****************************************************************************
*
*   Protected functions
*
***/

//============================================================================
void FileInitialize () {
	s_running = true;
}

//============================================================================
void FileDestroy (bool wait) {
	s_running = false;

	NetTransCancelByProtocol(
		kNetProtocolCli2File,
		kNetErrRemoteShutdown
	);    
    NetMsgProtocolDestroy(
        kNetProtocolCli2File,
        false
    );

    s_critsect.Enter();
    {
		while (CliFileConn * conn = s_conns.Head())
			UnlinkAndAbandonConn_CS(conn);
		s_active = nil;
	}
    s_critsect.Leave();

	if (!wait)
		return;

	while (s_perf[kPerfConnCount]) {
		NetTransUpdate();
        AsyncSleep(10);
	}
}

//============================================================================
bool FileQueryConnected () {
	bool result;
	s_critsect.Enter();
	result = s_active != nil;
	s_critsect.Leave();
	return result;
}

//============================================================================
unsigned FileGetConnId () {
	unsigned connId;
	s_critsect.Enter();
	connId = (s_active) ? s_active->seq : 0;
	s_critsect.Leave();
	return connId;
}

} using namespace Ngl;

/*****************************************************************************
*
*   Exported functions
*
***/

//============================================================================
void NetCliFileStartConnect (
	const wchar *	fileAddrList[],
	unsigned		fileAddrCount,
	bool			isPatcher /* = false */
) {
	// TEMP: Only connect to one file server until we fill out this module
	// to choose the "best" file connection.
	fileAddrCount = min(fileAddrCount, 1);
	s_connectBuildId = isPatcher ? kFileSrvBuildId : BuildId();
	s_serverType = kSrvTypeNone;

	for (unsigned i = 0; i < fileAddrCount; ++i) {
		// Do we need to lookup the address?
		const wchar * name = fileAddrList[i];
		while (unsigned ch = *name) {
			++name;
			if (!(isdigit(ch) || ch == L'.' || ch == L':')) {
				AsyncCancelId cancelId;
				AsyncAddressLookupName(
					&cancelId,
					AsyncLookupCallback,
					fileAddrList[i],
					kNetDefaultClientPort,
					nil
				);
				break;
			}
		}
		if (!name[0]) {
			NetAddress addr;
			NetAddressFromString(&addr, fileAddrList[i], kNetDefaultClientPort);
			Connect(fileAddrList[i], addr);
		}
	}
}

//============================================================================
void NetCliFileStartConnectAsServer (
	const wchar *	fileAddrList[],
	unsigned		fileAddrCount,
	unsigned		serverType,
	unsigned		serverBuildId
) {
	// TEMP: Only connect to one file server until we fill out this module
	// to choose the "best" file connection.
	fileAddrCount = min(fileAddrCount, 1);
	s_connectBuildId = serverBuildId;
	s_serverType = serverType;

	for (unsigned i = 0; i < fileAddrCount; ++i) {
		// Do we need to lookup the address?
		const wchar * name = fileAddrList[i];
		while (unsigned ch = *name) {
			++name;
			if (!(isdigit(ch) || ch == L'.' || ch == L':')) {
				AsyncCancelId cancelId;
				AsyncAddressLookupName(
					&cancelId,
					AsyncLookupCallback,
					fileAddrList[i],
					kNetDefaultClientPort,
					nil
				);
				break;
			}
		}
		if (!name[0]) {
			NetAddress addr;
			NetAddressFromString(&addr, fileAddrList[i], kNetDefaultServerPort);
			Connect(fileAddrList[i], addr);
		}
	}
}

//============================================================================
void NetCliFileDisconnect () {
    s_critsect.Enter();
    {
		while (CliFileConn * conn = s_conns.Head())
			UnlinkAndAbandonConn_CS(conn);
		s_active = nil;
    }
    s_critsect.Leave();
}

//============================================================================
void NetCliFileBuildIdRequest (
	FNetCliFileBuildIdRequestCallback	callback,
	void *								param
) {
	BuildIdRequestTrans * trans = NEW(BuildIdRequestTrans)(
		callback,
		param
	);
	NetTransSend(trans);
}

//============================================================================
void NetCliFileRegisterBuildIdUpdate (FNetCliFileBuildIdUpdateCallback callback) {
	s_buildIdCallback = callback;
}

//============================================================================
void NetCliFileManifestRequest (
	FNetCliFileManifestRequestCallback	callback,
	void *								param,
	const wchar							group[],
	unsigned							buildId /* = 0 */
) {
	ManifestRequestTrans * trans = NEW(ManifestRequestTrans)(
		callback,
		param,
		group,
		buildId
	);
	NetTransSend(trans);
}

//============================================================================
void NetCliFileDownloadRequest (
	const wchar							filename[],
	hsStream *							writer,
	FNetCliFileDownloadRequestCallback	callback,
	void *								param,
	unsigned							buildId /* = 0 */
) {
	DownloadRequestTrans * trans = NEW(DownloadRequestTrans)(
		callback,
		param,
		filename,
		writer,
		buildId
	);
	NetTransSend(trans);
}