You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
418 lines
12 KiB
418 lines
12 KiB
/*==LICENSE==* |
|
|
|
CyanWorlds.com Engine - MMOG client, server and tools |
|
Copyright (C) 2011 Cyan Worlds, Inc. |
|
|
|
This program is free software: you can redistribute it and/or modify |
|
it under the terms of the GNU General Public License as published by |
|
the Free Software Foundation, either version 3 of the License, or |
|
(at your option) any later version. |
|
|
|
This program is distributed in the hope that it will be useful, |
|
but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
GNU General Public License for more details. |
|
|
|
You should have received a copy of the GNU General Public License |
|
along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
|
|
Additional permissions under GNU GPL version 3 section 7 |
|
|
|
If you modify this Program, or any covered work, by linking or |
|
combining it with any of RAD Game Tools Bink SDK, Autodesk 3ds Max SDK, |
|
NVIDIA PhysX SDK, Microsoft DirectX SDK, OpenSSL library, Independent |
|
JPEG Group JPEG library, Microsoft Windows Media SDK, or Apple QuickTime SDK |
|
(or a modified version of those libraries), |
|
containing parts covered by the terms of the Bink SDK EULA, 3ds Max EULA, |
|
PhysX SDK EULA, DirectX SDK EULA, OpenSSL and SSLeay licenses, IJG |
|
JPEG Library README, Windows Media SDK EULA, or QuickTime SDK EULA, the |
|
licensors of this Program grant you additional |
|
permission to convey the resulting work. Corresponding Source for a |
|
non-source form of such a combination shall include the source code for |
|
the parts of OpenSSL and IJG JPEG Library used as well as that of the covered |
|
work. |
|
|
|
You can contact Cyan Worlds, Inc. by email legal@cyan.com |
|
or by snail mail at: |
|
Cyan Worlds, Inc. |
|
14617 N Newport Hwy |
|
Mead, WA 99021 |
|
|
|
*==LICENSE==*/ |
|
/***************************************************************************** |
|
* |
|
* $/Plasma20/Sources/Plasma/NucleusLib/pnNetCli/pnNcChannel.cpp |
|
* |
|
***/ |
|
|
|
#include "Pch.h" |
|
#include <list> |
|
#include <mutex> |
|
#include "hsRefCnt.h" |
|
#pragma hdrstop |
|
|
|
|
|
namespace pnNetCli { |
|
|
|
/***************************************************************************** |
|
* |
|
* Private |
|
* |
|
***/ |
|
|
|
struct ChannelCrit { |
|
~ChannelCrit(); |
|
ChannelCrit() : m_init(true) { } |
|
|
|
inline void lock() |
|
{ |
|
hsAssert(m_init, "Bad things have happened."); |
|
m_critsect.lock(); |
|
} |
|
|
|
inline void unlock() |
|
{ |
|
hsAssert(m_init, "Bad things have happened."); |
|
m_critsect.unlock(); |
|
} |
|
|
|
private: |
|
bool m_init; |
|
std::mutex m_critsect; |
|
}; |
|
|
|
struct NetMsgChannel : hsRefCnt { |
|
NetMsgChannel() : hsRefCnt(0) { } |
|
|
|
uint32_t m_protocol; |
|
bool m_server; |
|
|
|
// Message definitions |
|
uint32_t m_largestRecv; |
|
ARRAY(NetMsgInitSend) m_sendMsgs; |
|
ARRAY(NetMsgInitRecv) m_recvMsgs; |
|
|
|
// Diffie-Hellman constants |
|
uint32_t m_dh_g; |
|
plBigNum m_dh_xa; // client: dh_x server: dh_a |
|
plBigNum m_dh_n; |
|
}; |
|
|
|
static ChannelCrit s_channelCrit; |
|
static std::list<NetMsgChannel*>* s_channels; |
|
|
|
|
|
/**************************************************************************** |
|
* |
|
* ChannelCrit |
|
* |
|
***/ |
|
|
|
//=========================================================================== |
|
ChannelCrit::~ChannelCrit () { |
|
std::lock_guard<ChannelCrit> lock(*this); |
|
|
|
if (s_channels) { |
|
while (s_channels->size()) { |
|
NetMsgChannel* const channel = s_channels->front(); |
|
s_channels->remove(channel); |
|
channel->UnRef("ChannelLink"); |
|
} |
|
|
|
delete s_channels; |
|
s_channels = nil; |
|
} |
|
} |
|
|
|
|
|
/***************************************************************************** |
|
* |
|
* Internal functions |
|
* |
|
***/ |
|
|
|
//=========================================================================== |
|
// Returns max size of message in bytes |
|
static unsigned ValidateMsg (const NetMsg & msg) { |
|
ASSERT(msg.fields); |
|
ASSERT(msg.count); |
|
|
|
unsigned maxBytes = sizeof(uint16_t); // for message id |
|
bool prevFieldWasVarCount = false; |
|
|
|
for (unsigned i = 0; i < msg.count; i++) { |
|
const NetMsgField & field = msg.fields[i]; |
|
|
|
for (;;) { |
|
bool gotVarCount = false; |
|
bool gotVarField = false; |
|
if (field.type == kNetMsgFieldVarCount) { |
|
if (gotVarField || gotVarCount) |
|
FATAL("Msg definition may only include one variable length field"); |
|
gotVarCount = true; |
|
break; |
|
} |
|
if (field.type == kNetMsgFieldVarPtr || field.type == kNetMsgFieldRawVarPtr) { |
|
if (gotVarField || gotVarCount) |
|
FATAL("Msg definition may only include one variable length field"); |
|
if (!prevFieldWasVarCount) |
|
FATAL("Variable length field must preceded by variable length count field"); |
|
gotVarField = true; |
|
break; |
|
} |
|
if (gotVarField) |
|
FATAL("Variable length field must be the last field in message definition"); |
|
break; |
|
} |
|
|
|
prevFieldWasVarCount = false; |
|
switch (field.type) { |
|
case kNetMsgFieldInteger: |
|
maxBytes += sizeof(uint64_t); |
|
break; |
|
|
|
case kNetMsgFieldReal: |
|
maxBytes += sizeof(double); |
|
break; |
|
|
|
case kNetMsgFieldVarPtr: |
|
case kNetMsgFieldRawVarPtr: |
|
break; |
|
|
|
case kNetMsgFieldVarCount: |
|
prevFieldWasVarCount = true; |
|
// fall-thru... |
|
case kNetMsgFieldString: |
|
case kNetMsgFieldPtr: |
|
case kNetMsgFieldRawPtr: |
|
case kNetMsgFieldData: |
|
case kNetMsgFieldRawData: |
|
maxBytes += msg.fields[i].count * msg.fields[i].size; |
|
break; |
|
|
|
DEFAULT_FATAL(field.type); |
|
} |
|
} |
|
|
|
return maxBytes; |
|
} |
|
|
|
|
|
//=========================================================================== |
|
template<class T> |
|
static unsigned MaxMsgId (const T msgs[], unsigned count) { |
|
unsigned maxMsgId = 0; |
|
|
|
for (unsigned i = 0; i < count; i++) { |
|
ASSERT(msgs[i].msg.count); |
|
maxMsgId = std::max(msgs[i].msg.messageId, maxMsgId); |
|
} |
|
return maxMsgId; |
|
} |
|
|
|
//=========================================================================== |
|
static void AddSendMsgs_CS ( |
|
NetMsgChannel * channel, |
|
const NetMsgInitSend src[], |
|
unsigned count |
|
) { |
|
channel->m_sendMsgs.GrowToFit(MaxMsgId(src, count), true); |
|
|
|
for (const NetMsgInitSend * term = src + count; src < term; ++src) { |
|
NetMsgInitSend * const dst = &channel->m_sendMsgs[src[0].msg.messageId]; |
|
|
|
// check to ensure that the message id isn't already used |
|
ASSERT(!dst->msg.count); |
|
|
|
*dst = *src; |
|
ValidateMsg(dst->msg); |
|
} |
|
} |
|
|
|
//=========================================================================== |
|
static void AddRecvMsgs_CS ( |
|
NetMsgChannel * channel, |
|
const NetMsgInitRecv src[], |
|
unsigned count |
|
) { |
|
channel->m_recvMsgs.GrowToFit(MaxMsgId(src, count), true); |
|
|
|
for (const NetMsgInitRecv * term = src + count; src < term; ++src) { |
|
ASSERT(src->recv); |
|
NetMsgInitRecv * const dst = &channel->m_recvMsgs[src[0].msg.messageId]; |
|
|
|
// check to ensure that the message id isn't already used |
|
ASSERT(!dst->msg.count); |
|
|
|
// copy the message handler |
|
*dst = *src; |
|
|
|
const uint32_t bytes = ValidateMsg(dst->msg); |
|
channel->m_largestRecv = std::max(channel->m_largestRecv, bytes); |
|
} |
|
} |
|
|
|
//=========================================================================== |
|
static NetMsgChannel* FindChannel_CS (uint32_t protocol, bool server) { |
|
if (!s_channels) |
|
return nil; |
|
|
|
std::list<NetMsgChannel*>::iterator it = s_channels->begin(); |
|
for (; it != s_channels->end(); ++it) { |
|
if (((*it)->m_protocol == protocol) && ((*it)->m_server == server)) |
|
return *it; |
|
} |
|
|
|
return nil; |
|
} |
|
|
|
//=========================================================================== |
|
static NetMsgChannel* FindOrCreateChannel_CS (uint32_t protocol, bool server) { |
|
if (!s_channels) { |
|
s_channels = new std::list<NetMsgChannel*>(); |
|
} |
|
|
|
// find or create protocol |
|
NetMsgChannel * channel = FindChannel_CS(protocol, server); |
|
if (!channel) { |
|
channel = new NetMsgChannel(); |
|
channel->m_protocol = protocol; |
|
channel->m_server = server; |
|
channel->m_largestRecv = 0; |
|
|
|
s_channels->push_back(channel); |
|
channel->Ref("ChannelLink"); |
|
} |
|
|
|
return channel; |
|
} |
|
|
|
|
|
/***************************************************************************** |
|
* |
|
* Module functions |
|
* |
|
***/ |
|
|
|
//============================================================================ |
|
NetMsgChannel * NetMsgChannelLock ( |
|
unsigned protocol, |
|
bool server, |
|
uint32_t * largestRecv |
|
) { |
|
NetMsgChannel * channel; |
|
std::lock_guard<ChannelCrit> lock(s_channelCrit); |
|
if (nullptr != (channel = FindChannel_CS(protocol, server))) { |
|
*largestRecv = channel->m_largestRecv; |
|
channel->Ref("ChannelLock"); |
|
} |
|
else { |
|
*largestRecv = 0; |
|
} |
|
return channel; |
|
} |
|
|
|
//============================================================================ |
|
void NetMsgChannelUnlock ( |
|
NetMsgChannel * channel |
|
) { |
|
std::lock_guard<ChannelCrit> lock(s_channelCrit); |
|
|
|
channel->UnRef("ChannelLock"); |
|
} |
|
|
|
//============================================================================ |
|
const NetMsgInitRecv * NetMsgChannelFindRecvMessage ( |
|
NetMsgChannel * channel, |
|
unsigned messageId |
|
) { |
|
// Is message in range? |
|
if (messageId >= channel->m_recvMsgs.Count()) |
|
return nil; |
|
|
|
// Is message defined? |
|
const NetMsgInitRecv * recvMsg = &channel->m_recvMsgs[messageId]; |
|
if (!recvMsg->msg.count) |
|
return nil; |
|
|
|
// Success! |
|
return recvMsg; |
|
} |
|
|
|
//============================================================================ |
|
const NetMsgInitSend * NetMsgChannelFindSendMessage ( |
|
NetMsgChannel * channel, |
|
unsigned messageId |
|
) { |
|
// Is message in range? |
|
ASSERT(messageId < channel->m_sendMsgs.Count()); |
|
|
|
// Is message defined? |
|
const NetMsgInitSend * sendMsg = &channel->m_sendMsgs[messageId]; |
|
ASSERTMSG(sendMsg->msg.count, "NetMsg not found for send"); |
|
|
|
return sendMsg; |
|
} |
|
|
|
//============================================================================ |
|
void NetMsgChannelGetDhConstants ( |
|
const NetMsgChannel * channel, |
|
uint32_t * dh_g, |
|
const plBigNum** dh_xa, |
|
const plBigNum** dh_n |
|
) { |
|
if (dh_g) *dh_g = channel->m_dh_g; |
|
if (dh_xa) *dh_xa = &channel->m_dh_xa; |
|
if (dh_n) *dh_n = &channel->m_dh_n; |
|
} |
|
|
|
|
|
} // namespace pnNetCli |
|
|
|
|
|
/***************************************************************************** |
|
* |
|
* Exports |
|
* |
|
***/ |
|
|
|
//=========================================================================== |
|
void NetMsgProtocolRegister ( |
|
uint32_t protocol, |
|
bool server, |
|
const NetMsgInitSend sendMsgs[], |
|
uint32_t sendMsgCount, |
|
const NetMsgInitRecv recvMsgs[], |
|
uint32_t recvMsgCount, |
|
uint32_t dh_g, |
|
const plBigNum& dh_xa, // client: dh_x server: dh_a |
|
const plBigNum& dh_n |
|
) { |
|
std::lock_guard<ChannelCrit> lock(s_channelCrit); |
|
|
|
NetMsgChannel * channel = FindOrCreateChannel_CS(protocol, server); |
|
|
|
// make sure no connections have been established on this protocol, otherwise |
|
// we'll be modifying a live data structure; NetCli's don't lock their protocol |
|
// to operate on it once they have linked to it! |
|
ASSERT(channel->RefCnt() == 1); |
|
|
|
channel->m_dh_g = dh_g; |
|
channel->m_dh_xa = dh_xa; |
|
channel->m_dh_n = dh_n; |
|
|
|
if (sendMsgCount) |
|
AddSendMsgs_CS(channel, sendMsgs, sendMsgCount); |
|
if (recvMsgCount) |
|
AddRecvMsgs_CS(channel, recvMsgs, recvMsgCount); |
|
} |
|
|
|
//=========================================================================== |
|
void NetMsgProtocolDestroy (uint32_t protocol, bool server) { |
|
std::lock_guard<ChannelCrit> lock(s_channelCrit); |
|
|
|
if (NetMsgChannel* channel = FindChannel_CS(protocol, server)) { |
|
s_channels->remove(channel); |
|
channel->UnRef("ChannelLink"); |
|
} |
|
}
|
|
|