/*==LICENSE==*

CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011  Cyan Worlds, Inc.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

You can contact Cyan Worlds, Inc. by email legal@cyan.com
 or by snail mail at:
      Cyan Worlds, Inc.
      14617 N Newport Hwy
      Mead, WA   99021

*==LICENSE==*/
/*****************************************************************************
*
*   $/Plasma20/Sources/Plasma/NucleusLib/pnSimpleNet/pnSimpleNet.cpp
*   
***/

#include "Pch.h"
#pragma hdrstop


/*****************************************************************************
*
*   Local types
*
***/

struct SimpleNetConn : AtomicRef {
    LINK(SimpleNetConn)         link;
    AsyncSocket                 sock;
    AsyncCancelId               cancelId;
    unsigned                    channelId;
    bool                        abandoned;
    struct ConnectParam *       connectParam;

    SimpleNet_MsgHeader *       oversizeMsg;
    ARRAY(byte)                 oversizeBuffer;

    ~SimpleNetConn () {
        ASSERT(!link.IsLinked());
    }
};

struct SimpleNetChannel : AtomicRef, THashKeyVal<unsigned> {
    HASHLINK(SimpleNetChannel)  link;
    
    FSimpleNetOnMsg             onMsg;
    FSimpleNetOnError           onError;
    
    LISTDECL(SimpleNetConn, link)   conns;

    SimpleNetChannel (unsigned channel)
    : THashKeyVal<unsigned>(channel)
    { }
    ~SimpleNetChannel () {
        ASSERT(!link.IsLinked());
        ASSERT(!conns.Head());
    }
};

struct ConnectParam {
    SimpleNetChannel *      channel;
    FSimpleNetOnConnect     callback;
    void *                  param;
    
    ~ConnectParam () {
        if (channel)
            channel->DecRef();
    }
};


/*****************************************************************************
*
*   Local data
*
***/

static bool                         s_running;
static CCritSect                    s_critsect;
static FSimpleNetQueryAccept        s_queryAccept;
static void *                       s_queryAcceptParam;

static HASHTABLEDECL(
    SimpleNetChannel,
    THashKeyVal<unsigned>,
    link
) s_channels;


/*****************************************************************************
*
*   Local functions
*
***/

//============================================================================
static void NotifyConnSocketConnect (SimpleNetConn * conn) {

    conn->TransferRef("Connecting", "Connected");
    
    conn->connectParam->callback(
        conn->connectParam->param,
        conn,
        kNetSuccess
    );
    
    DEL(conn->connectParam);
    conn->connectParam = nil;
}

//============================================================================
static void NotifyConnSocketConnectFailed (SimpleNetConn * conn) {

    s_critsect.Enter();
    {
        conn->link.Unlink();
    }
    s_critsect.Leave();

    conn->connectParam->callback(
        conn->connectParam->param,
        nil,
        kNetErrConnectFailed
    );
    
    DEL(conn->connectParam);
    conn->connectParam = nil;

    conn->DecRef("Connecting");
    conn->DecRef("Lifetime");
}

//============================================================================
static void NotifyConnSocketDisconnect (SimpleNetConn * conn) {

    bool abandoned;
    SimpleNetChannel * channel;
    s_critsect.Enter();
    {
        abandoned = conn->abandoned;
        if (nil != (channel = s_channels.Find(conn->channelId)))
            channel->IncRef();
        conn->link.Unlink();
    }
    s_critsect.Leave();
    
    if (channel && !abandoned) {
        channel->onError(conn, kNetErrDisconnected);
        channel->DecRef();
    }
    
    conn->DecRef("Connected");
}

//============================================================================
static bool NotifyConnSocketRead (SimpleNetConn * conn, AsyncNotifySocketRead * read) {

    SimpleNetChannel * channel;
    s_critsect.Enter();
    {
        if (nil != (channel = s_channels.Find(conn->channelId)))
            channel->IncRef();
    }
    s_critsect.Leave();
    
    if (!channel)
        return false;
        
    bool result = true;

    const byte * curr = read->buffer;
    const byte * term = curr + read->bytes;

    while (curr < term) {
        // Reading oversize msg?
        if (conn->oversizeBuffer.Count()) {
            unsigned spaceLeft = conn->oversizeMsg->messageBytes - conn->oversizeBuffer.Count();
            unsigned copyBytes = min(spaceLeft, term - curr);
            conn->oversizeBuffer.Add(curr, copyBytes);
            
            curr += copyBytes;

            // Wait until we have received the entire message
            if (copyBytes != spaceLeft)
                break;
                
            // Dispatch oversize msg
            if (!channel->onMsg(conn, conn->oversizeMsg)) {
                result = false;
                break;
            }
            
            conn->oversizeBuffer.SetCount(0);
            continue;
        }

        // Wait until we receive the entire message header
        if (term - curr < sizeof(SimpleNet_MsgHeader))
            break;

        SimpleNet_MsgHeader * msg = (SimpleNet_MsgHeader *) read->buffer;

        // Sanity check message size        
        if (msg->messageBytes < sizeof(*msg)) {
            result = false;
            break;
        }
        
        // Handle oversized messages
        if (msg->messageBytes > kAsyncSocketBufferSize) {
            
            conn->oversizeBuffer.SetCount(msg->messageBytes);
            conn->oversizeMsg = (SimpleNet_MsgHeader *) conn->oversizeBuffer.Ptr();
            *conn->oversizeMsg = *msg;          
            
            curr += sizeof(*msg);
            continue;
        }
        
        // Wait until we have received the entire message
        const byte * msgTerm = (const byte *) curr + msg->messageBytes;
        if (msgTerm > term)
            break;
        curr = msgTerm;

        // Dispatch msg
        if (!channel->onMsg(conn, msg)) {
            result = false;
            break;
        }
    }

    // Return count of bytes we processed
    read->bytesProcessed = curr - read->buffer;
    
    channel->DecRef();
    return result;
}

//============================================================================
static bool AsyncNotifySocketProc (
    AsyncSocket         sock,
    EAsyncNotifySocket  code,
    AsyncNotifySocket * notify,
    void **             userState
) {
    bool result = true;
    SimpleNetConn * conn;

    switch (code) {
        case kNotifySocketListenSuccess: {

            AsyncNotifySocketListen * listen = (AsyncNotifySocketListen *) notify;

            const SimpleNet_ConnData & connect = *(const SimpleNet_ConnData *) listen->buffer;
            listen->bytesProcessed += sizeof(connect);

            SimpleNetChannel * channel;
            s_critsect.Enter();
            {
                if (nil != (channel = s_channels.Find(connect.channelId)))
                    channel->IncRef();
            }
            s_critsect.Leave();
            
            if (!channel)
                break;
                
            conn = NEWZERO(SimpleNetConn);
            conn->channelId = channel->GetValue();
            conn->IncRef("Lifetime");
            conn->IncRef("Connected");
            conn->sock = sock;
            *userState = conn;
            
            bool accepted = s_queryAccept(
                s_queryAcceptParam,
                channel->GetValue(),
                conn,
                listen->remoteAddr
            );
            
            if (!accepted) {
                SimpleNetDisconnect(conn);
            }
            else {
                s_critsect.Enter();
                {
                    channel->conns.Link(conn);
                }
                s_critsect.Leave();
            }
            
            channel->DecRef();
        }
        break;
        
        case kNotifySocketConnectSuccess: {
            conn = (SimpleNetConn *) notify->param;
            *userState = conn;
            bool abandoned;
            
            s_critsect.Enter();
            {
                conn->sock      = sock;
                conn->cancelId  = 0;
                abandoned       = conn->abandoned;
            }
            s_critsect.Leave();
            
            if (abandoned)
                AsyncSocketDisconnect(sock, true);
            else
                NotifyConnSocketConnect(conn);
        }
        break;

        case kNotifySocketConnectFailed:
            conn = (SimpleNetConn *) notify->param;
            NotifyConnSocketConnectFailed(conn);
        break;

        case kNotifySocketDisconnect:
            conn = (SimpleNetConn *) *userState;
            NotifyConnSocketDisconnect(conn);
        break;

        case kNotifySocketRead:
            conn = (SimpleNetConn *) *userState;
            result = NotifyConnSocketRead(conn, (AsyncNotifySocketRead *) notify);
        break;
    }

    return result;
}

//============================================================================
static void Connect (const NetAddress & addr, ConnectParam * cp) {

    SimpleNetConn * conn = NEWZERO(SimpleNetConn);
    conn->channelId = cp->channel->GetValue();
    conn->connectParam = cp;
    conn->IncRef("Lifetime");
    conn->IncRef("Connecting");

    s_critsect.Enter();
    {
        cp->channel->conns.Link(conn);
        
        SimpleNet_Connect connect;
        connect.hdr.connType    = kConnTypeSimpleNet;
        connect.hdr.hdrBytes    = sizeof(connect.hdr);
        connect.hdr.buildId     = BuildId();
        connect.hdr.buildType   = BuildType();
        connect.hdr.branchId    = BranchId();
        connect.hdr.productId   = ProductId();
        connect.data.channelId  = cp->channel->GetValue();
            
        AsyncSocketConnect(
            &conn->cancelId,
            addr,
            AsyncNotifySocketProc,
            conn,
            &connect,
            sizeof(connect)
        );
        
        conn = nil; 
        cp = nil;
    }
    s_critsect.Leave();
    
    DEL(conn);
    DEL(cp);
}

//============================================================================
static void AsyncLookupCallback (
    void *              param,
    const wchar         name[],
    unsigned            addrCount,
    const NetAddress    addrs[]
) {
    ConnectParam * cp = (ConnectParam *)param;

    if (!addrCount) {
        if (cp->callback)
            cp->callback(cp->param, nil, kNetErrNameLookupFailed);
        DEL(cp);
        return;
    }

    Connect(addrs[0], (ConnectParam *)param);
}


/*****************************************************************************
*
*   Exported functions
*
***/

//============================================================================
void SimpleNetInitialize () {

    s_running = true;

    AsyncSocketRegisterNotifyProc(
        kConnTypeSimpleNet,
        AsyncNotifySocketProc
    );
}

//============================================================================
void SimpleNetShutdown () {

    s_running = false;

    ASSERT(!s_channels.Head());
    
    AsyncSocketUnregisterNotifyProc(
        kConnTypeSimpleNet,
        AsyncNotifySocketProc
    );
}

//============================================================================
void SimpleNetConnIncRef (SimpleNetConn * conn) {

    ASSERT(s_running);
    ASSERT(conn);
    
    conn->IncRef();
}

//============================================================================
void SimpleNetConnDecRef (SimpleNetConn * conn) {

    ASSERT(s_running);
    ASSERT(conn);
    
    conn->DecRef();
}

//============================================================================
bool SimpleNetStartListening (
    FSimpleNetQueryAccept   queryAccept,
    void *                  param
) {
    ASSERT(s_running);
    ASSERT(queryAccept);
    ASSERT(!s_queryAccept);
    
    s_queryAccept       = queryAccept;
    s_queryAcceptParam  = param;
    
    NetAddress addr;
    NetAddressFromNode(0, kNetDefaultSimpleNetPort, &addr);
    return (0 != AsyncSocketStartListening(addr, nil));
}

//============================================================================
void SimpleNetStopListening () {

    ASSERT(s_running);

    NetAddress addr;
    NetAddressFromNode(0, kNetDefaultSimpleNetPort, &addr);
    AsyncSocketStopListening(addr, nil);

    s_queryAccept       = nil;
    s_queryAcceptParam  = nil;
}

//============================================================================
void SimpleNetCreateChannel (
    unsigned            channelId,
    FSimpleNetOnMsg     onMsg,
    FSimpleNetOnError   onError
) {
    ASSERT(s_running);

    SimpleNetChannel * channel = NEWZERO(SimpleNetChannel)(channelId);
    channel->IncRef();

    s_critsect.Enter();
    {
        #ifdef HS_DEBUGGING
        {
            SimpleNetChannel * existing = s_channels.Find(channelId);
            ASSERT(!existing);
        }
        #endif
        
        channel->onMsg      = onMsg;
        channel->onError    = onError;
        s_channels.Add(channel);
        channel->IncRef();
    }
    s_critsect.Leave();
    
    channel->DecRef();
}

//============================================================================
void SimpleNetDestroyChannel (unsigned channelId) {

    ASSERT(s_running);

    SimpleNetChannel * channel;
    s_critsect.Enter();
    {
        if (nil != (channel = s_channels.Find(channelId))) {
            s_channels.Unlink(channel);
            while (SimpleNetConn * conn = channel->conns.Head()) {
                SimpleNetDisconnect(conn);
                channel->conns.Unlink(conn);
            }
        }
    }
    s_critsect.Leave();

    if (channel)
        channel->DecRef();  
}

//============================================================================
void SimpleNetStartConnecting (
    unsigned            channelId,
    const wchar         addr[],
    FSimpleNetOnConnect onConnect,
    void *              param
) {
    ASSERT(s_running);
    ASSERT(onConnect);
    
    ConnectParam * cp = NEW(ConnectParam);
    cp->callback    = onConnect;
    cp->param       = param;
    
    s_critsect.Enter();
    {
        if (nil != (cp->channel = s_channels.Find(channelId)))
            cp->channel->IncRef();
    }
    s_critsect.Leave();
    
    ASSERT(cp->channel);

    // Do we need to lookup the address?
    const wchar * name = addr;
    while (unsigned ch = *name) {
        ++name;
        if (!(isdigit(ch) || ch == L'.' || ch == L':')) {

            AsyncCancelId cancelId;
            AsyncAddressLookupName(
                &cancelId,
                AsyncLookupCallback,
                addr,
                kNetDefaultSimpleNetPort,
                cp
            );
            break;
        }
    }
    if (!name[0]) {
        NetAddress netAddr;
        NetAddressFromString(&netAddr, addr, kNetDefaultSimpleNetPort);
        Connect(netAddr, cp);
    }
}

//============================================================================
void SimpleNetDisconnect (
    SimpleNetConn * conn
) {
    ASSERT(s_running);
    ASSERT(conn);
    
    s_critsect.Enter();
    {
        conn->abandoned = true;
        if (conn->sock) {
            AsyncSocketDisconnect(conn->sock, true);
            conn->sock = nil;
        }
        else if (conn->cancelId) {
            AsyncSocketConnectCancel(AsyncNotifySocketProc, conn->cancelId);
            conn->cancelId = nil;
        }
    }
    s_critsect.Leave();

    conn->DecRef("Lifetime");
}

//============================================================================
void SimpleNetSend (
    SimpleNetConn *         conn,
    SimpleNet_MsgHeader *   msg
) {
    ASSERT(s_running);
    ASSERT(msg);
    ASSERT(msg->messageBytes != (dword)-1);
    ASSERT(conn);
    
    s_critsect.Enter();
    {
        if (conn->sock)
            AsyncSocketSend(conn->sock, msg, msg->messageBytes);
    }
    s_critsect.Leave();
}