400 lines
12 KiB

/*==LICENSE==*
CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011 Cyan Worlds, Inc.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
You can contact Cyan Worlds, Inc. by email legal@cyan.com
or by snail mail at:
Cyan Worlds, Inc.
14617 N Newport Hwy
Mead, WA 99021
*==LICENSE==*/
/*****************************************************************************
*
* $/Plasma20/Sources/Plasma/NucleusLib/pnNetCli/pnNcChannel.cpp
*
***/
#include "Pch.h"
#pragma hdrstop
namespace pnNetCli {
/*****************************************************************************
*
* Private
*
***/
struct ChannelCrit {
~ChannelCrit ();
ChannelCrit ();
inline void Enter () { m_critsect.Enter(); }
inline void Leave () { m_critsect.Leave(); }
inline void EnterSafe () { if (m_init) m_critsect.Enter(); }
inline void LeaveSafe () { if (m_init) m_critsect.Leave(); }
private:
bool m_init;
CCritSect m_critsect;
};
struct NetMsgChannel : AtomicRef {
LINK(NetMsgChannel) m_link;
unsigned m_protocol;
bool m_server;
// Message definitions
unsigned m_largestRecv;
ARRAY(NetMsgInitSend) m_sendMsgs;
ARRAY(NetMsgInitRecv) m_recvMsgs;
// Diffie-Hellman constants
unsigned m_dh_g;
BigNum m_dh_xa; // client: dh_x server: dh_a
BigNum m_dh_n;
};
static ChannelCrit s_channelCrit;
static LIST(NetMsgChannel) * s_channels;
/****************************************************************************
*
* ChannelCrit
*
***/
//===========================================================================
ChannelCrit::ChannelCrit () {
m_init = true;
}
//===========================================================================
ChannelCrit::~ChannelCrit () {
EnterSafe();
if (s_channels) {
while (NetMsgChannel * const channel = s_channels->Head()) {
s_channels->Unlink(channel);
channel->DecRef("ChannelLink");
}
DEL(s_channels);
s_channels = nil;
}
LeaveSafe();
}
/*****************************************************************************
*
* Internal functions
*
***/
//===========================================================================
// Returns max size of message in bytes
static unsigned ValidateMsg (const NetMsg & msg) {
ASSERT(msg.fields);
ASSERT(msg.count);
unsigned maxBytes = sizeof(word); // for message id
bool prevFieldWasVarCount = false;
for (unsigned i = 0; i < msg.count; i++) {
const NetMsgField & field = msg.fields[i];
for (;;) {
bool gotVarCount = false;
bool gotVarField = false;
if (field.type == kNetMsgFieldVarCount) {
if (gotVarField || gotVarCount)
FATAL("Msg definition may only include one variable length field");
gotVarCount = true;
break;
}
if (field.type == kNetMsgFieldVarPtr || field.type == kNetMsgFieldRawVarPtr) {
if (gotVarField || gotVarCount)
FATAL("Msg definition may only include one variable length field");
if (!prevFieldWasVarCount)
FATAL("Variable length field must preceded by variable length count field");
gotVarField = true;
break;
}
if (gotVarField)
FATAL("Variable length field must be the last field in message definition");
break;
}
prevFieldWasVarCount = false;
switch (field.type) {
case kNetMsgFieldInteger:
maxBytes += sizeof(qword);
break;
case kNetMsgFieldReal:
maxBytes += sizeof(double);
break;
case kNetMsgFieldVarPtr:
case kNetMsgFieldRawVarPtr:
break;
case kNetMsgFieldVarCount:
prevFieldWasVarCount = true;
// fall-thru...
case kNetMsgFieldString:
case kNetMsgFieldPtr:
case kNetMsgFieldRawPtr:
case kNetMsgFieldData:
case kNetMsgFieldRawData:
maxBytes += msg.fields[i].count * msg.fields[i].size;
break;
DEFAULT_FATAL(field.type);
}
}
return maxBytes;
}
//===========================================================================
template<class T>
static unsigned MaxMsgId (const T msgs[], unsigned count) {
unsigned maxMsgId = 0;
for (unsigned i = 0; i < count; i++) {
ASSERT(msgs[i].msg.count);
maxMsgId = max(msgs[i].msg.messageId, maxMsgId);
}
return maxMsgId;
}
//===========================================================================
static void AddSendMsgs_CS (
NetMsgChannel * channel,
const NetMsgInitSend src[],
unsigned count
) {
channel->m_sendMsgs.GrowToFit(MaxMsgId(src, count), true);
for (const NetMsgInitSend * term = src + count; src < term; ++src) {
NetMsgInitSend * const dst = &channel->m_sendMsgs[src[0].msg.messageId];
// check to ensure that the message id isn't already used
ASSERT(!dst->msg.count);
*dst = *src;
ValidateMsg(dst->msg);
}
}
//===========================================================================
static void AddRecvMsgs_CS (
NetMsgChannel * channel,
const NetMsgInitRecv src[],
unsigned count
) {
channel->m_recvMsgs.GrowToFit(MaxMsgId(src, count), true);
for (const NetMsgInitRecv * term = src + count; src < term; ++src) {
ASSERT(src->recv);
NetMsgInitRecv * const dst = &channel->m_recvMsgs[src[0].msg.messageId];
// check to ensure that the message id isn't already used
ASSERT(!dst->msg.count);
// copy the message handler
*dst = *src;
const unsigned bytes = ValidateMsg(dst->msg);
channel->m_largestRecv = max(channel->m_largestRecv, bytes);
}
}
//===========================================================================
static NetMsgChannel * FindChannel_CS (unsigned protocol, bool server) {
if (!s_channels)
return nil;
NetMsgChannel * channel = s_channels->Head();
for (; channel; channel = s_channels->Next(channel)) {
if ((channel->m_protocol == protocol) && (channel->m_server == server))
break;
}
return channel;
}
//===========================================================================
static NetMsgChannel * FindOrCreateChannel_CS (unsigned protocol, bool server) {
if (!s_channels) {
s_channels = NEW(LIST(NetMsgChannel));
s_channels->SetLinkOffset(offsetof(NetMsgChannel, m_link));
}
// find or create protocol
NetMsgChannel * channel = FindChannel_CS(protocol, server);
if (!channel) {
channel = NEW(NetMsgChannel);
channel->m_protocol = protocol;
channel->m_server = server;
channel->m_largestRecv = 0;
s_channels->Link(channel);
channel->IncRef("ChannelLink");
}
return channel;
}
/*****************************************************************************
*
* Module functions
*
***/
//============================================================================
NetMsgChannel * NetMsgChannelLock (
unsigned protocol,
bool server,
unsigned * largestRecv
) {
NetMsgChannel * channel;
s_channelCrit.Enter();
if (nil != (channel = FindChannel_CS(protocol, server))) {
*largestRecv = channel->m_largestRecv;
channel->IncRef("ChannelLock");
}
else {
*largestRecv = 0;
}
s_channelCrit.Leave();
return channel;
}
//============================================================================
void NetMsgChannelUnlock (
NetMsgChannel * channel
) {
s_channelCrit.Enter();
{
channel->DecRef("ChannelLock");
}
s_channelCrit.Leave();
}
//============================================================================
const NetMsgInitRecv * NetMsgChannelFindRecvMessage (
NetMsgChannel * channel,
unsigned messageId
) {
// Is message in range?
if (messageId >= channel->m_recvMsgs.Count())
return nil;
// Is message defined?
const NetMsgInitRecv * recvMsg = &channel->m_recvMsgs[messageId];
if (!recvMsg->msg.count)
return nil;
// Success!
return recvMsg;
}
//============================================================================
const NetMsgInitSend * NetMsgChannelFindSendMessage (
NetMsgChannel * channel,
unsigned messageId
) {
// Is message in range?
ASSERT(messageId < channel->m_sendMsgs.Count());
// Is message defined?
const NetMsgInitSend * sendMsg = &channel->m_sendMsgs[messageId];
ASSERTMSG(sendMsg->msg.count, "NetMsg not found for send");
return sendMsg;
}
//============================================================================
void NetMsgChannelGetDhConstants (
const NetMsgChannel * channel,
unsigned * dh_g,
const BigNum ** dh_xa,
const BigNum ** dh_n
) {
*dh_g = channel->m_dh_g;
*dh_xa = &channel->m_dh_xa;
*dh_n = &channel->m_dh_n;
}
} // namespace pnNetCli
/*****************************************************************************
*
* Exports
*
***/
//===========================================================================
void NetMsgProtocolRegister (
unsigned protocol,
bool server,
const NetMsgInitSend sendMsgs[],
unsigned sendMsgCount,
const NetMsgInitRecv recvMsgs[],
unsigned recvMsgCount,
unsigned dh_g,
const BigNum & dh_xa, // client: dh_x server: dh_a
const BigNum & dh_n
) {
s_channelCrit.EnterSafe();
{
NetMsgChannel * channel = FindOrCreateChannel_CS(protocol, server);
// make sure no connections have been established on this protocol, otherwise
// we'll be modifying a live data structure; NetCli's don't lock their protocol
// to operate on it once they have linked to it!
ASSERT(channel->GetRefCount() == 1);
channel->m_dh_g = dh_g;
channel->m_dh_xa = dh_xa;
channel->m_dh_n = dh_n;
if (sendMsgCount)
AddSendMsgs_CS(channel, sendMsgs, sendMsgCount);
if (recvMsgCount)
AddRecvMsgs_CS(channel, recvMsgs, recvMsgCount);
}
s_channelCrit.LeaveSafe();
}
//===========================================================================
void NetMsgProtocolDestroy (unsigned protocol, bool server) {
s_channelCrit.EnterSafe();
if (NetMsgChannel * channel = FindChannel_CS(protocol, server)) {
s_channels->Unlink(channel);
channel->DecRef("ChannelLink");
}
s_channelCrit.LeaveSafe();
}