/*==LICENSE==*

CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011  Cyan Worlds, Inc.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

You can contact Cyan Worlds, Inc. by email legal@cyan.com
 or by snail mail at:
      Cyan Worlds, Inc.
      14617 N Newport Hwy
      Mead, WA   99021

*==LICENSE==*/
/*****************************************************************************
*
*   $/Plasma20/Sources/Plasma/PubUtilLib/plNetGameLib/Private/plNglCsr.cpp
*   
***/

#include "../Pch.h"
#pragma hdrstop


namespace Ngl { namespace Csr {
/*****************************************************************************
*
*   Internal types
*
***/

struct ConnectParam {
    FNetCliCsrConnectedCallback callback;
    void *                      param;
};

//============================================================================
// Connection record
//============================================================================
struct CliCsConn : AtomicRef {
    LINK(CliCsConn) link;

    CCritSect       critsect;
    AsyncSocket     sock;
    AsyncCancelId   cancelId;
    NetCli *        cli;
    NetAddress      addr;
    unsigned        seq;
    bool            abandoned;
    unsigned        serverChallenge;
    unsigned        latestBuildId;
    ConnectParam *  connectParam;

    // ping
    AsyncTimer *    pingTimer;
    unsigned        pingSendTimeMs;
    unsigned        lastHeardTimeMs;

    CliCsConn ();
    ~CliCsConn ();

    void AutoPing ();
    void StopAutoPing ();
    void TimerPing ();

    void Send (const unsigned_ptr fields[], unsigned count);
};


//============================================================================
// Transaction objects
//============================================================================
struct ConnectedNotifyTrans : NetNotifyTrans {

    ConnectParam *  m_connectParam;
    unsigned        m_latestBuildId;

    ConnectedNotifyTrans (ConnectParam * cp, unsigned lbi)
    : NetNotifyTrans(kCsrConnectedNotifyTrans)
    , m_connectParam(cp)
    , m_latestBuildId(lbi)
    { }
    ~ConnectedNotifyTrans () {
        DEL(m_connectParam);
    }
    void Post ();
};

struct LoginRequestTrans : NetCsrTrans {

    wchar                   m_csrName[kMaxAccountNameLength];
    ShaDigest               m_namePassHash;
    FNetCliCsrLoginCallback m_callback;
    void *                  m_param;
    
    Uuid                    m_csrId;
    unsigned                m_csrFlags;
    
    LoginRequestTrans (
        const wchar             csrName[],
        const ShaDigest &       namePassHash,
        FNetCliCsrLoginCallback callback,
        void *                  param
    );

    bool Send ();
    void Post ();
    bool Recv (
        const byte  msg[],
        unsigned    bytes
    );
};


/*****************************************************************************
*
*   Internal data
*
***/

enum {
    kPerfConnCount,
    kPingDisabled,
    kNumPerf
};

static bool                         s_running;
static CCritSect                    s_critsect;
static LISTDECL(CliCsConn, link)    s_conns;
static CliCsConn *                  s_active;
static long                         s_perf[kNumPerf];


/*****************************************************************************
*
*   Internal functions
*
***/

//===========================================================================
static unsigned GetNonZeroTimeMs () {
    if (unsigned ms = TimeGetMs())
        return ms;
    return 1;
}

//============================================================================
static CliCsConn * GetConnIncRef_CS (const char tag[]) {

    if (CliCsConn * conn = s_active)
        if (conn->cli) {
            conn->IncRef(tag);
            return conn;
        }
    return nil;
}

//============================================================================
static CliCsConn * GetConnIncRef (const char tag[]) {
    CliCsConn * conn;
    s_critsect.Enter();
    {
        conn = GetConnIncRef_CS(tag);
    }
    s_critsect.Leave();
    return conn;
}

//============================================================================
static void UnlinkAndAbandonConn_CS (CliCsConn * conn) {

    s_conns.Unlink(conn);
    conn->abandoned = true;
    if (conn->cancelId) {
        AsyncSocketConnectCancel(nil, conn->cancelId);
        conn->cancelId  = 0;
    }
    else if (conn->sock) {
        AsyncSocketDisconnect(conn->sock, true);
    }
    else {
        conn->DecRef("Lifetime");
    }
}

//============================================================================
static void SendRegisterRequest (CliCsConn * conn) {
    
    const unsigned_ptr msg[] = {
        kCli2Csr_RegisterRequest,
        0
    };

    conn->Send(msg, arrsize(msg));
}

//============================================================================
static bool ConnEncrypt (ENetError error, void * param) {

    CliCsConn * conn = (CliCsConn *) param;

    if (IS_NET_SUCCESS(error)) {
        s_critsect.Enter();
        {
            s_active = conn;
            conn->AutoPing();
            conn->IncRef();
        }
        s_critsect.Leave();
        
        SendRegisterRequest(conn);
        
        conn->DecRef();
    }

    return IS_NET_SUCCESS(error);
}

//============================================================================
static void NotifyConnSocketConnect (CliCsConn * conn) {

    conn->cli = NetCliConnectAccept(
        conn->sock,
        kNetProtocolCli2Csr,
        false,
        ConnEncrypt,
        0,
        nil,
        conn
    );
}

//============================================================================
static void NotifyConnSocketConnectFailed (CliCsConn * conn) {

    bool notify;
    s_critsect.Enter();
    {
        conn->cancelId = 0;
        s_conns.Unlink(conn);
        
        notify
            =  s_running
            && !conn->abandoned
            && (!s_active || conn == s_active);
            
        if (conn == s_active)
            s_active = nil;
    }
    s_critsect.Leave();

    NetTransCancelByConnId(conn->seq, kNetErrTimeout);
    conn->DecRef("Connecting");
    conn->DecRef("Lifetime");

    if (notify)
        ReportNetError(kNetProtocolCli2Csr, kNetErrConnectFailed);
}

//============================================================================
static void NotifyConnSocketDisconnect (CliCsConn * conn) {

    conn->StopAutoPing();

    bool notify;
    s_critsect.Enter();
    {
        s_conns.Unlink(conn);

        notify
            =  s_running
            && !conn->abandoned
            && (!s_active || conn == s_active);

        if (conn == s_active)
            s_active = nil;
    }
    s_critsect.Leave();

    // Cancel all transactions in process on this connection.
    NetTransCancelByConnId(conn->seq, kNetErrTimeout);
    conn->DecRef("Connected");
    conn->DecRef("Lifetime");

    if (notify)
        ReportNetError(kNetProtocolCli2Csr, kNetErrDisconnected);
}

//============================================================================
static bool NotifyConnSocketRead (CliCsConn * conn, AsyncNotifySocketRead * read) {

    conn->lastHeardTimeMs = GetNonZeroTimeMs();
    bool result = NetCliDispatch(conn->cli, read->buffer, read->bytes, conn);
    read->bytesProcessed += read->bytes;
    return result;
}

//============================================================================
static bool SocketNotifyCallback (
    AsyncSocket         sock,
    EAsyncNotifySocket  code,
    AsyncNotifySocket * notify,
    void **             userState
) {
    bool result = true;
    CliCsConn * conn;

    switch (code) {
        case kNotifySocketConnectSuccess: {
            conn = (CliCsConn *) notify->param;
            *userState = conn;
            conn->TransferRef("Connecting", "Connected");
            bool abandoned = true;
            
            if (abandoned)
                AsyncSocketDisconnect(sock, true);
            else
                NotifyConnSocketConnect(conn);
        }
        break;

        case kNotifySocketConnectFailed:
            conn = (CliCsConn *) notify->param;
            NotifyConnSocketConnectFailed(conn);
        break;

        case kNotifySocketDisconnect:
            conn = (CliCsConn *) *userState;
            NotifyConnSocketDisconnect(conn);
        break;

        case kNotifySocketRead:
            conn = (CliCsConn *) *userState;
            result = NotifyConnSocketRead(conn, (AsyncNotifySocketRead *) notify);
        break;
    }

    return result;
}

//============================================================================
static void Connect (
    const NetAddress &  addr,
    ConnectParam *      cp
) {
    CliCsConn * conn = NEWZERO(CliCsConn);
    conn->addr              = addr;
    conn->seq               = ConnNextSequence();
    conn->lastHeardTimeMs   = GetNonZeroTimeMs();
    conn->connectParam      = cp;

    conn->IncRef("Lifetime");
    conn->IncRef("Connecting");

    s_critsect.Enter();
    {
        while (CliCsConn * conn = s_conns.Head())
            UnlinkAndAbandonConn_CS(conn);
        s_conns.Link(conn);
    }
    s_critsect.Leave();

    Cli2Csr_Connect connect;
    connect.hdr.connType    = kConnTypeCliToCsr;
    connect.hdr.hdrBytes    = sizeof(connect.hdr);
    connect.hdr.buildId     = BuildId();
    connect.hdr.buildType   = BuildType();
    connect.hdr.branchId    = BranchId();
    connect.hdr.productId   = ProductId();
    connect.data.dataBytes  = sizeof(connect.data);

    AsyncSocketConnect(
        &conn->cancelId,
        addr,
        SocketNotifyCallback,
        conn,
        &connect,
        sizeof(connect),
        0,
        0
    );
}

//============================================================================
static void AsyncLookupCallback (
    void *              param,
    const wchar         name[],
    unsigned            addrCount,
    const NetAddress    addrs[]
) {
    if (!addrCount) {
        ReportNetError(kNetProtocolCli2Auth, kNetErrNameLookupFailed);
        return;
    }

    // Only connect to one server   
    addrCount = MIN(addrCount, 1);

    for (unsigned i = 0; i < addrCount; ++i) {
        Connect(addrs[i], (ConnectParam *)param);
    }
}


/*****************************************************************************
*
*   Message handlers
*
***/

//============================================================================
static bool Recv_PingReply (
    const byte  msg[],
    unsigned    bytes,
    void *
) {
    const Csr2Cli_PingReply & reply = *(const Csr2Cli_PingReply *)msg;

    NetTransRecv(reply.transId, msg, bytes);

    return true;
}

//============================================================================
static bool Recv_RegisterReply (
    const byte  msg[],
    unsigned    ,
    void *      param
) {
    CliCsConn * conn = (CliCsConn *)param;

    const Csr2Cli_RegisterReply & reply = *(const Csr2Cli_RegisterReply *)msg;

    conn->serverChallenge   = reply.serverChallenge;
    conn->latestBuildId     = reply.csrBuildId;

    ConnectedNotifyTrans * trans = NEW(ConnectedNotifyTrans)(
        conn->connectParam,
        conn->latestBuildId
    );
    NetTransSend(trans);
    
    conn->connectParam = nil;

    return true;
}

//============================================================================
static bool Recv_LoginReply (
    const byte  msg[],
    unsigned    bytes,
    void *
) {
    const Csr2Cli_LoginReply & reply = *(const Csr2Cli_LoginReply *)msg;

    NetTransRecv(reply.transId, msg, bytes);

    return true;
}


/*****************************************************************************
*
*   Protocol
*
***/

#define MSG(s)  kNetMsg_Cli2Csr_##s
static NetMsgInitSend s_send[] = {
    { MSG(PingRequest)          },
    { MSG(RegisterRequest)      },
    { MSG(LoginRequest)         },
};
#undef MSG

#define MSG(s)  kNetMsg_Csr2Cli_##s, Recv_##s
static NetMsgInitRecv s_recv[] = {
    { MSG(PingReply)            },
    { MSG(RegisterReply)        },
    { MSG(LoginReply)           },
};
#undef MSG


/*****************************************************************************
*
*   CliCsConn
*
***/

//===========================================================================
static unsigned CliCsConnTimerDestroyed (void * param) {
    
    CliCsConn * conn = (CliCsConn *) param;
    conn->DecRef("PingTimer");
    return kAsyncTimeInfinite;
}

//===========================================================================
static unsigned CliCsConnPingTimerProc (void * param) {
    
    ((CliCsConn *) param)->TimerPing();
    return kPingIntervalMs;
}

//============================================================================
CliCsConn::CliCsConn () {
    
    AtomicAdd(&s_perf[kPerfConnCount], 1);
}

//============================================================================
CliCsConn::~CliCsConn () {

    // Delete 'cli' after all refs have been removed    
    if (cli)
        NetCliDelete(cli, true);
        
    DEL(connectParam);

    AtomicAdd(&s_perf[kPerfConnCount], -1);
}

//============================================================================
void CliCsConn::AutoPing () {
    ASSERT(!pingTimer);
    
    IncRef("PingTimer");
    critsect.Enter();
    {
        AsyncTimerCreate(
            &pingTimer,
            CliCsConnPingTimerProc,
            sock ? 0 : kAsyncTimeInfinite,
            this
        );
    }
    critsect.Leave();
}

//============================================================================
void CliCsConn::StopAutoPing () {
    critsect.Enter();
    {
        if (AsyncTimer * timer = pingTimer) {
            pingTimer = nil;
            AsyncTimerDeleteCallback(timer, CliCsConnTimerDestroyed);
        }
    }
    critsect.Leave();
}

//============================================================================
void CliCsConn::TimerPing () {

    // Send a ping request
    pingSendTimeMs = GetNonZeroTimeMs();
    
    const unsigned_ptr msg[] = {
        kCli2Auth_PingRequest,
                    0,      // not a transaction
                    pingSendTimeMs,
                    0,      // no payload
                    nil
    };
    
    Send(msg, arrsize(msg));
}

//============================================================================
void CliCsConn::Send (const unsigned_ptr fields[], unsigned count) {

    critsect.Enter();
    {
        NetCliSend(cli, fields, count);
        NetCliFlush(cli);
    }
    critsect.Leave();
}


/*****************************************************************************
*
*   ConnectedNotifyTrans
*
***/

//============================================================================
void ConnectedNotifyTrans::Post () {

    if (m_connectParam && m_connectParam->callback)
        m_connectParam->callback(m_connectParam->param, m_latestBuildId);
}


/*****************************************************************************
*
*   LoginRequestTrans
*
***/

//============================================================================
LoginRequestTrans::LoginRequestTrans (
    const wchar             csrName[],
    const ShaDigest &       namePassHash,
    FNetCliCsrLoginCallback callback,
    void *                  param
) : NetCsrTrans(kCsrLoginTrans)
,   m_namePassHash(namePassHash)
,   m_callback(callback)
,   m_param(param)
{
    ASSERT(callback);
    StrCopy(m_csrName, csrName, arrsize(m_csrName));
}

//============================================================================
bool LoginRequestTrans::Send () {

    if (!AcquireConn())
        return false;
        
    ShaDigest challengeHash;
    dword clientChallenge = 0;
    
    CryptCreateRandomSeed(
        sizeof(clientChallenge),
        (byte *) &clientChallenge
    );

    CryptHashPasswordChallenge(
        clientChallenge,
        s_active->serverChallenge,
        m_namePassHash,
        &challengeHash
    );

    const unsigned_ptr msg[] = {
        kCli2Csr_LoginRequest,
                        m_transId,
                        clientChallenge,
        (unsigned_ptr)  m_csrName,
        (unsigned_ptr)  &challengeHash
    };

    m_conn->Send(msg, arrsize(msg));

    return true;
}

//============================================================================
void LoginRequestTrans::Post () {
    m_callback(
        m_result,
        m_param,
        m_csrId,
        m_csrFlags
    );
}

//============================================================================
bool LoginRequestTrans::Recv (
    const byte  msg[],
    unsigned    bytes
) {
    const Csr2Cli_LoginReply & reply = *(const Csr2Cli_LoginReply *) msg;
    
    m_result    = reply.result;
    m_csrId     = reply.csrId;
    m_csrFlags  = reply.csrFlags;
    
    m_state     = kTransStateComplete;
    
    return true;
}


} using namespace Csr;


/*****************************************************************************
*
*   NetCsrTrans
*
***/

//============================================================================
NetCsrTrans::NetCsrTrans (ETransType transType)
:   NetTrans(kNetProtocolCli2Csr, transType)
,   m_conn(nil)
{
}

//============================================================================
NetCsrTrans::~NetCsrTrans () {
    ReleaseConn();
}

//============================================================================
bool NetCsrTrans::AcquireConn () {
    if (!m_conn)
        m_conn = GetConnIncRef("AcquireConn");
    return m_conn != nil;
}

//============================================================================
void NetCsrTrans::ReleaseConn () {
    if (m_conn) {
        m_conn->DecRef("AcquireConn");
        m_conn = nil;
    }
}


/*****************************************************************************
*
*   Module functions
*
***/

//============================================================================
void CsrInitialize () {

    s_running = true;

    NetMsgProtocolRegister(
        kNetProtocolCli2Csr,
        false,
        s_send, arrsize(s_send),
        s_recv, arrsize(s_recv),
        kCsrDhGValue,
        BigNum(sizeof(kCsrDhXData), kCsrDhXData),
        BigNum(sizeof(kCsrDhNData), kCsrDhNData)
    );
}

//============================================================================
void CsrDestroy (bool wait) {

    s_running = false;
    
    NetTransCancelByProtocol(
        kNetProtocolCli2Csr,
        kNetErrRemoteShutdown
    );    
    NetMsgProtocolDestroy(
        kNetProtocolCli2Csr,
        false
    );

    s_critsect.Enter();
    {
        while (CliCsConn * conn = s_conns.Head())
            UnlinkAndAbandonConn_CS(conn);
        s_active = nil;
    }
    s_critsect.Leave();

    if (!wait)
        return;
        
    while (s_perf[kPerfConnCount]) {
        NetTransUpdate();
        AsyncSleep(10);
    }
}

//============================================================================
bool CsrQueryConnected () {

    bool result;
    s_critsect.Enter();
    {
        if (nil != (result = s_active))
            result &= (nil != s_active->cli);
    }
    s_critsect.Leave();
    return result;
}

//============================================================================
unsigned CsrGetConnId () {

    unsigned connId;
    s_critsect.Enter();
    {
        connId = (s_active) ? s_active->seq : 0;
    }
    s_critsect.Leave();
    return connId;
}

} using namespace Ngl;


/*****************************************************************************
*
*   Exports
*
***/

//============================================================================
void NetCliCsrStartConnect (
    const wchar *               addrList[],
    unsigned                    addrCount,
    FNetCliCsrConnectedCallback callback,
    void *                      param
) {
    // Only connect to one server
    addrCount = min(addrCount, 1);

    for (unsigned i = 0; i < addrCount; ++i) {
        // Do we need to lookup the address?
        const wchar * name = addrList[i];
        while (unsigned ch = *name) {
            ++name;
            if (!(isdigit(ch) || ch == L'.' || ch == L':')) {
                ConnectParam * cp = NEW(ConnectParam);
                cp->callback    = callback;
                cp->param       = param;

                AsyncCancelId cancelId;
                AsyncAddressLookupName(
                    &cancelId,
                    AsyncLookupCallback,
                    addrList[i],
                    kNetDefaultClientPort,
                    cp
                );
                break;
            }
        }
        if (!name[0]) {
            NetAddress addr;
            NetAddressFromString(&addr, addrList[i], kNetDefaultClientPort);
            
            ConnectParam * cp = NEW(ConnectParam);
            cp->callback    = callback;
            cp->param       = param;

            Connect(addr, cp);
        }
    }
}

//============================================================================
void NetCliCsrDisconnect () {
    
    s_critsect.Enter();
    {
        while (CliCsConn * conn = s_conns.Head())
            UnlinkAndAbandonConn_CS(conn);
        s_active = nil;
    }
    s_critsect.Leave();
}

//============================================================================
void NetCliCsrLoginRequest (
    const wchar             csrName[],
    const ShaDigest &       namePassHash,
    FNetCliCsrLoginCallback callback,
    void *                  param
) {
    LoginRequestTrans * trans = NEW(LoginRequestTrans)(
        csrName,
        namePassHash,
        callback,
        param
    );
    NetTransSend(trans);
}