647 lines
21 KiB

/*==LICENSE==*
CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011 Cyan Worlds, Inc.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
Additional permissions under GNU GPL version 3 section 7
If you modify this Program, or any covered work, by linking or
combining it with any of RAD Game Tools Bink SDK, Autodesk 3ds Max SDK,
NVIDIA PhysX SDK, Microsoft DirectX SDK, OpenSSL library, Independent
JPEG Group JPEG library, Microsoft Windows Media SDK, or Apple QuickTime SDK
(or a modified version of those libraries),
containing parts covered by the terms of the Bink SDK EULA, 3ds Max EULA,
PhysX SDK EULA, DirectX SDK EULA, OpenSSL and SSLeay licenses, IJG
JPEG Library README, Windows Media SDK EULA, or QuickTime SDK EULA, the
licensors of this Program grant you additional
permission to convey the resulting work. Corresponding Source for a
non-source form of such a combination shall include the source code for
the parts of OpenSSL and IJG JPEG Library used as well as that of the covered
work.
You can contact Cyan Worlds, Inc. by email legal@cyan.com
or by snail mail at:
Cyan Worlds, Inc.
14617 N Newport Hwy
Mead, WA 99021
*==LICENSE==*/
#include "hsSTLStream.h"
#include "hsResMgr.h"
#include "plgDispatch.h"
#include "pnUtils/pnUtils.h"
#include "pnNetBase/pnNetBase.h"
#include "pnAsyncCore/pnAsyncCore.h"
#include "pnNetCli/pnNetCli.h"
#include "plNetGameLib/plNetGameLib.h"
#include "plFile/plFileUtils.h"
#include "plFile/plStreamSource.h"
#include "plNetCommon/plNetCommon.h"
#include "plProgressMgr/plProgressMgr.h"
#include "plMessage/plPreloaderMsg.h"
#include "plMessage/plNetCommMsgs.h"
#include "pfSecurePreloader.h"
#include "plNetClientComm/plNetClientComm.h"
extern hsBool gDataServerLocal;
// Max number of concurrent file downloads
static const unsigned kMaxConcurrency = 1;
pfSecurePreloader * pfSecurePreloader::fInstance;
///////////////////////////////////////////////////////////////////////////////
// Callback routines for the network code
// Called when a file's info is retrieved from the server
static void DefaultFileListRequestCallback(ENetError result, void* param, const NetCliAuthFileInfo infoArr[], unsigned infoCount)
{
bool success = !IS_NET_ERROR(result);
std::vector<std::wstring> filenames;
std::vector<UInt32> sizes;
if (success)
{
filenames.reserve(infoCount);
sizes.reserve(infoCount);
for (unsigned curFile = 0; curFile < infoCount; curFile++)
{
filenames.push_back(infoArr[curFile].filename);
sizes.push_back(infoArr[curFile].filesize);
}
}
((pfSecurePreloader*)param)->RequestFinished(filenames, sizes, success);
}
// Called when a file download is either finished, or failed
static void DefaultFileRequestCallback(ENetError result, void* param, const wchar filename[], hsStream* stream)
{
// Retry download unless shutting down or file not found
switch (result) {
case kNetSuccess:
((pfSecurePreloader*)param)->FinishedDownload(filename, true);
break;
case kNetErrFileNotFound:
case kNetErrRemoteShutdown:
((pfSecurePreloader*)param)->FinishedDownload(filename, false);
break;
default:
stream->Rewind();
NetCliAuthFileRequest(
filename,
stream,
&DefaultFileRequestCallback,
param
);
break;
}
}
///////////////////////////////////////////////////////////////////////////////
// Our custom stream for writing directly to disk securely, and updating the
// progress bar. Does NOT support reading (cause it doesn't need to)
class Direct2DiskStream : public hsUNIXStream
{
protected:
wchar * fWriteFileName;
pfSecurePreloader* fPreloader;
public:
Direct2DiskStream(pfSecurePreloader* preloader);
~Direct2DiskStream();
virtual hsBool Open(const char* name, const char* mode = "wb");
virtual hsBool Open(const wchar* name, const wchar* mode = L"wb");
virtual hsBool Close();
virtual UInt32 Read(UInt32 byteCount, void* buffer);
virtual UInt32 Write(UInt32 byteCount, const void* buffer);
};
Direct2DiskStream::Direct2DiskStream(pfSecurePreloader* preloader) :
fWriteFileName(nil),
fPreloader(preloader)
{}
Direct2DiskStream::~Direct2DiskStream()
{
Close();
}
hsBool Direct2DiskStream::Open(const char* name, const char* mode)
{
wchar* wName = hsStringToWString(name);
wchar* wMode = hsStringToWString(mode);
hsBool ret = Open(wName, wMode);
delete [] wName;
delete [] wMode;
return ret;
}
hsBool Direct2DiskStream::Open(const wchar* name, const wchar* mode)
{
if (0 != wcscmp(mode, L"wb")) {
hsAssert(0, "Unsupported open mode");
return false;
}
fWriteFileName = TRACKED_NEW(wchar[wcslen(name) + 1]);
wcscpy(fWriteFileName, name);
// LogMsg(kLogPerf, L"Opening disk file %S", fWriteFileName);
return hsUNIXStream::Open(name, mode);
}
hsBool Direct2DiskStream::Close()
{
delete [] fWriteFileName;
fWriteFileName = nil;
return hsUNIXStream::Close();
}
UInt32 Direct2DiskStream::Read(UInt32 bytes, void* buffer)
{
hsAssert(0, "not implemented");
return 0; // we don't read
}
UInt32 Direct2DiskStream::Write(UInt32 bytes, const void* buffer)
{
// LogMsg(kLogPerf, L"Writing %u bytes to disk file %S", bytes, fWriteFileName);
fPreloader->UpdateProgressBar(bytes);
return hsUNIXStream::Write(bytes, buffer);
}
///////////////////////////////////////////////////////////////////////////////
// secure preloader class implementation
// closes and deletes all streams
void pfSecurePreloader::ICleanupStreams()
{
if (fD2DStreams.size() > 0)
{
std::map<std::wstring, hsStream*>::iterator curStream;
for (curStream = fD2DStreams.begin(); curStream != fD2DStreams.end(); curStream++)
{
curStream->second->Close();
delete curStream->second;
curStream->second = nil;
}
fD2DStreams.clear();
}
}
// queues a single file to be preloaded (does nothing if already preloaded)
void pfSecurePreloader::RequestSingleFile(std::wstring filename)
{
fileRequest request;
ZERO(request);
request.fType = fileRequest::kSingleFile;
request.fPath = filename;
request.fExt = L"";
fRequests.push_back(request);
}
// queues a group of files to be preloaded (does nothing if already preloaded)
void pfSecurePreloader::RequestFileGroup(std::wstring dir, std::wstring ext)
{
fileRequest request;
ZERO(request);
request.fType = fileRequest::kFileList;
request.fPath = dir;
request.fExt = ext;
fRequests.push_back(request);
}
// preloads all requested files from the server (does nothing if already preloaded)
void pfSecurePreloader::Start()
{
if (gDataServerLocal) {
// using local data, don't do anything
plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg();
msg->fSuccess = true;
msg->Send();
return;
}
NetCliAuthGetEncryptionKey(fEncryptionKey, 4); // grab the encryption key from the server
fNetError = false;
// make sure we are all cleaned up
ICleanupStreams();
fTotalDataReceived = 0;
// update the progress bar for downloading
if (!fProgressBar)
fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fRequests.size()), "Getting file info...", plProgressMgr::kUpdateText, false, true);
for (unsigned curRequest = 0; curRequest < fRequests.size(); curRequest++)
{
fNumInfoRequestsRemaining++; // increment the counter
if (fRequests[curRequest].fType == fileRequest::kSingleFile)
{
#ifndef PLASMA_EXTERNAL_RELEASE
// in internal releases, we can use on-disk files if they exist
if (plFileUtils::FileExists(fRequests[curRequest].fPath.c_str()))
{
fileInfo info;
info.fOriginalNameAndPath = fRequests[curRequest].fPath;
info.fSizeInBytes = plFileUtils::GetFileSize(info.fOriginalNameAndPath.c_str());
info.fDownloading = false;
info.fDownloaded = false;
info.fLocal = true;
// generate garbled name
wchar_t pathBuffer[MAX_PATH + 1];
wchar_t filename[arrsize(pathBuffer)];
GetTempPathW(arrsize(pathBuffer), pathBuffer);
GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
info.fGarbledNameAndPath = filename;
fTotalDataDownload += info.fSizeInBytes;
fFileInfoMap[info.fOriginalNameAndPath] = info;
}
// internal client will still request it, even if it exists locally,
// so that things get updated properly
#endif // PLASMA_EXTERNAL_RELEASE
NetCliAuthFileListRequest(
fRequests[curRequest].fPath.c_str(),
nil,
&DefaultFileListRequestCallback,
(void*)this
);
}
else
{
#ifndef PLASMA_EXTERNAL_RELEASE
// in internal releases, we can use on-disk files if they exist
// Build the search string as "dir\\*.ext"
wchar searchStr[MAX_PATH];
PathAddFilename(searchStr, fRequests[curRequest].fPath.c_str(), L"*", arrsize(searchStr));
PathSetExtension(searchStr, searchStr, fRequests[curRequest].fExt.c_str(), arrsize(searchStr));
ARRAY(PathFind) paths;
PathFindFiles(&paths, searchStr, kPathFlagFile); // find all files that match
// convert it to our little file info array
PathFind* curFile = paths.Ptr();
PathFind* lastFile = paths.Term();
while (curFile != lastFile) {
fileInfo info;
info.fOriginalNameAndPath = curFile->name;
info.fSizeInBytes = (UInt32)curFile->fileLength;
info.fDownloading = false;
info.fDownloaded = false;
info.fLocal = true;
// generate garbled name
wchar_t pathBuffer[MAX_PATH + 1];
wchar_t filename[arrsize(pathBuffer)];
GetTempPathW(arrsize(pathBuffer), pathBuffer);
GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
info.fGarbledNameAndPath = filename;
fTotalDataDownload += info.fSizeInBytes;
fFileInfoMap[info.fOriginalNameAndPath] = info;
curFile++;
}
#endif // PLASMA_EXTERNAL_RELEASE
NetCliAuthFileListRequest(
fRequests[curRequest].fPath.c_str(),
fRequests[curRequest].fExt.c_str(),
&DefaultFileListRequestCallback,
(void*)this
);
}
}
}
// closes all file pointers and cleans up after itself
void pfSecurePreloader::Cleanup()
{
ICleanupStreams();
fRequests.clear();
fFileInfoMap.clear();
fNumInfoRequestsRemaining = 0;
fTotalDataDownload = 0;
fTotalDataReceived = 0;
DEL(fProgressBar);
fProgressBar = nil;
}
//============================================================================
void pfSecurePreloader::RequestFinished(const std::vector<std::wstring> & filenames, const std::vector<UInt32> & sizes, bool succeeded)
{
fNetError |= !succeeded;
if (succeeded)
{
unsigned count = 0;
for (int curFile = 0; curFile < filenames.size(); curFile++)
{
if (fFileInfoMap.find(filenames[curFile]) != fFileInfoMap.end())
continue; // if it is a duplicate, ignore it (the duplicate is probably one we found locally)
fileInfo info;
info.fOriginalNameAndPath = filenames[curFile];
info.fSizeInBytes = sizes[curFile];
info.fDownloading = false;
info.fDownloaded = false;
info.fLocal = false; // if we get here, it was retrieved remotely
// generate garbled name
wchar_t pathBuffer[MAX_PATH + 1];
wchar_t filename[arrsize(pathBuffer)];
GetTempPathW(arrsize(pathBuffer), pathBuffer);
GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
info.fGarbledNameAndPath = filename;
fTotalDataDownload += info.fSizeInBytes;
fFileInfoMap[info.fOriginalNameAndPath] = info;
++count;
}
LogMsg(kLogPerf, "Added %u files to secure download queue", count);
}
if (fProgressBar)
fProgressBar->Increment(1.f);
--fNumInfoRequestsRemaining; // even if we fail, decrement the counter
if (succeeded) {
DEL(fProgressBar);
fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fTotalDataDownload), "Downloading...", plProgressMgr::kUpdateText, false, true);
// Issue some file download requests (up to kMaxConcurrency)
IIssueDownloadRequests();
}
else {
IPreloadComplete();
}
}
//============================================================================
void pfSecurePreloader::IIssueDownloadRequests () {
std::map<std::wstring, fileInfo>::iterator curFile;
for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++)
{
// Skip files already downloaded or currently downloading
if (curFile->second.fDownloaded || curFile->second.fDownloading)
continue;
std::wstring filename = curFile->second.fOriginalNameAndPath;
#ifndef PLASMA_EXTERNAL_RELEASE
// in internal releases, we can use on-disk files if they exist
if (plFileUtils::FileExists(filename.c_str()))
{
// don't bother streaming, just make the secure stream using the local file
// a local key overrides the server-downloaded key
UInt32 localKey[4];
bool hasLocalKey = plFileUtils::GetSecureEncryptionKey(filename.c_str(), localKey, arrsize(localKey));
hsStream* stream = nil;
if (hasLocalKey)
stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, localKey);
else
stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, fEncryptionKey);
// add it to the stream source
bool added = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream);
if (!added)
DEL(stream); // wasn't added, so nuke our local copy
// and make sure the vars are set up right
curFile->second.fDownloaded = true;
curFile->second.fLocal = true;
}
else
#endif
{
// Enforce concurrency limit
if (fNumDownloadRequestsRemaining >= kMaxConcurrency)
break;
curFile->second.fDownloading = true;
curFile->second.fDownloaded = false;
curFile->second.fLocal = false;
// create and setup the stream
Direct2DiskStream* fileStream = TRACKED_NEW Direct2DiskStream(this);
fileStream->Open(curFile->second.fGarbledNameAndPath.c_str(), L"wb");
fD2DStreams[filename] = (hsStream*)fileStream;
// request the file from the server
LogMsg(kLogPerf, L"Requesting secure file:%s", filename.c_str());
++fNumDownloadRequestsRemaining;
NetCliAuthFileRequest(
filename.c_str(),
(hsStream*)fileStream,
&DefaultFileRequestCallback,
this
);
}
}
if (!fNumDownloadRequestsRemaining)
IPreloadComplete();
}
void pfSecurePreloader::UpdateProgressBar(UInt32 bytesReceived)
{
fTotalDataReceived += bytesReceived;
if (fTotalDataReceived > fTotalDataDownload)
fTotalDataReceived = fTotalDataDownload; // shouldn't happen... but just in case
if (fProgressBar)
fProgressBar->Increment((hsScalar)bytesReceived);
}
void pfSecurePreloader::FinishedDownload(std::wstring filename, bool succeeded)
{
for (;;)
{
if (fFileInfoMap.find(filename) == fFileInfoMap.end())
{
// file doesn't exist... abort
succeeded = false;
break;
}
fFileInfoMap[filename].fDownloading = false;
// close and delete the writer stream (even if we failed)
fD2DStreams[filename]->Close();
delete fD2DStreams[filename];
fD2DStreams.erase(fD2DStreams.find(filename));
if (succeeded)
{
// open a secure stream to that file
hsStream* stream = plSecureStream::OpenSecureFile(
fFileInfoMap[filename].fGarbledNameAndPath.c_str(),
plSecureStream::kRequireEncryption | plSecureStream::kDeleteOnExit, // force delete and encryption
fEncryptionKey
);
bool addedToSource = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream);
if (!addedToSource)
DEL(stream); // cleanup if it wasn't added
fFileInfoMap[filename].fDownloaded = true;
break;
}
// file download failed, clean up after it
// delete the temporary file
if (plFileUtils::FileExists(fFileInfoMap[filename].fGarbledNameAndPath.c_str()))
plFileUtils::RemoveFile(fFileInfoMap[filename].fGarbledNameAndPath.c_str(), true);
// and remove it from the info map
fFileInfoMap.erase(fFileInfoMap.find(filename));
break;
}
fNetError |= !succeeded;
--fNumDownloadRequestsRemaining;
LogMsg(kLogPerf, L"Received secure file:%s, success:%s", filename.c_str(), succeeded ? L"Yep" : L"Nope");
if (!succeeded)
IPreloadComplete();
else
// Issue some file download requests (up to kMaxConcurrency)
IIssueDownloadRequests();
}
//============================================================================
void pfSecurePreloader::INotifyAuthReconnected () {
// The secure file download network protocol will now just pick up downloading
// where it left off before the reconnect, so no need to reset in-progress files.
/*
std::map<std::wstring, fileInfo>::iterator curFile;
for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++) {
// Reset files that were currently downloading
if (curFile->second.fDownloading)
curFile->second.fDownloading = false;
}
if (fNumDownloadRequestsRemaining > 0) {
LogMsg(kLogPerf, L"pfSecurePreloader: Auth reconnected, resetting in-progress file downloads");
// Issue some file download requests (up to kMaxConcurrency)
IIssueDownloadRequests();
}
*/
}
//============================================================================
void pfSecurePreloader::IPreloadComplete () {
DEL(fProgressBar);
fProgressBar = nil;
plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg();
msg->fSuccess = !fNetError;
msg->Send();
}
//============================================================================
hsBool pfSecurePreloader::MsgReceive (plMessage * msg) {
if (plNetCommAuthConnectedMsg * authMsg = plNetCommAuthConnectedMsg::ConvertNoRef(msg)) {
INotifyAuthReconnected();
return true;
}
return hsKeyedObject::MsgReceive(msg);
}
//============================================================================
pfSecurePreloader * pfSecurePreloader::GetInstance () {
if (!fInstance) {
fInstance = NEWZERO(pfSecurePreloader);
fInstance->RegisterAs(kSecurePreloader_KEY);
}
return fInstance;
}
//============================================================================
bool pfSecurePreloader::IsInstanced () {
return fInstance != nil;
}
//============================================================================
void pfSecurePreloader::Init () {
if (!fInitialized) {
fInitialized = true;
plgDispatch::Dispatch()->RegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey());
}
}
//============================================================================
void pfSecurePreloader::Shutdown () {
if (fInitialized) {
fInitialized = false;
plgDispatch::Dispatch()->UnRegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey());
}
if (fInstance) {
fInstance->UnRegister();
fInstance = nil;
}
}
//============================================================================
pfSecurePreloader::pfSecurePreloader () {
}
//============================================================================
pfSecurePreloader::~pfSecurePreloader () {
Cleanup();
}