/*==LICENSE==*

CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011  Cyan Worlds, Inc.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

You can contact Cyan Worlds, Inc. by email legal@cyan.com
 or by snail mail at:
      Cyan Worlds, Inc.
      14617 N Newport Hwy
      Mead, WA   99021

*==LICENSE==*/
/*****************************************************************************
*
*   $/Plasma20/Sources/Plasma/NucleusLib/pnNetCli/pnNcChannel.cpp
*   
***/

#include "Pch.h"
#pragma hdrstop


namespace pnNetCli {

/*****************************************************************************
*
*   Private
*
***/

struct ChannelCrit {
    ~ChannelCrit ();
    ChannelCrit ();
    inline void Enter ()        { m_critsect.Enter();               }
    inline void Leave ()        { m_critsect.Leave();               }
    inline void EnterSafe ()    { if (m_init) m_critsect.Enter();   }
    inline void LeaveSafe ()    { if (m_init) m_critsect.Leave();   }

private:
    bool        m_init;
    CCritSect   m_critsect;
};

struct NetMsgChannel : AtomicRef {
    LINK(NetMsgChannel)     m_link;
    unsigned                m_protocol;
    bool                    m_server;

    // Message definitions
    unsigned                m_largestRecv;
    ARRAY(NetMsgInitSend)   m_sendMsgs;
    ARRAY(NetMsgInitRecv)   m_recvMsgs;

    // Diffie-Hellman constants
    unsigned                m_dh_g;
    BigNum                  m_dh_xa;    // client: dh_x     server: dh_a
    BigNum                  m_dh_n;
};

static ChannelCrit              s_channelCrit;
static LIST(NetMsgChannel) *    s_channels;


/****************************************************************************
*
*   ChannelCrit
*
***/

//===========================================================================
ChannelCrit::ChannelCrit () {
    m_init = true;
}

//===========================================================================
ChannelCrit::~ChannelCrit () {
    EnterSafe();
    if (s_channels) {
        while (NetMsgChannel * const channel = s_channels->Head()) {
            s_channels->Unlink(channel);
            channel->DecRef("ChannelLink");
        }

        DEL(s_channels);
        s_channels = nil;
    }
    LeaveSafe();
}


/*****************************************************************************
*
*   Internal functions
*
***/

//===========================================================================
// Returns max size of message in bytes
static unsigned ValidateMsg (const NetMsg & msg) {
    ASSERT(msg.fields);
    ASSERT(msg.count);

    unsigned maxBytes = sizeof(word);    // for message id
    bool prevFieldWasVarCount = false;

    for (unsigned i = 0; i < msg.count; i++) {
        const NetMsgField & field = msg.fields[i];

        for (;;) {
            bool gotVarCount = false;
            bool gotVarField = false;
            if (field.type == kNetMsgFieldVarCount) {
                if (gotVarField || gotVarCount)
                    FATAL("Msg definition may only include one variable length field");
                gotVarCount = true;
                break;
            }
            if (field.type == kNetMsgFieldVarPtr || field.type == kNetMsgFieldRawVarPtr) {
                if (gotVarField || gotVarCount)
                    FATAL("Msg definition may only include one variable length field");
                if (!prevFieldWasVarCount)
                    FATAL("Variable length field must preceded by variable length count field");
                gotVarField = true;
                break;
            }
            if (gotVarField)
                FATAL("Variable length field must be the last field in message definition");
            break;
        }

        prevFieldWasVarCount = false;
        switch (field.type) {
            case kNetMsgFieldInteger:
                maxBytes += sizeof(qword);
            break;

            case kNetMsgFieldReal:
                maxBytes += sizeof(double);
            break;

            case kNetMsgFieldVarPtr:
            case kNetMsgFieldRawVarPtr:
            break;

            case kNetMsgFieldVarCount:
                prevFieldWasVarCount = true;
                // fall-thru...
            case kNetMsgFieldString:
            case kNetMsgFieldPtr:
            case kNetMsgFieldRawPtr:
            case kNetMsgFieldData:
            case kNetMsgFieldRawData:
                maxBytes += msg.fields[i].count * msg.fields[i].size;
            break;

            DEFAULT_FATAL(field.type);
        }
    }

    return maxBytes;
}


//===========================================================================
template<class T>
static unsigned MaxMsgId (const T msgs[], unsigned count) {
    unsigned maxMsgId = 0;

    for (unsigned i = 0; i < count; i++) {
        ASSERT(msgs[i].msg.count);
        maxMsgId = max(msgs[i].msg.messageId, maxMsgId);
    }
    return maxMsgId;
}

//===========================================================================
static void AddSendMsgs_CS (
    NetMsgChannel *         channel,
    const NetMsgInitSend    src[],
    unsigned                count
) {
    channel->m_sendMsgs.GrowToFit(MaxMsgId(src, count), true);

    for (const NetMsgInitSend * term = src + count; src < term; ++src) {
        NetMsgInitSend * const dst = &channel->m_sendMsgs[src[0].msg.messageId];

        // check to ensure that the message id isn't already used
        ASSERT(!dst->msg.count);

        *dst = *src;
        ValidateMsg(dst->msg);
    }
}

//===========================================================================
static void AddRecvMsgs_CS (
    NetMsgChannel *         channel,
    const NetMsgInitRecv    src[],
    unsigned                count
) {
    channel->m_recvMsgs.GrowToFit(MaxMsgId(src, count), true);

    for (const NetMsgInitRecv * term = src + count; src < term; ++src) {
        ASSERT(src->recv);
        NetMsgInitRecv * const dst = &channel->m_recvMsgs[src[0].msg.messageId];

        // check to ensure that the message id isn't already used
        ASSERT(!dst->msg.count);

        // copy the message handler
        *dst = *src;

        const unsigned bytes = ValidateMsg(dst->msg);
        channel->m_largestRecv = max(channel->m_largestRecv, bytes);
    }
}

//===========================================================================
static NetMsgChannel * FindChannel_CS (unsigned protocol, bool server) {
    if (!s_channels)
        return nil;

    NetMsgChannel * channel = s_channels->Head();
    for (; channel; channel = s_channels->Next(channel)) {
        if ((channel->m_protocol == protocol) && (channel->m_server == server))
            break;
    }

    return channel;
}

//===========================================================================
static NetMsgChannel * FindOrCreateChannel_CS (unsigned protocol, bool server) {
    if (!s_channels) {
        s_channels = NEW(LIST(NetMsgChannel));
        s_channels->SetLinkOffset(offsetof(NetMsgChannel, m_link));
    }

    // find or create protocol
    NetMsgChannel * channel = FindChannel_CS(protocol, server);
    if (!channel) {
        channel                 = NEW(NetMsgChannel);
        channel->m_protocol     = protocol;
        channel->m_server       = server;
        channel->m_largestRecv  = 0;

        s_channels->Link(channel);
        channel->IncRef("ChannelLink");
    }

    return channel;
}


/*****************************************************************************
*
*   Module functions
*
***/

//============================================================================
NetMsgChannel * NetMsgChannelLock (
    unsigned        protocol,
    bool            server,
    unsigned *      largestRecv
) {
    NetMsgChannel * channel;
    s_channelCrit.Enter();
    if (nil != (channel = FindChannel_CS(protocol, server))) {
        *largestRecv = channel->m_largestRecv;
        channel->IncRef("ChannelLock");
    }
    else {
        *largestRecv = 0;
    }
    s_channelCrit.Leave();
    return channel;
}

//============================================================================
void NetMsgChannelUnlock (
    NetMsgChannel * channel
) {
    s_channelCrit.Enter();
    {
        channel->DecRef("ChannelLock");
    }
    s_channelCrit.Leave();
}

//============================================================================
const NetMsgInitRecv * NetMsgChannelFindRecvMessage (
    NetMsgChannel * channel,
    unsigned        messageId
) {
    // Is message in range?
    if (messageId >= channel->m_recvMsgs.Count())
        return nil;

    // Is message defined?
    const NetMsgInitRecv * recvMsg = &channel->m_recvMsgs[messageId];
    if (!recvMsg->msg.count)
        return nil;

    // Success!
    return recvMsg;
}

//============================================================================
const NetMsgInitSend * NetMsgChannelFindSendMessage (
    NetMsgChannel * channel,
    unsigned        messageId
) {
    // Is message in range?
    ASSERT(messageId < channel->m_sendMsgs.Count());

    // Is message defined?
    const NetMsgInitSend * sendMsg = &channel->m_sendMsgs[messageId];
    ASSERTMSG(sendMsg->msg.count, "NetMsg not found for send");

    return sendMsg;
}

//============================================================================
void NetMsgChannelGetDhConstants (
    const NetMsgChannel *   channel,
    unsigned *              dh_g,
    const BigNum **         dh_xa,
    const BigNum **         dh_n
) {
    *dh_g   =  channel->m_dh_g;
    *dh_xa  = &channel->m_dh_xa;
    *dh_n   = &channel->m_dh_n;
}


}   // namespace pnNetCli


/*****************************************************************************
*
*   Exports
*
***/

//===========================================================================
void NetMsgProtocolRegister (
    unsigned                protocol,
    bool                    server,
    const NetMsgInitSend    sendMsgs[],
    unsigned                sendMsgCount,
    const NetMsgInitRecv    recvMsgs[],
    unsigned                recvMsgCount,
    unsigned                dh_g,
    const BigNum &          dh_xa,    // client: dh_x     server: dh_a
    const BigNum &          dh_n
) {
    s_channelCrit.EnterSafe();
    {
        NetMsgChannel * channel = FindOrCreateChannel_CS(protocol, server);

        // make sure no connections have been established on this protocol, otherwise
        // we'll be modifying a live data structure; NetCli's don't lock their protocol
        // to operate on it once they have linked to it!
        ASSERT(channel->GetRefCount() == 1);

        channel->m_dh_g     = dh_g;
        channel->m_dh_xa    = dh_xa;
        channel->m_dh_n     = dh_n;

        if (sendMsgCount)
            AddSendMsgs_CS(channel, sendMsgs, sendMsgCount);
        if (recvMsgCount)
            AddRecvMsgs_CS(channel, recvMsgs, recvMsgCount);
    }
    s_channelCrit.LeaveSafe();
}

//===========================================================================
void NetMsgProtocolDestroy (unsigned protocol, bool server) {
    s_channelCrit.EnterSafe();
    if (NetMsgChannel * channel = FindChannel_CS(protocol, server)) {
        s_channels->Unlink(channel);
        channel->DecRef("ChannelLink");
    }
    s_channelCrit.LeaveSafe();
}