You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

631 lines
18 KiB

/*==LICENSE==*
CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011 Cyan Worlds, Inc.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
You can contact Cyan Worlds, Inc. by email legal@cyan.com
or by snail mail at:
Cyan Worlds, Inc.
14617 N Newport Hwy
Mead, WA 99021
*==LICENSE==*/
#include "hsSTLStream.h"
#include "hsResMgr.h"
#include "plgDispatch.h"
#include "../pnUtils/pnUtils.h"
#include "../pnNetBase/pnNetBase.h"
#include "../pnAsyncCore/pnAsyncCore.h"
#include "../pnNetCli/pnNetCli.h"
#include "../plNetGameLib/plNetGameLib.h"
#include "../plFile/plFileUtils.h"
#include "../plFile/plStreamSource.h"
#include "../plNetCommon/plNetCommon.h"
#include "../plProgressMgr/plProgressMgr.h"
#include "../plMessage/plPreloaderMsg.h"
#include "../plMessage/plNetCommMsgs.h"
#include "pfSecurePreloader.h"
#include "../plNetClientComm/plNetClientComm.h"
extern hsBool gDataServerLocal;
// Max number of concurrent file downloads
static const unsigned kMaxConcurrency = 1;
pfSecurePreloader * pfSecurePreloader::fInstance;
///////////////////////////////////////////////////////////////////////////////
// Callback routines for the network code
// Called when a file's info is retrieved from the server
static void DefaultFileListRequestCallback(ENetError result, void* param, const NetCliAuthFileInfo infoArr[], unsigned infoCount)
{
bool success = !IS_NET_ERROR(result);
std::vector<std::wstring> filenames;
std::vector<UInt32> sizes;
if (success)
{
filenames.reserve(infoCount);
sizes.reserve(infoCount);
for (unsigned curFile = 0; curFile < infoCount; curFile++)
{
filenames.push_back(infoArr[curFile].filename);
sizes.push_back(infoArr[curFile].filesize);
}
}
((pfSecurePreloader*)param)->RequestFinished(filenames, sizes, success);
}
// Called when a file download is either finished, or failed
static void DefaultFileRequestCallback(ENetError result, void* param, const wchar filename[], hsStream* stream)
{
// Retry download unless shutting down or file not found
switch (result) {
case kNetSuccess:
((pfSecurePreloader*)param)->FinishedDownload(filename, true);
break;
case kNetErrFileNotFound:
case kNetErrRemoteShutdown:
((pfSecurePreloader*)param)->FinishedDownload(filename, false);
break;
default:
stream->Rewind();
NetCliAuthFileRequest(
filename,
stream,
&DefaultFileRequestCallback,
param
);
break;
}
}
///////////////////////////////////////////////////////////////////////////////
// Our custom stream for writing directly to disk securely, and updating the
// progress bar. Does NOT support reading (cause it doesn't need to)
class Direct2DiskStream : public hsUNIXStream
{
protected:
wchar * fWriteFileName;
pfSecurePreloader* fPreloader;
public:
Direct2DiskStream(pfSecurePreloader* preloader);
~Direct2DiskStream();
virtual hsBool Open(const char* name, const char* mode = "wb");
virtual hsBool Open(const wchar* name, const wchar* mode = L"wb");
virtual hsBool Close();
virtual UInt32 Read(UInt32 byteCount, void* buffer);
virtual UInt32 Write(UInt32 byteCount, const void* buffer);
};
Direct2DiskStream::Direct2DiskStream(pfSecurePreloader* preloader) :
fWriteFileName(nil),
fPreloader(preloader)
{}
Direct2DiskStream::~Direct2DiskStream()
{
Close();
}
hsBool Direct2DiskStream::Open(const char* name, const char* mode)
{
wchar* wName = hsStringToWString(name);
wchar* wMode = hsStringToWString(mode);
hsBool ret = Open(wName, wMode);
delete [] wName;
delete [] wMode;
return ret;
}
hsBool Direct2DiskStream::Open(const wchar* name, const wchar* mode)
{
if (0 != wcscmp(mode, L"wb")) {
hsAssert(0, "Unsupported open mode");
return false;
}
fWriteFileName = TRACKED_NEW(wchar[wcslen(name) + 1]);
wcscpy(fWriteFileName, name);
// LogMsg(kLogPerf, L"Opening disk file %S", fWriteFileName);
return hsUNIXStream::Open(name, mode);
}
hsBool Direct2DiskStream::Close()
{
delete [] fWriteFileName;
fWriteFileName = nil;
return hsUNIXStream::Close();
}
UInt32 Direct2DiskStream::Read(UInt32 bytes, void* buffer)
{
hsAssert(0, "not implemented");
return 0; // we don't read
}
UInt32 Direct2DiskStream::Write(UInt32 bytes, const void* buffer)
{
// LogMsg(kLogPerf, L"Writing %u bytes to disk file %S", bytes, fWriteFileName);
fPreloader->UpdateProgressBar(bytes);
return hsUNIXStream::Write(bytes, buffer);
}
///////////////////////////////////////////////////////////////////////////////
// secure preloader class implementation
// closes and deletes all streams
void pfSecurePreloader::ICleanupStreams()
{
if (fD2DStreams.size() > 0)
{
std::map<std::wstring, hsStream*>::iterator curStream;
for (curStream = fD2DStreams.begin(); curStream != fD2DStreams.end(); curStream++)
{
curStream->second->Close();
delete curStream->second;
curStream->second = nil;
}
fD2DStreams.clear();
}
}
// queues a single file to be preloaded (does nothing if already preloaded)
void pfSecurePreloader::RequestSingleFile(std::wstring filename)
{
fileRequest request;
ZERO(request);
request.fType = fileRequest::kSingleFile;
request.fPath = filename;
request.fExt = L"";
fRequests.push_back(request);
}
// queues a group of files to be preloaded (does nothing if already preloaded)
void pfSecurePreloader::RequestFileGroup(std::wstring dir, std::wstring ext)
{
fileRequest request;
ZERO(request);
request.fType = fileRequest::kFileList;
request.fPath = dir;
request.fExt = ext;
fRequests.push_back(request);
}
// preloads all requested files from the server (does nothing if already preloaded)
void pfSecurePreloader::Start()
{
if (gDataServerLocal) {
// using local data, don't do anything
plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg();
msg->fSuccess = true;
msg->Send();
return;
}
NetCliAuthGetEncryptionKey(fEncryptionKey, 4); // grab the encryption key from the server
fNetError = false;
// make sure we are all cleaned up
ICleanupStreams();
fTotalDataReceived = 0;
// update the progress bar for downloading
if (!fProgressBar)
fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fRequests.size()), "Getting file info...", plProgressMgr::kUpdateText, false, true);
for (unsigned curRequest = 0; curRequest < fRequests.size(); curRequest++)
{
fNumInfoRequestsRemaining++; // increment the counter
if (fRequests[curRequest].fType == fileRequest::kSingleFile)
{
#ifndef PLASMA_EXTERNAL_RELEASE
// in internal releases, we can use on-disk files if they exist
if (plFileUtils::FileExists(fRequests[curRequest].fPath.c_str()))
{
fileInfo info;
info.fOriginalNameAndPath = fRequests[curRequest].fPath;
info.fSizeInBytes = plFileUtils::GetFileSize(info.fOriginalNameAndPath.c_str());
info.fDownloading = false;
info.fDownloaded = false;
info.fLocal = true;
// generate garbled name
wchar_t pathBuffer[MAX_PATH + 1];
wchar_t filename[arrsize(pathBuffer)];
GetTempPathW(arrsize(pathBuffer), pathBuffer);
GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
info.fGarbledNameAndPath = filename;
fTotalDataDownload += info.fSizeInBytes;
fFileInfoMap[info.fOriginalNameAndPath] = info;
}
// internal client will still request it, even if it exists locally,
// so that things get updated properly
#endif // PLASMA_EXTERNAL_RELEASE
NetCliAuthFileListRequest(
fRequests[curRequest].fPath.c_str(),
nil,
&DefaultFileListRequestCallback,
(void*)this
);
}
else
{
#ifndef PLASMA_EXTERNAL_RELEASE
// in internal releases, we can use on-disk files if they exist
// Build the search string as "dir\\*.ext"
wchar searchStr[MAX_PATH];
PathAddFilename(searchStr, fRequests[curRequest].fPath.c_str(), L"*", arrsize(searchStr));
PathSetExtension(searchStr, searchStr, fRequests[curRequest].fExt.c_str(), arrsize(searchStr));
ARRAY(PathFind) paths;
PathFindFiles(&paths, searchStr, kPathFlagFile); // find all files that match
// convert it to our little file info array
PathFind* curFile = paths.Ptr();
PathFind* lastFile = paths.Term();
while (curFile != lastFile) {
fileInfo info;
info.fOriginalNameAndPath = curFile->name;
info.fSizeInBytes = (UInt32)curFile->fileLength;
info.fDownloading = false;
info.fDownloaded = false;
info.fLocal = true;
// generate garbled name
wchar_t pathBuffer[MAX_PATH + 1];
wchar_t filename[arrsize(pathBuffer)];
GetTempPathW(arrsize(pathBuffer), pathBuffer);
GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
info.fGarbledNameAndPath = filename;
fTotalDataDownload += info.fSizeInBytes;
fFileInfoMap[info.fOriginalNameAndPath] = info;
curFile++;
}
#endif // PLASMA_EXTERNAL_RELEASE
NetCliAuthFileListRequest(
fRequests[curRequest].fPath.c_str(),
fRequests[curRequest].fExt.c_str(),
&DefaultFileListRequestCallback,
(void*)this
);
}
}
}
// closes all file pointers and cleans up after itself
void pfSecurePreloader::Cleanup()
{
ICleanupStreams();
fRequests.clear();
fFileInfoMap.clear();
fNumInfoRequestsRemaining = 0;
fTotalDataDownload = 0;
fTotalDataReceived = 0;
DEL(fProgressBar);
fProgressBar = nil;
}
//============================================================================
void pfSecurePreloader::RequestFinished(const std::vector<std::wstring> & filenames, const std::vector<UInt32> & sizes, bool succeeded)
{
fNetError |= !succeeded;
if (succeeded)
{
unsigned count = 0;
for (int curFile = 0; curFile < filenames.size(); curFile++)
{
if (fFileInfoMap.find(filenames[curFile]) != fFileInfoMap.end())
continue; // if it is a duplicate, ignore it (the duplicate is probably one we found locally)
fileInfo info;
info.fOriginalNameAndPath = filenames[curFile];
info.fSizeInBytes = sizes[curFile];
info.fDownloading = false;
info.fDownloaded = false;
info.fLocal = false; // if we get here, it was retrieved remotely
// generate garbled name
wchar_t pathBuffer[MAX_PATH + 1];
wchar_t filename[arrsize(pathBuffer)];
GetTempPathW(arrsize(pathBuffer), pathBuffer);
GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
info.fGarbledNameAndPath = filename;
fTotalDataDownload += info.fSizeInBytes;
fFileInfoMap[info.fOriginalNameAndPath] = info;
++count;
}
LogMsg(kLogPerf, "Added %u files to secure download queue", count);
}
if (fProgressBar)
fProgressBar->Increment(1.f);
--fNumInfoRequestsRemaining; // even if we fail, decrement the counter
if (succeeded) {
DEL(fProgressBar);
fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fTotalDataDownload), "Downloading...", plProgressMgr::kUpdateText, false, true);
// Issue some file download requests (up to kMaxConcurrency)
IIssueDownloadRequests();
}
else {
IPreloadComplete();
}
}
//============================================================================
void pfSecurePreloader::IIssueDownloadRequests () {
std::map<std::wstring, fileInfo>::iterator curFile;
for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++)
{
// Skip files already downloaded or currently downloading
if (curFile->second.fDownloaded || curFile->second.fDownloading)
continue;
std::wstring filename = curFile->second.fOriginalNameAndPath;
#ifndef PLASMA_EXTERNAL_RELEASE
// in internal releases, we can use on-disk files if they exist
if (plFileUtils::FileExists(filename.c_str()))
{
// don't bother streaming, just make the secure stream using the local file
// a local key overrides the server-downloaded key
UInt32 localKey[4];
bool hasLocalKey = plFileUtils::GetSecureEncryptionKey(filename.c_str(), localKey, arrsize(localKey));
hsStream* stream = nil;
if (hasLocalKey)
stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, localKey);
else
stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, fEncryptionKey);
// add it to the stream source
bool added = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream);
if (!added)
DEL(stream); // wasn't added, so nuke our local copy
// and make sure the vars are set up right
curFile->second.fDownloaded = true;
curFile->second.fLocal = true;
}
else
#endif
{
// Enforce concurrency limit
if (fNumDownloadRequestsRemaining >= kMaxConcurrency)
break;
curFile->second.fDownloading = true;
curFile->second.fDownloaded = false;
curFile->second.fLocal = false;
// create and setup the stream
Direct2DiskStream* fileStream = TRACKED_NEW Direct2DiskStream(this);
fileStream->Open(curFile->second.fGarbledNameAndPath.c_str(), L"wb");
fD2DStreams[filename] = (hsStream*)fileStream;
// request the file from the server
LogMsg(kLogPerf, L"Requesting secure file:%s", filename.c_str());
++fNumDownloadRequestsRemaining;
NetCliAuthFileRequest(
filename.c_str(),
(hsStream*)fileStream,
&DefaultFileRequestCallback,
this
);
}
}
if (!fNumDownloadRequestsRemaining)
IPreloadComplete();
}
void pfSecurePreloader::UpdateProgressBar(UInt32 bytesReceived)
{
fTotalDataReceived += bytesReceived;
if (fTotalDataReceived > fTotalDataDownload)
fTotalDataReceived = fTotalDataDownload; // shouldn't happen... but just in case
if (fProgressBar)
fProgressBar->Increment((hsScalar)bytesReceived);
}
void pfSecurePreloader::FinishedDownload(std::wstring filename, bool succeeded)
{
for (;;)
{
if (fFileInfoMap.find(filename) == fFileInfoMap.end())
{
// file doesn't exist... abort
succeeded = false;
break;
}
fFileInfoMap[filename].fDownloading = false;
// close and delete the writer stream (even if we failed)
fD2DStreams[filename]->Close();
delete fD2DStreams[filename];
fD2DStreams.erase(fD2DStreams.find(filename));
if (succeeded)
{
// open a secure stream to that file
hsStream* stream = plSecureStream::OpenSecureFile(
fFileInfoMap[filename].fGarbledNameAndPath.c_str(),
plSecureStream::kRequireEncryption | plSecureStream::kDeleteOnExit, // force delete and encryption
fEncryptionKey
);
bool addedToSource = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream);
if (!addedToSource)
DEL(stream); // cleanup if it wasn't added
fFileInfoMap[filename].fDownloaded = true;
break;
}
// file download failed, clean up after it
// delete the temporary file
if (plFileUtils::FileExists(fFileInfoMap[filename].fGarbledNameAndPath.c_str()))
plFileUtils::RemoveFile(fFileInfoMap[filename].fGarbledNameAndPath.c_str(), true);
// and remove it from the info map
fFileInfoMap.erase(fFileInfoMap.find(filename));
break;
}
fNetError |= !succeeded;
--fNumDownloadRequestsRemaining;
LogMsg(kLogPerf, L"Received secure file:%s, success:%s", filename.c_str(), succeeded ? L"Yep" : L"Nope");
if (!succeeded)
IPreloadComplete();
else
// Issue some file download requests (up to kMaxConcurrency)
IIssueDownloadRequests();
}
//============================================================================
void pfSecurePreloader::INotifyAuthReconnected () {
// The secure file download network protocol will now just pick up downloading
// where it left off before the reconnect, so no need to reset in-progress files.
/*
std::map<std::wstring, fileInfo>::iterator curFile;
for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++) {
// Reset files that were currently downloading
if (curFile->second.fDownloading)
curFile->second.fDownloading = false;
}
if (fNumDownloadRequestsRemaining > 0) {
LogMsg(kLogPerf, L"pfSecurePreloader: Auth reconnected, resetting in-progress file downloads");
// Issue some file download requests (up to kMaxConcurrency)
IIssueDownloadRequests();
}
*/
}
//============================================================================
void pfSecurePreloader::IPreloadComplete () {
DEL(fProgressBar);
fProgressBar = nil;
plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg();
msg->fSuccess = !fNetError;
msg->Send();
}
//============================================================================
hsBool pfSecurePreloader::MsgReceive (plMessage * msg) {
if (plNetCommAuthConnectedMsg * authMsg = plNetCommAuthConnectedMsg::ConvertNoRef(msg)) {
INotifyAuthReconnected();
return true;
}
return hsKeyedObject::MsgReceive(msg);
}
//============================================================================
pfSecurePreloader * pfSecurePreloader::GetInstance () {
if (!fInstance) {
fInstance = NEWZERO(pfSecurePreloader);
fInstance->RegisterAs(kSecurePreloader_KEY);
}
return fInstance;
}
//============================================================================
bool pfSecurePreloader::IsInstanced () {
return fInstance != nil;
}
//============================================================================
void pfSecurePreloader::Init () {
if (!fInitialized) {
fInitialized = true;
plgDispatch::Dispatch()->RegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey());
}
}
//============================================================================
void pfSecurePreloader::Shutdown () {
if (fInitialized) {
fInitialized = false;
plgDispatch::Dispatch()->UnRegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey());
}
if (fInstance) {
fInstance->UnRegister();
fInstance = nil;
}
}
//============================================================================
pfSecurePreloader::pfSecurePreloader () {
}
//============================================================================
pfSecurePreloader::~pfSecurePreloader () {
Cleanup();
}