/*==LICENSE==* CyanWorlds.com Engine - MMOG client, server and tools Copyright (C) 2011 Cyan Worlds, Inc. This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . You can contact Cyan Worlds, Inc. by email legal@cyan.com or by snail mail at: Cyan Worlds, Inc. 14617 N Newport Hwy Mead, WA 99021 *==LICENSE==*/ /***************************************************************************** * * $/Plasma20/Sources/Plasma/NucleusLib/pnNetDiag/pnNdTcp.cpp * ***/ #include "Pch.h" #pragma hdrstop namespace AuthKey { // This file excluded from pre-compiled header because it is auto-generated by the build server. #include "pnNetBase/pnNbAuthKey.hpp" } // AuthKey /***************************************************************************** * * Local types * ***/ struct AuthConn : AtomicRef { NetDiag * diag; FNetDiagDumpProc dump; FNetDiagTestCallback callback; void * param; AsyncSocket sock; AsyncCancelId cancelId; NetCli * cli; long pingsInRoute; long pingsCompleted; bool done; ENetError error; ~AuthConn (); }; struct AuthTrans : THashKeyVal { HASHLINK(AuthTrans) link; AuthConn * conn; unsigned pingAtMs; AuthTrans (AuthConn * conn); ~AuthTrans (); }; struct FileConn : AtomicRef { NetDiag * diag; FNetDiagDumpProc dump; FNetDiagTestCallback callback; void * param; AsyncSocket sock; AsyncCancelId cancelId; ARRAY(byte) recvBuffer; long pingsInRoute; long pingsCompleted; bool done; ENetError error; ~FileConn (); }; struct FileTrans : THashKeyVal { HASHLINK(FileTrans) link; FileConn * conn; unsigned pingAtMs; FileTrans (FileConn * conn); ~FileTrans (); }; /***************************************************************************** * * Local data * ***/ static const unsigned kPingTimeoutMs = 5000; static const unsigned kTimeoutCheckMs = 100; static const unsigned kMaxPings = 15; static long s_authProtocolRegistered; static unsigned s_transId; static CCritSect s_critsect; static bool s_shutdown; static byte s_payload[32]; static AsyncTimer * s_timer; static HASHTABLEDECL( AuthTrans, THashKeyVal, link ) s_authTrans; static HASHTABLEDECL( FileTrans, THashKeyVal, link ) s_fileTrans; /***************************************************************************** * * Cli2Auth protocol * ***/ //============================================================================ static bool Recv_PingReply ( const byte msg[], unsigned bytes, void * ) { ref(bytes); const Auth2Cli_PingReply & reply = *(const Auth2Cli_PingReply *)msg; AuthTrans * trans; s_critsect.Enter(); { if (bytes < sizeof(Auth2Cli_PingReply)) { // beta6 compatibility if (nil != (trans = s_authTrans.Tail())) s_authTrans.Unlink(trans); } else if (nil != (trans = s_authTrans.Find(reply.transId))) s_authTrans.Unlink(trans); } s_critsect.Leave(); if (trans) { unsigned replyAtMs = TimeGetMs(); trans->conn->dump(L"[TCP] Reply from SrvAuth. ms=%u", replyAtMs - trans->pingAtMs); DEL(trans); return true; } else { return false; } } //============================================================================ #define MSG(s) kNetMsg_Cli2Auth_##s static NetMsgInitSend s_send[] = { { MSG(PingRequest) }, }; #undef MSG #define MSG(s) kNetMsg_Auth2Cli_##s, Recv_##s static NetMsgInitRecv s_recv[] = { { MSG(PingReply) }, }; #undef MSG /***************************************************************************** * * Local functions * ***/ //============================================================================ static unsigned TimerCallback (void *) { unsigned timeMs = TimeGetMs(); s_critsect.Enter(); { ENetError error = kNetErrTimeout; {for (AuthTrans * next, * curr = s_authTrans.Head(); curr; curr = next) { next = s_authTrans.Next(curr); unsigned diff = timeMs - curr->pingAtMs; if (diff > kPingTimeoutMs) { if (!curr->conn->error) curr->conn->error = error; curr->conn->dump(L"[TCP] No reply from SrvAuth: %u, %s (ms=%u)", error, NetErrorToString(error), diff); DEL(curr); } }} {for (FileTrans * next, * curr = s_fileTrans.Head(); curr; curr = next) { next = s_fileTrans.Next(curr); unsigned diff = timeMs - curr->pingAtMs; if (diff > kPingTimeoutMs) { if (!curr->conn->error) curr->conn->error = error; curr->conn->dump(L"[TCP] No reply from SrvFile: %u, %s (ms=%u)", error, NetErrorToString(error), diff); DEL(curr); } }} } s_critsect.Leave(); return kTimeoutCheckMs; } //============================================================================ static void AuthPingProc (void * param) { AuthConn * conn = (AuthConn *)param; while (!conn->done && conn->pingsCompleted < kMaxPings) { if (!conn->pingsInRoute) { AuthTrans * trans = NEW(AuthTrans)(conn); trans->pingAtMs = TimeGetMs(); s_critsect.Enter(); for (;;) { if (conn->done) { conn->pingsCompleted = kMaxPings; DEL(trans); break; } while (++s_transId == 0) NULL_STMT; trans->SetValue(s_transId); s_authTrans.Add(trans); const unsigned_ptr msg[] = { kCli2Auth_PingRequest, trans->pingAtMs, trans->GetValue(), sizeof(s_payload), (unsigned_ptr) s_payload, }; NetCliSend(conn->cli, msg, arrsize(msg)); NetCliFlush(conn->cli); break; } s_critsect.Leave(); } AsyncSleep(10); } s_critsect.Enter(); { conn->done = true; AsyncSocketDisconnect(conn->sock, true); NetCliDelete(conn->cli, false); conn->cli = nil; } s_critsect.Leave(); conn->DecRef("Pinging"); } //============================================================================ static bool AuthConnEncrypt (ENetError error, void * param) { AuthConn * conn = (AuthConn *)param; if (IS_NET_SUCCESS(error)) { conn->dump(L"[TCP] SrvAuth stream encrypted."); conn->dump(L"[TCP] Pinging SrvAuth with 32 bytes of data..."); conn->IncRef("Pinging"); _beginthread(AuthPingProc, 0, conn); } else { conn->dump(L"[TCP] SrvAuth stream encryption failed: %u, %s", error, NetErrorToString(error)); } return IS_NET_SUCCESS(error); } //============================================================================ static void NotifyAuthConnSocketConnect (AuthConn * conn) { conn->dump(L"[TCP] SrvAuth socket established, encrypting stream..."); conn->TransferRef("Connecting", "Connected"); conn->cli = NetCliConnectAccept( conn->sock, kNetProtocolCli2Auth, false, AuthConnEncrypt, 0, nil, conn ); } //============================================================================ static void NotifyAuthConnSocketConnectFailed (AuthConn * conn) { conn->error = kNetErrConnectFailed; conn->cancelId = 0; conn->dump(L"[TCP] SrvAuth socket connection failed %u, %s", conn->error, NetErrorToString(conn->error)); conn->DecRef("Connecting"); } //============================================================================ static void NotifyAuthConnSocketDisconnect (AuthConn * conn) { if (!conn->done && !conn->error) conn->error = kNetErrDisconnected; conn->cancelId = 0; conn->dump(L"[TCP] SrvAuth socket closed: %u, %s", conn->error, NetErrorToString(conn->error)); HASHTABLEDECL( AuthTrans, THashKeyVal, link ) authTrans; s_critsect.Enter(); { conn->done = true; while (AuthTrans * trans = s_authTrans.Head()) authTrans.Add(trans); } s_critsect.Leave(); while (AuthTrans * trans = authTrans.Head()) { conn->dump(L"[TCP] No reply from SrvAuth: %u, %s", conn->error, NetErrorToString(conn->error)); DEL(trans); } conn->DecRef("Connected"); } //============================================================================ static bool NotifyAuthConnSocketRead (AuthConn * conn, AsyncNotifySocketRead * read) { NetCliDispatch(conn->cli, read->buffer, read->bytes, conn); read->bytesProcessed += read->bytes; return true; } //============================================================================ static bool AuthSocketNotifyCallback ( AsyncSocket sock, EAsyncNotifySocket code, AsyncNotifySocket * notify, void ** userState ) { bool result = true; AuthConn * conn; switch (code) { case kNotifySocketConnectSuccess: conn = (AuthConn *) notify->param; *userState = conn; conn->sock = sock; conn->cancelId = 0; NotifyAuthConnSocketConnect(conn); break; case kNotifySocketConnectFailed: conn = (AuthConn *) notify->param; NotifyAuthConnSocketConnectFailed(conn); break; case kNotifySocketDisconnect: conn = (AuthConn *) *userState; NotifyAuthConnSocketDisconnect(conn); break; case kNotifySocketRead: conn = (AuthConn *) *userState; result = NotifyAuthConnSocketRead(conn, (AsyncNotifySocketRead *) notify); break; } return result; } //============================================================================ static bool Recv_File2Cli_ManifestReply (FileConn * conn, const File2Cli_ManifestReply & msg) { ref(conn); FileTrans * trans; s_critsect.Enter(); { if (nil != (trans = s_fileTrans.Find(msg.transId))) s_fileTrans.Unlink(trans); } s_critsect.Leave(); if (trans) { unsigned replyAtMs = TimeGetMs(); trans->conn->dump(L"[TCP] Reply from SrvFile. ms=%u", replyAtMs - trans->pingAtMs); DEL(trans); return true; } else { return false; } } //============================================================================ static void FilePingProc (void * param) { FileConn * conn = (FileConn *)param; while (!conn->done && conn->pingsCompleted < kMaxPings) { if (!conn->pingsInRoute) { FileTrans * trans = NEW(FileTrans)(conn); trans->pingAtMs = TimeGetMs(); s_critsect.Enter(); for (;;) { if (conn->done) { conn->pingsCompleted = kMaxPings; DEL(trans); break; } while (++s_transId == 0) NULL_STMT; trans->SetValue(s_transId); s_fileTrans.Add(trans); Cli2File_ManifestRequest msg; StrCopy(msg.group, L"External", arrsize(msg.group)); msg.messageId = kCli2File_ManifestRequest; msg.transId = trans->GetValue(); msg.messageBytes = sizeof(msg); msg.buildId = 0; AsyncSocketSend(conn->sock, &msg, sizeof(msg)); break; } s_critsect.Leave(); } AsyncSleep(10); } s_critsect.Enter(); { conn->done = true; AsyncSocketDisconnect(conn->sock, true); } s_critsect.Leave(); conn->DecRef("Pinging"); } //============================================================================ static void NotifyFileConnSocketConnect (FileConn * conn) { conn->TransferRef("Connecting", "Connected"); conn->dump(L"[TCP] SrvFile socket established"); conn->dump(L"[TCP] Pinging SrvFile..."); conn->IncRef("Pinging"); _beginthread(FilePingProc, 0, conn); } //============================================================================ static void NotifyFileConnSocketConnectFailed (FileConn * conn) { conn->error = kNetErrConnectFailed; conn->cancelId = 0; conn->dump(L"[TCP] SrvFile socket connection failed %u, %s", conn->error, NetErrorToString(conn->error)); conn->DecRef("Connecting"); } //============================================================================ static void NotifyFileConnSocketDisconnect (FileConn * conn) { if (!conn->done && !conn->error) conn->error = kNetErrDisconnected; conn->cancelId = 0; conn->dump(L"[TCP] SrvFile socket closed: %u, %s", conn->error, NetErrorToString(conn->error)); HASHTABLEDECL( FileTrans, THashKeyVal, link ) fileTrans; s_critsect.Enter(); { conn->done = true; while (FileTrans * trans = s_fileTrans.Head()) fileTrans.Add(trans); } s_critsect.Leave(); while (FileTrans * trans = fileTrans.Head()) { conn->dump(L"[TCP] No reply from SrvFile: %u, %s", conn->error, NetErrorToString(conn->error)); DEL(trans); } conn->DecRef("Connected"); } //============================================================================ static bool NotifyFileConnSocketRead (FileConn * conn, AsyncNotifySocketRead * read) { conn->recvBuffer.Add(read->buffer, read->bytes); read->bytesProcessed += read->bytes; for (;;) { if (conn->recvBuffer.Count() < sizeof(dword)) return true; dword msgSize = *(dword *)conn->recvBuffer.Ptr(); if (conn->recvBuffer.Count() < msgSize) return true; const Cli2File_MsgHeader * msg = (const Cli2File_MsgHeader *) conn->recvBuffer.Ptr(); if (msg->messageId != kFile2Cli_ManifestReply) { conn->dump(L"[TCP] SrvFile received unexpected message. id: %u", msg->messageId); return false; } if (!Recv_File2Cli_ManifestReply(conn, *(const File2Cli_ManifestReply *)msg)) return false; conn->recvBuffer.Move(0, msgSize, conn->recvBuffer.Count() - msgSize); conn->recvBuffer.ShrinkBy(msgSize); } } //============================================================================ static bool FileSocketNotifyCallback ( AsyncSocket sock, EAsyncNotifySocket code, AsyncNotifySocket * notify, void ** userState ) { bool result = true; FileConn * conn; switch (code) { case kNotifySocketConnectSuccess: conn = (FileConn *) notify->param; *userState = conn; conn->sock = sock; conn->cancelId = 0; NotifyFileConnSocketConnect(conn); break; case kNotifySocketConnectFailed: conn = (FileConn *) notify->param; NotifyFileConnSocketConnectFailed(conn); break; case kNotifySocketDisconnect: conn = (FileConn *) *userState; NotifyFileConnSocketDisconnect(conn); break; case kNotifySocketRead: conn = (FileConn *) *userState; result = NotifyFileConnSocketRead(conn, (AsyncNotifySocketRead *) notify); break; } return result; } //============================================================================ static void StartAuthTcpTest ( NetDiag * diag, const NetAddress & addr, FNetDiagDumpProc dump, FNetDiagTestCallback callback, void * param ) { if (0 == AtomicSet(&s_authProtocolRegistered, 1)) { MemSet( s_payload, (byte)((unsigned_ptr)&s_payload >> 4), sizeof(s_payload) ); NetMsgProtocolRegister( kNetProtocolCli2Auth, false, s_send, arrsize(s_send), s_recv, arrsize(s_recv), AuthKey::kDhGValue, BigNum(sizeof(AuthKey::kDhXData), AuthKey::kDhXData), BigNum(sizeof(AuthKey::kDhNData), AuthKey::kDhNData) ); } wchar addrStr[128]; NetAddressToString(addr, addrStr, arrsize(addrStr), kNetAddressFormatAll); dump(L"[TCP] Connecting to SrvAuth at %s...", addrStr); diag->IncRef("TCP"); AuthConn * conn = NEWZERO(AuthConn); conn->diag = diag; conn->dump = dump; conn->callback = callback; conn->param = param; conn->IncRef("Connecting"); Cli2Auth_Connect connect; connect.hdr.connType = (byte) kConnTypeCliToAuth; connect.hdr.hdrBytes = sizeof(connect.hdr); connect.hdr.buildId = BuildId(); connect.hdr.buildType = BuildType(); connect.hdr.branchId = BranchId(); connect.hdr.productId = ProductId(); connect.data.token = kNilGuid; connect.data.dataBytes = sizeof(connect.data); AsyncSocketConnect( &conn->cancelId, addr, AuthSocketNotifyCallback, conn, &connect, sizeof(connect), 0, 0 ); } //============================================================================ static void StartFileTcpTest ( NetDiag * diag, const NetAddress & addr, FNetDiagDumpProc dump, FNetDiagTestCallback callback, void * param ) { wchar addrStr[128]; NetAddressToString(addr, addrStr, arrsize(addrStr), kNetAddressFormatAll); dump(L"[TCP] Connecting to SrvFile at %s...", addrStr); diag->IncRef("TCP"); FileConn * conn = NEWZERO(FileConn); conn->diag = diag; conn->dump = dump; conn->callback = callback; conn->param = param; conn->IncRef("Connecting"); Cli2File_Connect connect; connect.hdr.connType = kConnTypeCliToFile; connect.hdr.hdrBytes = sizeof(connect.hdr); connect.hdr.buildId = 0; connect.hdr.buildType = BuildType(); connect.hdr.branchId = BranchId(); connect.hdr.productId = ProductId(); connect.data.buildId = BuildId(); connect.data.serverType = kSrvTypeNone; connect.data.dataBytes = sizeof(connect.data); AsyncSocketConnect( &conn->cancelId, addr, FileSocketNotifyCallback, conn, &connect, sizeof(connect), 0, 0 ); } /***************************************************************************** * * AuthConn * ***/ //============================================================================ AuthConn::~AuthConn () { if (cli) NetCliDelete(cli, false); if (sock) AsyncSocketDelete(sock); callback(diag, kNetProtocolCli2Auth, error, param); diag->DecRef("TCP"); } /***************************************************************************** * * AuthTrans * ***/ //============================================================================ AuthTrans::AuthTrans (AuthConn * conn) : conn(conn) { conn->IncRef("Ping"); AtomicAdd(&conn->pingsInRoute, 1); } //============================================================================ AuthTrans::~AuthTrans () { AtomicAdd(&conn->pingsCompleted, 1); AtomicAdd(&conn->pingsInRoute, -1); conn->DecRef("Ping"); } /***************************************************************************** * * FileConn * ***/ //============================================================================ FileConn::~FileConn () { if (sock) AsyncSocketDelete(sock); callback(diag, kNetProtocolCli2File, error, param); diag->DecRef("TCP"); } /***************************************************************************** * * FileTrans * ***/ //============================================================================ FileTrans::FileTrans (FileConn * conn) : conn(conn) { conn->IncRef("Ping"); AtomicAdd(&conn->pingsInRoute, 1); } //============================================================================ FileTrans::~FileTrans () { AtomicAdd(&conn->pingsCompleted, 1); AtomicAdd(&conn->pingsInRoute, -1); conn->DecRef("Ping"); } /***************************************************************************** * * Module functions * ***/ //============================================================================ void TcpStartup () { s_shutdown = false; AsyncTimerCreate(&s_timer, TimerCallback, 0, nil); } //============================================================================ void TcpShutdown () { s_shutdown = true; AsyncTimerDeleteCallback(s_timer, TimerCallback); s_timer = nil; } /***************************************************************************** * * Exports * ***/ //============================================================================ void NetDiagTcp ( NetDiag * diag, ENetProtocol protocol, unsigned port, FNetDiagDumpProc dump, FNetDiagTestCallback callback, void * param ) { ASSERT(diag); ASSERT(dump); ASSERT(callback); unsigned srv = NetProtocolToSrv(protocol); if (srv == kNumDiagSrvs) { dump(L"[TCP] Unsupported protocol: %s", NetProtocolToString(protocol)); callback(diag, protocol, kNetErrNotSupported, param); return; } unsigned node; NetAddress addr; diag->critsect.Enter(); { node = diag->nodes[srv]; } diag->critsect.Leave(); if (!node) { dump(L"[TCP] No address set for protocol: %s", NetProtocolToString(protocol)); callback(diag, protocol, kNetSuccess, param); return; } NetAddressFromNode(node, port, &addr); switch (protocol) { case kNetProtocolCli2Auth: StartAuthTcpTest(diag, addr, dump, callback, param); break; case kNetProtocolCli2File: StartFileTcpTest(diag, addr, dump, callback, param); break; DEFAULT_FATAL(protocol); } }