/*==LICENSE==*

CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011  Cyan Worlds, Inc.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

You can contact Cyan Worlds, Inc. by email legal@cyan.com
 or by snail mail at:
      Cyan Worlds, Inc.
      14617 N Newport Hwy
      Mead, WA   99021

*==LICENSE==*/
/*****************************************************************************
*
*   $/Plasma20/Sources/Plasma/NucleusLib/pnNetCli/pnNcCli.cpp
*   
***/

#include "Pch.h"
#pragma hdrstop

//#define NCCLI_DEBUGGING
#ifdef NCCLI_DEBUGGING
# pragma message("Compiling pnNetCli with debugging on")
# define NCCLI_LOG  LogMsg
#else
# define NCCLI_LOG  NULL_STMT
#endif

//#define NO_ENCRYPTION

namespace pnNetCli {

/*****************************************************************************
*
*   Private types and constants
*
***/

enum ENetCliMode {
    kNetCliModeServerStart,
    kNetCliModeClientStart,
    kNetCliModeEncrypted,
    kNumNetCliModes
};

} using namespace pnNetCli;


/*****************************************************************************
*
*   Opaque types
*
***/

// connection structure attached to each socket
struct NetCli : THashKeyVal<Uuid> {

    // communication channel
    AsyncSocket             sock;
    ENetProtocol            protocol;
    NetMsgChannel *         channel;
    bool                    server;

    // message queue    
    LINK(NetCli)            link;
    NetCliQueue *           queue;

    // message send/recv
    const NetMsgInitRecv *  recvMsg;
    const NetMsgField *     recvField;
    unsigned                recvFieldBytes;
    bool                    recvDispatch;
    byte *                  sendCurr;       // points into sendBuffer
    CInputAccumulator       input;

    // Message encryption
    ENetCliMode             mode;
    FNetCliEncrypt          encryptFcn;
    byte                    seed[kNetMaxSymmetricSeedBytes];
    CryptKey *              cryptIn;
    CryptKey *              cryptOut;
    void *                  encryptParam;

    // Message buffers
    byte                    sendBuffer[kAsyncSocketBufferSize];
    ARRAY(byte)             recvBuffer;
};

struct NetCliQueue {
    LISTDECL(NetCli, link)      list;
    unsigned                    lastSendMs;
    unsigned                    flushTimeMs;
};


namespace pnNetCli {

/*****************************************************************************
*
*   Private data
*
***/

 

/*****************************************************************************
*
*   Internal functions
*
***/

//============================================================================
static void PutBufferOnWire (NetCli * cli, void * data, unsigned bytes) {

        byte * temp, * heap = NULL;

    if (cli->mode == kNetCliModeEncrypted) {
        // Encrypt data...
#ifndef NO_ENCRYPTION
        if (bytes <= 2048)
            // byte count is small, use stack-based buffer
            temp = ALLOCA(byte, bytes);
        else
            // byte count is large, use heap-based buffer
            temp = heap = (byte *)ALLOC(bytes);

        MemCopy(temp, data, bytes);
        CryptEncrypt(cli->cryptOut, bytes, temp);
        data = temp;
#endif
    }
    if (cli->sock)
        AsyncSocketSend(cli->sock, data, bytes);
        
    // free heap buffer (if any)
    FREE(heap);
}

//============================================================================
static void FlushSendBuffer (NetCli * cli) {
    const unsigned bytes = cli->sendCurr - cli->sendBuffer;
    ASSERT(bytes <= arrsize(cli->sendBuffer));
    PutBufferOnWire(cli, cli->sendBuffer, bytes);
    cli->sendCurr = cli->sendBuffer;
}

//===========================================================================
static void AddToSendBuffer (
    NetCli *            cli,
    unsigned            bytes,
    void const * const  data
) {
    byte const * src = (byte const *) data;

    if (bytes > arrsize(cli->sendBuffer)) {
        // Let the OS fragment oversize buffers
        FlushSendBuffer(cli);
        void * heap = ALLOC(bytes);
        MemCopy(heap, data, bytes);
        PutBufferOnWire(cli, heap, bytes);
        FREE(heap);
    }
    else {
        for (;;) {
            // calculate the space left in the output buffer and use it
            // to determine the maximum number of bytes that will fit
            unsigned const left = &cli->sendBuffer[arrsize(cli->sendBuffer)] - cli->sendCurr;
            unsigned const copy = min(bytes, left);

            // copy the data into the buffer
            for (unsigned i = 0; i < copy; ++i)
                cli->sendCurr[i] = src[i];
            cli->sendCurr += copy;
            ASSERT(cli->sendCurr - cli->sendBuffer <= sizeof(cli->sendBuffer));

            // if we copied all the data then bail
            if (copy < left)
                break;

            src   += copy;
            bytes -= copy;

            FlushSendBuffer(cli);
        }
    }
}

//============================================================================
static void BufferedSendData (
    NetCli *            cli,
    const unsigned_ptr  msg[], 
    unsigned            fieldCount
) {
    #define ASSERT_MSG_VALID(expr)          \
        ASSERTMSG(expr, "Invalid message definition");

    #define WRITE_SWAPPED_INT(t,c) {        \
        ASSERT(sizeof(t) == sizeof(c));     \
        t endianCount = Endian((t)c);       \
        AddToSendBuffer(cli, sizeof(t), (const void *) &endianCount);   \
    }


    ASSERT(cli);
    ASSERT(msg);
    ASSERT(fieldCount);

    if (!cli->sock)
        return;

    unsigned_ptr const * const msgEnd = msg + fieldCount;

    const NetMsgInitSend * sendMsg = NetMsgChannelFindSendMessage(cli->channel, msg[0]);
    ASSERT(msg[0] == sendMsg->msg.messageId);
    ASSERT(fieldCount-1 == sendMsg->msg.count);

    // insert messageId into command stream
    const word msgId = (word) msg[0];
    WRITE_SWAPPED_INT(word, msgId);
    ++msg;
    ASSERT_MSG_VALID(msg < msgEnd);

    // insert fields into command stream
    dword varCount  = 0;
    dword varSize   = 0;
    const NetMsgField * cmd     = sendMsg->msg.fields;
    const NetMsgField * cmdEnd  = cmd + sendMsg->msg.count;
    for (; cmd < cmdEnd; ++msg, ++cmd) {
        switch (cmd->type) {
            case kNetMsgFieldInteger: {
                const unsigned count = cmd->count ? cmd->count : 1;
                const unsigned bytes = cmd->size * count;
                void * temp = ALLOCA(byte, bytes);
                
                if (count == 1)
                    // Single values are passed by value
                    EndianCopy(temp, (const byte *) msg, count, cmd->size);
                else
                    // Value arrays are passed in by ptr
                    EndianCopy(temp, (const byte *) *msg, count, cmd->size);
                
                // Write values to send buffer
                AddToSendBuffer(cli, bytes, temp);
            }
            break;

            case kNetMsgFieldReal: {
                const unsigned count = cmd->count ? cmd->count : 1;
                const unsigned bytes = cmd->size * count;
                
                if (count == 1)
                    // Single values are passed in by value
                    AddToSendBuffer(cli, bytes, (const void *) msg);
                else
                    // Value arrays are passed in by ptr
                    AddToSendBuffer(cli, bytes, (const void *) *msg);
            }
            break;

            case kNetMsgFieldString: {
                // Use less-than instead of less-or-equal because
                // we reserve one space for the NULL terminator
                const word length = (word) StrLen((const wchar *) *msg);
                ASSERT_MSG_VALID(length < cmd->count);
                // Write actual string length
                WRITE_SWAPPED_INT(word, length);
                // Write string data
                AddToSendBuffer(cli, length * sizeof(wchar), (const void *) *msg);
            }
            break;

            case kNetMsgFieldData:
            case kNetMsgFieldRawData: {
                // write values to send buffer
                AddToSendBuffer(cli, cmd->count * cmd->size, (const void *) *msg);
            }
            break;

            case kNetMsgFieldVarCount: {
                ASSERT(!varCount);
                ASSERT(!varSize);
                // remember the element size
                varSize  = cmd->size;
                // write the actual element count
                varCount = (dword) *msg;
                WRITE_SWAPPED_INT(dword, varCount);
            }
            break;

            case kNetMsgFieldVarPtr:
            case kNetMsgFieldRawVarPtr: {
                ASSERT(varSize);
                // write var sized array
                AddToSendBuffer(cli, varCount * varSize, (const void *) *msg);
                varCount    = 0;
                varSize     = 0;
            }
            break;

            case kNetMsgFieldPtr:
            case kNetMsgFieldRawPtr: {
                // write values
                AddToSendBuffer(cli, cmd->count * cmd->size, (const void *) *msg);
            }
            break;

            DEFAULT_FATAL(cmd->type);
        }
    }

    // prepare to flush this connection
    if (cli->queue)
        cli->queue->list.Link(cli);
}

//===========================================================================
static bool DispatchData (NetCli * cli, void * param) {

    word msgId = 0;
    while (!cli->input.Eof()) {
        // if we're not already decompressing a message, start new message
        if (!cli->recvMsg) {
            // get next message id
            if (!cli->input.Get(sizeof(msgId), &msgId))
                goto NEED_MORE_DATA;

            msgId = Endian(msgId);

            if (nil == (cli->recvMsg = NetMsgChannelFindRecvMessage(cli->channel, msgId)))
                goto ERR_NO_HANDLER;

            // prepare to start decompressing new fields
            ASSERT(!cli->recvField);
            ASSERT(!cli->recvFieldBytes);
            cli->recvField = cli->recvMsg->msg.fields;
            cli->recvBuffer.ZeroCount();
            cli->recvBuffer.Reserve(kAsyncSocketBufferSize);

            // store the message id as dword into the destination buffer
            dword * recvMsgId = (dword *) cli->recvBuffer.New(sizeof(dword));
            *recvMsgId = msgId;
        }

        for (
            const NetMsgField * end = cli->recvMsg->msg.fields + cli->recvMsg->msg.count;
            cli->recvField < end;
            ++cli->recvField
        ) {
            switch (cli->recvField->type) {
                case kNetMsgFieldInteger: {
                    const unsigned count
                        = cli->recvField->count
                        ? cli->recvField->count
                        : 1;

                    // Get integer values
                    const unsigned bytes = count * cli->recvField->size;
                    byte * data = cli->recvBuffer.New(bytes);
                    if (!cli->input.Get(bytes, data)) {
                        cli->recvBuffer.ShrinkBy(bytes);
                        goto NEED_MORE_DATA;
                    }

                    // Byte-swap integers
                    EndianConvert(
                        data,
                        count,
                        cli->recvField->size
                    );

                    // Field complete
                }
                break;

                case kNetMsgFieldReal: {
                    const unsigned count
                        = cli->recvField->count
                        ? cli->recvField->count
                        : 1;

                    // Get float values
                    const unsigned bytes = count * cli->recvField->size;
                    byte * data = cli->recvBuffer.New(bytes);
                    if (!cli->input.Get(bytes, data)) {
                        cli->recvBuffer.ShrinkBy(bytes);
                        goto NEED_MORE_DATA;
                    }

                    // Field complete
                }
                break;

                case kNetMsgFieldData:
                case kNetMsgFieldRawData: {
                    // Read fixed-length data into destination buffer
                    const unsigned bytes = cli->recvField->count * cli->recvField->size;
                    byte * data = cli->recvBuffer.New(bytes);
                    if (!cli->input.Get(bytes, data)) {
                        cli->recvBuffer.ShrinkBy(bytes);
                        goto NEED_MORE_DATA;
                    }

                    // Field complete
                }
                break;

                case kNetMsgFieldVarCount: {
                    // Read var count field into destination buffer
                    const unsigned bytes = sizeof(dword);
                    byte * data = cli->recvBuffer.New(bytes);
                    if (!cli->input.Get(bytes, data)) {
                        cli->recvBuffer.ShrinkBy(bytes);
                        goto NEED_MORE_DATA;
                    }

                    // Byte-swap value
                    EndianConvert((dword *) data, 1);

                    // Prepare to read var-length field
                    cli->recvFieldBytes = *(dword *)data * cli->recvField->size;

                    // Field complete
                }
                break;

                case kNetMsgFieldVarPtr:
                case kNetMsgFieldRawVarPtr: {
                    // Read var-length data into destination buffer
                    const unsigned bytes = cli->recvFieldBytes;
                    byte * data = cli->recvBuffer.New(bytes);
                    if (!cli->input.Get(bytes, data)) {
                        cli->recvBuffer.ShrinkBy(bytes);
                        goto NEED_MORE_DATA;
                    }

                    // Field complete
                    cli->recvFieldBytes = 0;
                }
                break;

                case kNetMsgFieldString: {
                    if (!cli->recvFieldBytes) {
                        // Read string length
                        word length;
                        if (!cli->input.Get(sizeof(word), &length))
                            goto NEED_MORE_DATA;
                        cli->recvFieldBytes = Endian(length) * sizeof(wchar);

                        // Validate size. Use >= instead of > to leave room for the NULL terminator.
                        if (cli->recvFieldBytes >= cli->recvField->count * cli->recvField->size)
                            goto ERR_BAD_COUNT;
                    }

                    const unsigned bytes = cli->recvField->count * cli->recvField->size;
                    byte * data = cli->recvBuffer.New(bytes);
                    // Read compressed string data (less than full field length)
                    if (!cli->input.Get(cli->recvFieldBytes, data)) {
                        cli->recvBuffer.ShrinkBy(bytes);
                        goto NEED_MORE_DATA;
                    }

                    // Insert NULL terminator
                    * (wchar *)(data + cli->recvFieldBytes) = 0;

                    // IDEA: fill the remainder with a freaky byte pattern

                    // Field complete
                    cli->recvFieldBytes = 0;
                }
                break;
            }
        }

        // dispatch message to handler function
        NCCLI_LOG(kLogPerf, L"pnNetCli: Dispatching. msg: %S. cli: %p", cli->recvMsg ? cli->recvMsg->msg.name : "(unknown)", cli);
        if (!cli->recvMsg->recv(cli->recvBuffer.Ptr(), cli->recvBuffer.Count(), param))
            goto ERR_DISPATCH_FAILED;
        
        // prepare to start next message
        cli->recvMsg        = nil;
        cli->recvField      = 0;
        cli->recvFieldBytes = 0;

        // Release oversize message buffer
        if (cli->recvBuffer.Count() > kAsyncSocketBufferSize)
            cli->recvBuffer.Clear();
    }

    return true;

// these are used for convenience in setting breakpoints
NEED_MORE_DATA:
    NCCLI_LOG(kLogPerf, L"pnNetCli: NEED_MORE_DATA. msg: %S (%u). cli: %p", cli->recvMsg ? cli->recvMsg->msg.name : "(unknown)", msgId, cli);
    return true;

ERR_BAD_COUNT:
    LogMsg(kLogError, L"pnNetCli: ERR_BAD_COUNT. msg: %S (%u). cli: %p", cli->recvMsg ? cli->recvMsg->msg.name : "(unknown)", msgId, cli);
    return false;

ERR_NO_HANDLER:
    LogMsg(kLogError, L"pnNetCli: ERR_NO_HANDLER. msg: %S (%u). cli: %p", cli->recvMsg ? cli->recvMsg->msg.name : "(unknown)", msgId, cli);
    return false;

ERR_DISPATCH_FAILED:
    LogMsg(kLogError, L"pnNetCli: ERR_DISPATCH_FAILED. msg: %S (%u). cli: %p", cli->recvMsg ? cli->recvMsg->msg.name : "(unknown)", msgId, cli);
    return false;
}


namespace Connect {
/*****************************************************************************
*
*   NetCli connect protocol
*
***/

#include <PshPack1.h>
enum {
    kNetCliCli2SrvConnect,
    kNetCliSrv2CliEncrypt,
    kNetCliSrv2CliError,
    kNumNetCliMsgs
};

struct NetCli_PacketHeader {
    byte    message;
    byte    length;
};

struct NetCli_Cli2Srv_Connect : NetCli_PacketHeader {
    byte    dh_y_data[kNetDiffieHellmanKeyBits / 8];
};

struct NetCli_Srv2Cli_Encrypt : NetCli_PacketHeader {
    byte    serverSeed[kNetMaxSymmetricSeedBytes];
};

struct NetCli_Srv2Cli_Error : NetCli_PacketHeader {
    dword   error;              // ENetError
};
#include <PopPack.h>


//===========================================================================
static void CreateSymmetricKey (
    unsigned        serverBytes,
    const byte *    serverSeed,
    unsigned        clientBytes,
    const byte *    clientSeed,
    unsigned        outputBytes,
    byte *          outputSeed
) {
    ASSERT(clientBytes == kNetMaxSymmetricSeedBytes);
    ASSERT(serverBytes == kNetMaxSymmetricSeedBytes);
    ASSERT(outputBytes == kNetMaxSymmetricSeedBytes);
    for (unsigned i = 0; i < outputBytes; ++i)
        outputSeed[i] = (byte) (clientSeed[i] ^ serverSeed[i]);
}

//============================================================================
static void ClientConnect (NetCli * cli) {

    // Initiate diffie-hellman for client
    BigNum clientSeed;
    BigNum serverSeed;
    NetMsgCryptClientStart(
        cli->channel,
        sizeof(cli->seed),
        cli->seed,
        &clientSeed,
        &serverSeed
    );

    // Save client seed
    {
        ZERO(cli->seed);
        unsigned bytes;
        const void * data = clientSeed.GetData(&bytes);
        MemCopy(cli->seed, data, min(bytes, sizeof(cli->seed)));
    }

    // Send server seed
    if (cli->sock) {
        unsigned bytes;
        NetCli_Cli2Srv_Connect msg;
        const void * data = serverSeed.GetData(&bytes);
        ASSERTMSG(bytes <= sizeof(msg.dh_y_data), "4");
        msg.message    = kNetCliCli2SrvConnect;
        msg.length     = (byte) (sizeof(msg) - sizeof(msg.dh_y_data) +  bytes);
        MemCopy(msg.dh_y_data, data, bytes);
        AsyncSocketSend(cli->sock, &msg, msg.length);
    }
}

//============================================================================
static bool ServerRecvConnect (
    NetCli *                    cli,
    const NetCli_PacketHeader & pkt
) {
    // Validate connection state
    if (cli->mode != kNetCliModeServerStart)
        return false;

    // Validate message size
    const NetCli_Cli2Srv_Connect & msg =
        * (const NetCli_Cli2Srv_Connect *) &pkt;
    if (pkt.length < sizeof(msg))
        return false;

    // Send the server seed to the client (unencrypted)
    if (cli->sock) {
        NetCli_Srv2Cli_Encrypt reply;
        reply.message   = kNetCliSrv2CliEncrypt;
        reply.length    = sizeof(reply);
        MemCopy(reply.serverSeed, cli->seed, sizeof(reply.serverSeed));
        AsyncSocketSend(cli->sock, &reply, sizeof(reply));
    }

    // Compute client seed
    byte clientSeed[kNetMaxSymmetricSeedBytes];
    {
        BigNum clientSeedValue;
        NetMsgCryptServerConnect(
            cli->channel,
            msg.length - sizeof(pkt),
            msg.dh_y_data,
            &clientSeedValue
        );

        ZERO(clientSeed);
        unsigned bytes;
        const void * data = clientSeedValue.GetData(&bytes);
        MemCopy(clientSeed, data, min(bytes, sizeof(clientSeed)));
    }

    // Create the symmetric key from a combination
    // of the client seed and the server seed
    byte sharedSeed[kNetMaxSymmetricSeedBytes];
    CreateSymmetricKey(
        sizeof(cli->seed),  cli->seed,  // server seed
        sizeof(clientSeed), clientSeed, // client seed
        sizeof(sharedSeed), sharedSeed  // combined seed
    );

    // Switch to encrypted mode
    cli->mode = kNetCliModeEncrypted;
    cli->cryptIn  = CryptKeyCreate(kCryptRc4, sizeof(sharedSeed), sharedSeed);
    cli->cryptOut = CryptKeyCreate(kCryptRc4, sizeof(sharedSeed), sharedSeed);

    return cli->encryptFcn(kNetSuccess, cli->encryptParam);
}

//============================================================================
static bool ClientRecvEncrypt (
    NetCli *                    cli,
    const NetCli_PacketHeader & pkt
) {
    // Validate connection state
    if (cli->mode != kNetCliModeClientStart)
        return false;

    // Validate message size
    const NetCli_Srv2Cli_Encrypt & msg =
        * (const NetCli_Srv2Cli_Encrypt *) &pkt;
    if (pkt.length != sizeof(msg))
        return false;

    // Create the symmetric key from a combination
    // of the client seed and the server seed
    byte sharedSeed[kNetMaxSymmetricSeedBytes];
    CreateSymmetricKey(
        sizeof(msg.serverSeed), msg.serverSeed, // server seed
        sizeof(cli->seed),      cli->seed,      // client seed
        sizeof(sharedSeed),     sharedSeed      // combined seed
    );

    // Switch to encrypted mode
    cli->mode = kNetCliModeEncrypted;
    cli->cryptIn  = CryptKeyCreate(kCryptRc4, sizeof(sharedSeed), sharedSeed);
    cli->cryptOut = CryptKeyCreate(kCryptRc4, sizeof(sharedSeed), sharedSeed);

    return cli->encryptFcn(kNetSuccess, cli->encryptParam);
}

//============================================================================
static bool ClientRecvError (
    NetCli *                    cli,
    const NetCli_PacketHeader & pkt
) {
    // Validate connection state
    if (cli->mode != kNetCliModeClientStart)
        return false;

    // Validate message size
    const NetCli_Srv2Cli_Error & msg =
        * (const NetCli_Srv2Cli_Error *) &pkt;
    if (pkt.length < sizeof(msg))
        return false;

    cli->encryptFcn((ENetError) msg.error, cli->encryptParam);
    return false;
}

//============================================================================
typedef bool (* FNetCliPacket)(
    NetCli *                    cli,
    const NetCli_PacketHeader & pkt
);

#if 0

#ifdef SERVER
static const FNetCliPacket s_recvTbl[kNumNetCliMsgs] = {
    ServerRecvConnect,
    nil,
    nil,
};
#endif

#ifdef CLIENT
static const FNetCliPacket s_recvTbl[kNumNetCliMsgs] = {
    nil,
    ClientRecvEncrypt,
    ClientRecvError,
};
#endif

#else // 0

static const FNetCliPacket s_recvTbl[kNumNetCliMsgs] = {
    ServerRecvConnect,
    ClientRecvEncrypt,
    ClientRecvError,
};

#endif // 0

//===========================================================================
static unsigned DispatchPacket (
    NetCli *        cli,
    unsigned        bytes,
    const byte      data[]
) {
    for (;;) {
        const NetCli_PacketHeader & pkt = * (const NetCli_PacketHeader *) data;
        if (bytes < sizeof(pkt))
            break;
        if (pkt.length > bytes)
            break;
        if (pkt.message >= kNumNetCliMsgs)
            break;
        if (!s_recvTbl[pkt.message])
            break;
        if (!s_recvTbl[pkt.message](cli, pkt))
            break;

        // Success!
        return pkt.length;
    }

    // Failure!
    return 0;
}

} // namespace Connect


/*****************************************************************************
*
*   NetCli implementation
*
***/

//===========================================================================
static void ResetSendRecv (NetCli * cli) {
    cli->recvMsg            = nil;
    cli->recvField          = nil;
    cli->recvFieldBytes     = 0;
    cli->recvDispatch       = true;
    cli->sendCurr           = cli->sendBuffer;
    cli->recvBuffer.Clear();
    cli->input.Clear();
}

//===========================================================================
static NetCli * ConnCreate (
    AsyncSocket     sock,
    unsigned        protocol,
    ENetCliMode     mode
) {
    // find channel
    unsigned largestRecv;
    NetMsgChannel * channel = NetMsgChannelLock(
        protocol,
        mode == kNetCliModeServerStart,
        &largestRecv
    );
    if (!channel)
        return nil;

    NetCli * const cli  = NEWZERO(NetCli);
    cli->sock           = sock;
    cli->protocol       = (ENetProtocol) protocol;
    cli->channel        = channel;
    cli->mode           = mode;
    cli->SetValue(kNilGuid);

    ResetSendRecv(cli);

    return cli;
}

//===========================================================================
static void SetConnSeed (
    NetCli *        cli,
    unsigned        seedBytes,
    const byte      seedData[]
) {
    if (seedBytes)
        MemCopy(cli->seed, seedData, min(sizeof(cli->seed), seedBytes));
    else
        CryptCreateRandomSeed(sizeof(cli->seed), cli->seed);
}

} using namespace pnNetCli;


/*****************************************************************************
*
*   Exports
*
***/

//============================================================================
NetCli * NetCliConnectAccept (
    AsyncSocket         sock,
    unsigned            protocol,
    bool                unbuffered,
    FNetCliEncrypt      encryptFcn,
    unsigned            seedBytes,
    const byte          seedData[],
    void *              encryptParam
) {
    // Create connection
    NetCli * cli = ConnCreate(sock, protocol, kNetCliModeClientStart);
    if (cli) {
        AsyncSocketEnableNagling(sock, !unbuffered);
        cli->encryptFcn     = encryptFcn;
        cli->encryptParam   = encryptParam;
        SetConnSeed(cli, seedBytes, seedData);
        Connect::ClientConnect(cli);
    }
    return cli;
}

//============================================================================
#ifdef SERVER
NetCli * NetCliListenAccept (
    AsyncSocket         sock,
    unsigned            protocol,
    bool                unbuffered,
    FNetCliEncrypt      encryptFcn,
    unsigned            seedBytes,
    const byte          seedData[],
    void *              encryptParam
) {
    // Create connection
    NetCli * cli = ConnCreate(sock, protocol, kNetCliModeServerStart);
    if (cli) {
        AsyncSocketEnableNagling(sock, !unbuffered);
        cli->encryptFcn     = encryptFcn;
        cli->encryptParam   = encryptParam;
        SetConnSeed(cli, seedBytes, seedData);
    }
    return cli;
}
#endif

//============================================================================
#ifdef SERVER
void NetCliListenReject (
    AsyncSocket     sock,
    ENetError       error
) {
    if (sock) {
        Connect::NetCli_Srv2Cli_Error response;
        response.message    = Connect::kNetCliSrv2CliError;
        response.length     = sizeof(response);
        response.error      = error;
        AsyncSocketSend(sock, &response, sizeof(response));
    }
}
#endif

//============================================================================
void NetCliClearSocket (NetCli * cli) {
    cli->sock = nil;
}

//============================================================================
void NetCliSetQueue (
    NetCli *        cli,
    NetCliQueue *   queue
) {
    cli->queue = queue;
}

//============================================================================
void NetCliDisconnect (
    NetCli *        cli,
    bool            hardClose
) {
    // send any existing messages and allow
    // the socket layer to complete sending data
    if (!hardClose)
        NetCliFlush(cli);

    if (cli->sock)
        AsyncSocketDisconnect(cli->sock, hardClose);

    // don't allow any more messages to be received
    cli->recvDispatch = false;
}

//============================================================================
void NetCliDelete (
    NetCli *        cli,
    bool            deleteSocket
) {
    NetMsgChannelUnlock(cli->channel);

    if (cli->sock && deleteSocket)
        AsyncSocketDelete(cli->sock);

    if (cli->cryptIn)
        CryptKeyClose(cli->cryptIn);
    if (cli->cryptOut)
        CryptKeyClose(cli->cryptOut);

    cli->input.Clear();
    cli->recvBuffer.Clear();

    DEL(cli);
}

//============================================================================
void NetCliFlush (
    NetCli *        cli
) {
    if (cli->sendCurr != cli->sendBuffer)
        FlushSendBuffer(cli);
}

//============================================================================
void NetCliSend (
    NetCli *            cli,
    const unsigned_ptr  msg[], 
    unsigned            count
) {
    BufferedSendData(cli, msg, count);
}

//============================================================================
bool NetCliDispatch (
    NetCli *        cli,
    const byte      data[],
    unsigned        bytes,
    void *          param
) {
    if (!cli->recvDispatch)
        return false;

    do {
        if (cli->mode == kNetCliModeEncrypted) {
            // Decrypt data...
            byte * temp, * heap = NULL;

#ifndef NO_ENCRYPTION
            if (bytes <= 2048)
                // byte count is small, use stack-based buffer
                temp = ALLOCA(byte, bytes);
            else
                // byte count is large, use heap-based buffer
                temp = heap = (byte *)ALLOC(bytes);

            MemCopy(temp, data, bytes);
            CryptDecrypt(cli->cryptIn, bytes, temp);
            data = temp;
#endif

            // Add data to accumulator and dispatch
            cli->input.Add(bytes, data);
            bool result = DispatchData(cli, param);

#ifdef SERVER
            cli->recvDispatch = result;
#endif
            
            // free heap buffer (if any)
            FREE(heap);

            cli->input.Compact();
            return cli->recvDispatch;
        }

        // Dispatch connect packets until encryption starts
        unsigned used = Connect::DispatchPacket(cli, bytes, data);
        if (!used)
            return false;

        data  += used;
        bytes -= used;

    } while (bytes);

    return true;
}