883 lines
22 KiB

/*==LICENSE==*
CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011 Cyan Worlds, Inc.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
You can contact Cyan Worlds, Inc. by email legal@cyan.com
or by snail mail at:
Cyan Worlds, Inc.
14617 N Newport Hwy
Mead, WA 99021
*==LICENSE==*/
/*****************************************************************************
*
* $/Plasma20/Sources/Plasma/PubUtilLib/plNetGameLib/Private/plNglCsr.cpp
*
***/
#include "../Pch.h"
#pragma hdrstop
namespace Ngl { namespace Csr {
/*****************************************************************************
*
* Internal types
*
***/
struct ConnectParam {
FNetCliCsrConnectedCallback callback;
void * param;
};
//============================================================================
// Connection record
//============================================================================
struct CliCsConn : AtomicRef {
LINK(CliCsConn) link;
CCritSect critsect;
AsyncSocket sock;
AsyncCancelId cancelId;
NetCli * cli;
NetAddress addr;
unsigned seq;
bool abandoned;
unsigned serverChallenge;
unsigned latestBuildId;
ConnectParam * connectParam;
// ping
AsyncTimer * pingTimer;
unsigned pingSendTimeMs;
unsigned lastHeardTimeMs;
CliCsConn ();
~CliCsConn ();
void AutoPing ();
void StopAutoPing ();
void TimerPing ();
void Send (const unsigned_ptr fields[], unsigned count);
};
//============================================================================
// Transaction objects
//============================================================================
struct ConnectedNotifyTrans : NetNotifyTrans {
ConnectParam * m_connectParam;
unsigned m_latestBuildId;
ConnectedNotifyTrans (ConnectParam * cp, unsigned lbi)
: NetNotifyTrans(kCsrConnectedNotifyTrans)
, m_connectParam(cp)
, m_latestBuildId(lbi)
{ }
~ConnectedNotifyTrans () {
DEL(m_connectParam);
}
void Post ();
};
struct LoginRequestTrans : NetCsrTrans {
wchar m_csrName[kMaxAccountNameLength];
ShaDigest m_namePassHash;
FNetCliCsrLoginCallback m_callback;
void * m_param;
Uuid m_csrId;
unsigned m_csrFlags;
LoginRequestTrans (
const wchar csrName[],
const ShaDigest & namePassHash,
FNetCliCsrLoginCallback callback,
void * param
);
bool Send ();
void Post ();
bool Recv (
const byte msg[],
unsigned bytes
);
};
/*****************************************************************************
*
* Internal data
*
***/
enum {
kPerfConnCount,
kPingDisabled,
kNumPerf
};
static bool s_running;
static CCritSect s_critsect;
static LISTDECL(CliCsConn, link) s_conns;
static CliCsConn * s_active;
static long s_perf[kNumPerf];
/*****************************************************************************
*
* Internal functions
*
***/
//===========================================================================
static unsigned GetNonZeroTimeMs () {
if (unsigned ms = TimeGetMs())
return ms;
return 1;
}
//============================================================================
static CliCsConn * GetConnIncRef_CS (const char tag[]) {
if (CliCsConn * conn = s_active)
if (conn->cli) {
conn->IncRef(tag);
return conn;
}
return nil;
}
//============================================================================
static CliCsConn * GetConnIncRef (const char tag[]) {
CliCsConn * conn;
s_critsect.Enter();
{
conn = GetConnIncRef_CS(tag);
}
s_critsect.Leave();
return conn;
}
//============================================================================
static void UnlinkAndAbandonConn_CS (CliCsConn * conn) {
s_conns.Unlink(conn);
conn->abandoned = true;
if (conn->cancelId) {
AsyncSocketConnectCancel(nil, conn->cancelId);
conn->cancelId = 0;
}
else if (conn->sock) {
AsyncSocketDisconnect(conn->sock, true);
}
else {
conn->DecRef("Lifetime");
}
}
//============================================================================
static void SendRegisterRequest (CliCsConn * conn) {
const unsigned_ptr msg[] = {
kCli2Csr_RegisterRequest,
0
};
conn->Send(msg, arrsize(msg));
}
//============================================================================
static bool ConnEncrypt (ENetError error, void * param) {
CliCsConn * conn = (CliCsConn *) param;
if (IS_NET_SUCCESS(error)) {
s_critsect.Enter();
{
s_active = conn;
conn->AutoPing();
conn->IncRef();
}
s_critsect.Leave();
SendRegisterRequest(conn);
conn->DecRef();
}
return IS_NET_SUCCESS(error);
}
//============================================================================
static void NotifyConnSocketConnect (CliCsConn * conn) {
conn->cli = NetCliConnectAccept(
conn->sock,
kNetProtocolCli2Csr,
false,
ConnEncrypt,
0,
nil,
conn
);
}
//============================================================================
static void NotifyConnSocketConnectFailed (CliCsConn * conn) {
bool notify;
s_critsect.Enter();
{
conn->cancelId = 0;
s_conns.Unlink(conn);
notify
= s_running
&& !conn->abandoned
&& (!s_active || conn == s_active);
if (conn == s_active)
s_active = nil;
}
s_critsect.Leave();
NetTransCancelByConnId(conn->seq, kNetErrTimeout);
conn->DecRef("Connecting");
conn->DecRef("Lifetime");
if (notify)
ReportNetError(kNetProtocolCli2Csr, kNetErrConnectFailed);
}
//============================================================================
static void NotifyConnSocketDisconnect (CliCsConn * conn) {
conn->StopAutoPing();
bool notify;
s_critsect.Enter();
{
s_conns.Unlink(conn);
notify
= s_running
&& !conn->abandoned
&& (!s_active || conn == s_active);
if (conn == s_active)
s_active = nil;
}
s_critsect.Leave();
// Cancel all transactions in process on this connection.
NetTransCancelByConnId(conn->seq, kNetErrTimeout);
conn->DecRef("Connected");
conn->DecRef("Lifetime");
if (notify)
ReportNetError(kNetProtocolCli2Csr, kNetErrDisconnected);
}
//============================================================================
static bool NotifyConnSocketRead (CliCsConn * conn, AsyncNotifySocketRead * read) {
conn->lastHeardTimeMs = GetNonZeroTimeMs();
bool result = NetCliDispatch(conn->cli, read->buffer, read->bytes, conn);
read->bytesProcessed += read->bytes;
return result;
}
//============================================================================
static bool SocketNotifyCallback (
AsyncSocket sock,
EAsyncNotifySocket code,
AsyncNotifySocket * notify,
void ** userState
) {
bool result = true;
CliCsConn * conn;
switch (code) {
case kNotifySocketConnectSuccess: {
conn = (CliCsConn *) notify->param;
*userState = conn;
conn->TransferRef("Connecting", "Connected");
bool abandoned = true;
if (abandoned)
AsyncSocketDisconnect(sock, true);
else
NotifyConnSocketConnect(conn);
}
break;
case kNotifySocketConnectFailed:
conn = (CliCsConn *) notify->param;
NotifyConnSocketConnectFailed(conn);
break;
case kNotifySocketDisconnect:
conn = (CliCsConn *) *userState;
NotifyConnSocketDisconnect(conn);
break;
case kNotifySocketRead:
conn = (CliCsConn *) *userState;
result = NotifyConnSocketRead(conn, (AsyncNotifySocketRead *) notify);
break;
}
return result;
}
//============================================================================
static void Connect (
const NetAddress & addr,
ConnectParam * cp
) {
CliCsConn * conn = NEWZERO(CliCsConn);
conn->addr = addr;
conn->seq = ConnNextSequence();
conn->lastHeardTimeMs = GetNonZeroTimeMs();
conn->connectParam = cp;
conn->IncRef("Lifetime");
conn->IncRef("Connecting");
s_critsect.Enter();
{
while (CliCsConn * conn = s_conns.Head())
UnlinkAndAbandonConn_CS(conn);
s_conns.Link(conn);
}
s_critsect.Leave();
Cli2Csr_Connect connect;
connect.hdr.connType = kConnTypeCliToCsr;
connect.hdr.hdrBytes = sizeof(connect.hdr);
connect.hdr.buildId = BuildId();
connect.hdr.buildType = BuildType();
connect.hdr.branchId = BranchId();
connect.hdr.productId = ProductId();
connect.data.dataBytes = sizeof(connect.data);
AsyncSocketConnect(
&conn->cancelId,
addr,
SocketNotifyCallback,
conn,
&connect,
sizeof(connect),
0,
0
);
}
//============================================================================
static void AsyncLookupCallback (
void * param,
const wchar name[],
unsigned addrCount,
const NetAddress addrs[]
) {
if (!addrCount) {
ReportNetError(kNetProtocolCli2Auth, kNetErrNameLookupFailed);
return;
}
// Only connect to one server
addrCount = MIN(addrCount, 1);
for (unsigned i = 0; i < addrCount; ++i) {
Connect(addrs[i], (ConnectParam *)param);
}
}
/*****************************************************************************
*
* Message handlers
*
***/
//============================================================================
static bool Recv_PingReply (
const byte msg[],
unsigned bytes,
void *
) {
const Csr2Cli_PingReply & reply = *(const Csr2Cli_PingReply *)msg;
NetTransRecv(reply.transId, msg, bytes);
return true;
}
//============================================================================
static bool Recv_RegisterReply (
const byte msg[],
unsigned ,
void * param
) {
CliCsConn * conn = (CliCsConn *)param;
const Csr2Cli_RegisterReply & reply = *(const Csr2Cli_RegisterReply *)msg;
conn->serverChallenge = reply.serverChallenge;
conn->latestBuildId = reply.csrBuildId;
ConnectedNotifyTrans * trans = NEW(ConnectedNotifyTrans)(
conn->connectParam,
conn->latestBuildId
);
NetTransSend(trans);
conn->connectParam = nil;
return true;
}
//============================================================================
static bool Recv_LoginReply (
const byte msg[],
unsigned bytes,
void *
) {
const Csr2Cli_LoginReply & reply = *(const Csr2Cli_LoginReply *)msg;
NetTransRecv(reply.transId, msg, bytes);
return true;
}
/*****************************************************************************
*
* Protocol
*
***/
#define MSG(s) kNetMsg_Cli2Csr_##s
static NetMsgInitSend s_send[] = {
{ MSG(PingRequest) },
{ MSG(RegisterRequest) },
{ MSG(LoginRequest) },
};
#undef MSG
#define MSG(s) kNetMsg_Csr2Cli_##s, Recv_##s
static NetMsgInitRecv s_recv[] = {
{ MSG(PingReply) },
{ MSG(RegisterReply) },
{ MSG(LoginReply) },
};
#undef MSG
/*****************************************************************************
*
* CliCsConn
*
***/
//===========================================================================
static unsigned CliCsConnTimerDestroyed (void * param) {
CliCsConn * conn = (CliCsConn *) param;
conn->DecRef("PingTimer");
return kAsyncTimeInfinite;
}
//===========================================================================
static unsigned CliCsConnPingTimerProc (void * param) {
((CliCsConn *) param)->TimerPing();
return kPingIntervalMs;
}
//============================================================================
CliCsConn::CliCsConn () {
AtomicAdd(&s_perf[kPerfConnCount], 1);
}
//============================================================================
CliCsConn::~CliCsConn () {
// Delete 'cli' after all refs have been removed
if (cli)
NetCliDelete(cli, true);
DEL(connectParam);
AtomicAdd(&s_perf[kPerfConnCount], -1);
}
//============================================================================
void CliCsConn::AutoPing () {
ASSERT(!pingTimer);
IncRef("PingTimer");
critsect.Enter();
{
AsyncTimerCreate(
&pingTimer,
CliCsConnPingTimerProc,
sock ? 0 : kAsyncTimeInfinite,
this
);
}
critsect.Leave();
}
//============================================================================
void CliCsConn::StopAutoPing () {
critsect.Enter();
{
if (AsyncTimer * timer = pingTimer) {
pingTimer = nil;
AsyncTimerDeleteCallback(timer, CliCsConnTimerDestroyed);
}
}
critsect.Leave();
}
//============================================================================
void CliCsConn::TimerPing () {
// Send a ping request
pingSendTimeMs = GetNonZeroTimeMs();
const unsigned_ptr msg[] = {
kCli2Auth_PingRequest,
0, // not a transaction
pingSendTimeMs,
0, // no payload
nil
};
Send(msg, arrsize(msg));
}
//============================================================================
void CliCsConn::Send (const unsigned_ptr fields[], unsigned count) {
critsect.Enter();
{
NetCliSend(cli, fields, count);
NetCliFlush(cli);
}
critsect.Leave();
}
/*****************************************************************************
*
* ConnectedNotifyTrans
*
***/
//============================================================================
void ConnectedNotifyTrans::Post () {
if (m_connectParam && m_connectParam->callback)
m_connectParam->callback(m_connectParam->param, m_latestBuildId);
}
/*****************************************************************************
*
* LoginRequestTrans
*
***/
//============================================================================
LoginRequestTrans::LoginRequestTrans (
const wchar csrName[],
const ShaDigest & namePassHash,
FNetCliCsrLoginCallback callback,
void * param
) : NetCsrTrans(kCsrLoginTrans)
, m_namePassHash(namePassHash)
, m_callback(callback)
, m_param(param)
{
ASSERT(callback);
StrCopy(m_csrName, csrName, arrsize(m_csrName));
}
//============================================================================
bool LoginRequestTrans::Send () {
if (!AcquireConn())
return false;
ShaDigest challengeHash;
dword clientChallenge = 0;
CryptCreateRandomSeed(
sizeof(clientChallenge),
(byte *) &clientChallenge
);
CryptHashPasswordChallenge(
clientChallenge,
s_active->serverChallenge,
m_namePassHash,
&challengeHash
);
const unsigned_ptr msg[] = {
kCli2Csr_LoginRequest,
m_transId,
clientChallenge,
(unsigned_ptr) m_csrName,
(unsigned_ptr) &challengeHash
};
m_conn->Send(msg, arrsize(msg));
return true;
}
//============================================================================
void LoginRequestTrans::Post () {
m_callback(
m_result,
m_param,
m_csrId,
m_csrFlags
);
}
//============================================================================
bool LoginRequestTrans::Recv (
const byte msg[],
unsigned bytes
) {
const Csr2Cli_LoginReply & reply = *(const Csr2Cli_LoginReply *) msg;
m_result = reply.result;
m_csrId = reply.csrId;
m_csrFlags = reply.csrFlags;
m_state = kTransStateComplete;
return true;
}
} using namespace Csr;
/*****************************************************************************
*
* NetCsrTrans
*
***/
//============================================================================
NetCsrTrans::NetCsrTrans (ETransType transType)
: NetTrans(kNetProtocolCli2Csr, transType)
, m_conn(nil)
{
}
//============================================================================
NetCsrTrans::~NetCsrTrans () {
ReleaseConn();
}
//============================================================================
bool NetCsrTrans::AcquireConn () {
if (!m_conn)
m_conn = GetConnIncRef("AcquireConn");
return m_conn != nil;
}
//============================================================================
void NetCsrTrans::ReleaseConn () {
if (m_conn) {
m_conn->DecRef("AcquireConn");
m_conn = nil;
}
}
/*****************************************************************************
*
* Module functions
*
***/
//============================================================================
void CsrInitialize () {
s_running = true;
NetMsgProtocolRegister(
kNetProtocolCli2Csr,
false,
s_send, arrsize(s_send),
s_recv, arrsize(s_recv),
kCsrDhGValue,
BigNum(sizeof(kCsrDhXData), kCsrDhXData),
BigNum(sizeof(kCsrDhNData), kCsrDhNData)
);
}
//============================================================================
void CsrDestroy (bool wait) {
s_running = false;
NetTransCancelByProtocol(
kNetProtocolCli2Csr,
kNetErrRemoteShutdown
);
NetMsgProtocolDestroy(
kNetProtocolCli2Csr,
false
);
s_critsect.Enter();
{
while (CliCsConn * conn = s_conns.Head())
UnlinkAndAbandonConn_CS(conn);
s_active = nil;
}
s_critsect.Leave();
if (!wait)
return;
while (s_perf[kPerfConnCount]) {
NetTransUpdate();
AsyncSleep(10);
}
}
//============================================================================
bool CsrQueryConnected () {
bool result;
s_critsect.Enter();
{
if (nil != (result = s_active))
result &= (nil != s_active->cli);
}
s_critsect.Leave();
return result;
}
//============================================================================
unsigned CsrGetConnId () {
unsigned connId;
s_critsect.Enter();
{
connId = (s_active) ? s_active->seq : 0;
}
s_critsect.Leave();
return connId;
}
} using namespace Ngl;
/*****************************************************************************
*
* Exports
*
***/
//============================================================================
void NetCliCsrStartConnect (
const wchar * addrList[],
unsigned addrCount,
FNetCliCsrConnectedCallback callback,
void * param
) {
// Only connect to one server
addrCount = min(addrCount, 1);
for (unsigned i = 0; i < addrCount; ++i) {
// Do we need to lookup the address?
const wchar * name = addrList[i];
while (unsigned ch = *name) {
++name;
if (!(isdigit(ch) || ch == L'.' || ch == L':')) {
ConnectParam * cp = NEW(ConnectParam);
cp->callback = callback;
cp->param = param;
AsyncCancelId cancelId;
AsyncAddressLookupName(
&cancelId,
AsyncLookupCallback,
addrList[i],
kNetDefaultClientPort,
cp
);
break;
}
}
if (!name[0]) {
NetAddress addr;
NetAddressFromString(&addr, addrList[i], kNetDefaultClientPort);
ConnectParam * cp = NEW(ConnectParam);
cp->callback = callback;
cp->param = param;
Connect(addr, cp);
}
}
}
//============================================================================
void NetCliCsrDisconnect () {
s_critsect.Enter();
{
while (CliCsConn * conn = s_conns.Head())
UnlinkAndAbandonConn_CS(conn);
s_active = nil;
}
s_critsect.Leave();
}
//============================================================================
void NetCliCsrLoginRequest (
const wchar csrName[],
const ShaDigest & namePassHash,
FNetCliCsrLoginCallback callback,
void * param
) {
LoginRequestTrans * trans = NEW(LoginRequestTrans)(
csrName,
namePassHash,
callback,
param
);
NetTransSend(trans);
}