/*==LICENSE==* CyanWorlds.com Engine - MMOG client, server and tools Copyright (C) 2011 Cyan Worlds, Inc. This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . You can contact Cyan Worlds, Inc. by email legal@cyan.com or by snail mail at: Cyan Worlds, Inc. 14617 N Newport Hwy Mead, WA 99021 *==LICENSE==*/ #include "hsSTLStream.h" #include "hsResMgr.h" #include "plgDispatch.h" #include "../pnUtils/pnUtils.h" #include "../pnNetBase/pnNetBase.h" #include "../pnAsyncCore/pnAsyncCore.h" #include "../pnNetCli/pnNetCli.h" #include "../plNetGameLib/plNetGameLib.h" #include "../plFile/plFileUtils.h" #include "../plFile/plStreamSource.h" #include "../plNetCommon/plNetCommon.h" #include "../plProgressMgr/plProgressMgr.h" #include "../plMessage/plPreloaderMsg.h" #include "../plMessage/plNetCommMsgs.h" #include "pfSecurePreloader.h" #include "../plNetClientComm/plNetClientComm.h" extern hsBool gDataServerLocal; // Max number of concurrent file downloads static const unsigned kMaxConcurrency = 1; pfSecurePreloader * pfSecurePreloader::fInstance; /////////////////////////////////////////////////////////////////////////////// // Callback routines for the network code // Called when a file's info is retrieved from the server static void DefaultFileListRequestCallback(ENetError result, void* param, const NetCliAuthFileInfo infoArr[], unsigned infoCount) { bool success = !IS_NET_ERROR(result); std::vector filenames; std::vector sizes; if (success) { filenames.reserve(infoCount); sizes.reserve(infoCount); for (unsigned curFile = 0; curFile < infoCount; curFile++) { filenames.push_back(infoArr[curFile].filename); sizes.push_back(infoArr[curFile].filesize); } } ((pfSecurePreloader*)param)->RequestFinished(filenames, sizes, success); } // Called when a file download is either finished, or failed static void DefaultFileRequestCallback(ENetError result, void* param, const wchar filename[], hsStream* stream) { // Retry download unless shutting down or file not found switch (result) { case kNetSuccess: ((pfSecurePreloader*)param)->FinishedDownload(filename, true); break; case kNetErrFileNotFound: case kNetErrRemoteShutdown: ((pfSecurePreloader*)param)->FinishedDownload(filename, false); break; default: stream->Rewind(); NetCliAuthFileRequest( filename, stream, &DefaultFileRequestCallback, param ); break; } } /////////////////////////////////////////////////////////////////////////////// // Our custom stream for writing directly to disk securely, and updating the // progress bar. Does NOT support reading (cause it doesn't need to) class Direct2DiskStream : public hsUNIXStream { protected: wchar * fWriteFileName; pfSecurePreloader* fPreloader; public: Direct2DiskStream(pfSecurePreloader* preloader); ~Direct2DiskStream(); virtual hsBool Open(const char* name, const char* mode = "wb"); virtual hsBool Open(const wchar* name, const wchar* mode = L"wb"); virtual hsBool Close(); virtual UInt32 Read(UInt32 byteCount, void* buffer); virtual UInt32 Write(UInt32 byteCount, const void* buffer); }; Direct2DiskStream::Direct2DiskStream(pfSecurePreloader* preloader) : fWriteFileName(nil), fPreloader(preloader) {} Direct2DiskStream::~Direct2DiskStream() { Close(); } hsBool Direct2DiskStream::Open(const char* name, const char* mode) { wchar* wName = hsStringToWString(name); wchar* wMode = hsStringToWString(mode); hsBool ret = Open(wName, wMode); delete [] wName; delete [] wMode; return ret; } hsBool Direct2DiskStream::Open(const wchar* name, const wchar* mode) { if (0 != wcscmp(mode, L"wb")) { hsAssert(0, "Unsupported open mode"); return false; } fWriteFileName = TRACKED_NEW(wchar[wcslen(name) + 1]); wcscpy(fWriteFileName, name); // LogMsg(kLogPerf, L"Opening disk file %S", fWriteFileName); return hsUNIXStream::Open(name, mode); } hsBool Direct2DiskStream::Close() { delete [] fWriteFileName; fWriteFileName = nil; return hsUNIXStream::Close(); } UInt32 Direct2DiskStream::Read(UInt32 bytes, void* buffer) { hsAssert(0, "not implemented"); return 0; // we don't read } UInt32 Direct2DiskStream::Write(UInt32 bytes, const void* buffer) { // LogMsg(kLogPerf, L"Writing %u bytes to disk file %S", bytes, fWriteFileName); fPreloader->UpdateProgressBar(bytes); return hsUNIXStream::Write(bytes, buffer); } /////////////////////////////////////////////////////////////////////////////// // secure preloader class implementation // closes and deletes all streams void pfSecurePreloader::ICleanupStreams() { if (fD2DStreams.size() > 0) { std::map::iterator curStream; for (curStream = fD2DStreams.begin(); curStream != fD2DStreams.end(); curStream++) { curStream->second->Close(); delete curStream->second; curStream->second = nil; } fD2DStreams.clear(); } } // queues a single file to be preloaded (does nothing if already preloaded) void pfSecurePreloader::RequestSingleFile(std::wstring filename) { fileRequest request; ZERO(request); request.fType = fileRequest::kSingleFile; request.fPath = filename; request.fExt = L""; fRequests.push_back(request); } // queues a group of files to be preloaded (does nothing if already preloaded) void pfSecurePreloader::RequestFileGroup(std::wstring dir, std::wstring ext) { fileRequest request; ZERO(request); request.fType = fileRequest::kFileList; request.fPath = dir; request.fExt = ext; fRequests.push_back(request); } // preloads all requested files from the server (does nothing if already preloaded) void pfSecurePreloader::Start() { if (gDataServerLocal) { // using local data, don't do anything plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg(); msg->fSuccess = true; msg->Send(); return; } NetCliAuthGetEncryptionKey(fEncryptionKey, 4); // grab the encryption key from the server fNetError = false; // make sure we are all cleaned up ICleanupStreams(); fTotalDataReceived = 0; // update the progress bar for downloading if (!fProgressBar) fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fRequests.size()), "Getting file info...", plProgressMgr::kUpdateText, false, true); for (unsigned curRequest = 0; curRequest < fRequests.size(); curRequest++) { fNumInfoRequestsRemaining++; // increment the counter if (fRequests[curRequest].fType == fileRequest::kSingleFile) { #ifndef PLASMA_EXTERNAL_RELEASE // in internal releases, we can use on-disk files if they exist if (plFileUtils::FileExists(fRequests[curRequest].fPath.c_str())) { fileInfo info; info.fOriginalNameAndPath = fRequests[curRequest].fPath; info.fSizeInBytes = plFileUtils::GetFileSize(info.fOriginalNameAndPath.c_str()); info.fDownloading = false; info.fDownloaded = false; info.fLocal = true; // generate garbled name wchar_t pathBuffer[MAX_PATH + 1]; wchar_t filename[arrsize(pathBuffer)]; GetTempPathW(arrsize(pathBuffer), pathBuffer); GetTempFileNameW(pathBuffer, L"CYN", 0, filename); info.fGarbledNameAndPath = filename; fTotalDataDownload += info.fSizeInBytes; fFileInfoMap[info.fOriginalNameAndPath] = info; } // internal client will still request it, even if it exists locally, // so that things get updated properly #endif // PLASMA_EXTERNAL_RELEASE NetCliAuthFileListRequest( fRequests[curRequest].fPath.c_str(), nil, &DefaultFileListRequestCallback, (void*)this ); } else { #ifndef PLASMA_EXTERNAL_RELEASE // in internal releases, we can use on-disk files if they exist // Build the search string as "dir\\*.ext" wchar searchStr[MAX_PATH]; PathAddFilename(searchStr, fRequests[curRequest].fPath.c_str(), L"*", arrsize(searchStr)); PathSetExtension(searchStr, searchStr, fRequests[curRequest].fExt.c_str(), arrsize(searchStr)); ARRAY(PathFind) paths; PathFindFiles(&paths, searchStr, kPathFlagFile); // find all files that match // convert it to our little file info array PathFind* curFile = paths.Ptr(); PathFind* lastFile = paths.Term(); while (curFile != lastFile) { fileInfo info; info.fOriginalNameAndPath = curFile->name; info.fSizeInBytes = (UInt32)curFile->fileLength; info.fDownloading = false; info.fDownloaded = false; info.fLocal = true; // generate garbled name wchar_t pathBuffer[MAX_PATH + 1]; wchar_t filename[arrsize(pathBuffer)]; GetTempPathW(arrsize(pathBuffer), pathBuffer); GetTempFileNameW(pathBuffer, L"CYN", 0, filename); info.fGarbledNameAndPath = filename; fTotalDataDownload += info.fSizeInBytes; fFileInfoMap[info.fOriginalNameAndPath] = info; curFile++; } #endif // PLASMA_EXTERNAL_RELEASE NetCliAuthFileListRequest( fRequests[curRequest].fPath.c_str(), fRequests[curRequest].fExt.c_str(), &DefaultFileListRequestCallback, (void*)this ); } } } // closes all file pointers and cleans up after itself void pfSecurePreloader::Cleanup() { ICleanupStreams(); fRequests.clear(); fFileInfoMap.clear(); fNumInfoRequestsRemaining = 0; fTotalDataDownload = 0; fTotalDataReceived = 0; DEL(fProgressBar); fProgressBar = nil; } //============================================================================ void pfSecurePreloader::RequestFinished(const std::vector & filenames, const std::vector & sizes, bool succeeded) { fNetError |= !succeeded; if (succeeded) { unsigned count = 0; for (int curFile = 0; curFile < filenames.size(); curFile++) { if (fFileInfoMap.find(filenames[curFile]) != fFileInfoMap.end()) continue; // if it is a duplicate, ignore it (the duplicate is probably one we found locally) fileInfo info; info.fOriginalNameAndPath = filenames[curFile]; info.fSizeInBytes = sizes[curFile]; info.fDownloading = false; info.fDownloaded = false; info.fLocal = false; // if we get here, it was retrieved remotely // generate garbled name wchar_t pathBuffer[MAX_PATH + 1]; wchar_t filename[arrsize(pathBuffer)]; GetTempPathW(arrsize(pathBuffer), pathBuffer); GetTempFileNameW(pathBuffer, L"CYN", 0, filename); info.fGarbledNameAndPath = filename; fTotalDataDownload += info.fSizeInBytes; fFileInfoMap[info.fOriginalNameAndPath] = info; ++count; } LogMsg(kLogPerf, "Added %u files to secure download queue", count); } if (fProgressBar) fProgressBar->Increment(1.f); --fNumInfoRequestsRemaining; // even if we fail, decrement the counter if (succeeded) { DEL(fProgressBar); fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fTotalDataDownload), "Downloading...", plProgressMgr::kUpdateText, false, true); // Issue some file download requests (up to kMaxConcurrency) IIssueDownloadRequests(); } else { IPreloadComplete(); } } //============================================================================ void pfSecurePreloader::IIssueDownloadRequests () { std::map::iterator curFile; for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++) { // Skip files already downloaded or currently downloading if (curFile->second.fDownloaded || curFile->second.fDownloading) continue; std::wstring filename = curFile->second.fOriginalNameAndPath; #ifndef PLASMA_EXTERNAL_RELEASE // in internal releases, we can use on-disk files if they exist if (plFileUtils::FileExists(filename.c_str())) { // don't bother streaming, just make the secure stream using the local file // a local key overrides the server-downloaded key UInt32 localKey[4]; bool hasLocalKey = plFileUtils::GetSecureEncryptionKey(filename.c_str(), localKey, arrsize(localKey)); hsStream* stream = nil; if (hasLocalKey) stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, localKey); else stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, fEncryptionKey); // add it to the stream source bool added = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream); if (!added) DEL(stream); // wasn't added, so nuke our local copy // and make sure the vars are set up right curFile->second.fDownloaded = true; curFile->second.fLocal = true; } else #endif { // Enforce concurrency limit if (fNumDownloadRequestsRemaining >= kMaxConcurrency) break; curFile->second.fDownloading = true; curFile->second.fDownloaded = false; curFile->second.fLocal = false; // create and setup the stream Direct2DiskStream* fileStream = TRACKED_NEW Direct2DiskStream(this); fileStream->Open(curFile->second.fGarbledNameAndPath.c_str(), L"wb"); fD2DStreams[filename] = (hsStream*)fileStream; // request the file from the server LogMsg(kLogPerf, L"Requesting secure file:%s", filename.c_str()); ++fNumDownloadRequestsRemaining; NetCliAuthFileRequest( filename.c_str(), (hsStream*)fileStream, &DefaultFileRequestCallback, this ); } } if (!fNumDownloadRequestsRemaining) IPreloadComplete(); } void pfSecurePreloader::UpdateProgressBar(UInt32 bytesReceived) { fTotalDataReceived += bytesReceived; if (fTotalDataReceived > fTotalDataDownload) fTotalDataReceived = fTotalDataDownload; // shouldn't happen... but just in case if (fProgressBar) fProgressBar->Increment((hsScalar)bytesReceived); } void pfSecurePreloader::FinishedDownload(std::wstring filename, bool succeeded) { for (;;) { if (fFileInfoMap.find(filename) == fFileInfoMap.end()) { // file doesn't exist... abort succeeded = false; break; } fFileInfoMap[filename].fDownloading = false; // close and delete the writer stream (even if we failed) fD2DStreams[filename]->Close(); delete fD2DStreams[filename]; fD2DStreams.erase(fD2DStreams.find(filename)); if (succeeded) { // open a secure stream to that file hsStream* stream = plSecureStream::OpenSecureFile( fFileInfoMap[filename].fGarbledNameAndPath.c_str(), plSecureStream::kRequireEncryption | plSecureStream::kDeleteOnExit, // force delete and encryption fEncryptionKey ); bool addedToSource = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream); if (!addedToSource) DEL(stream); // cleanup if it wasn't added fFileInfoMap[filename].fDownloaded = true; break; } // file download failed, clean up after it // delete the temporary file if (plFileUtils::FileExists(fFileInfoMap[filename].fGarbledNameAndPath.c_str())) plFileUtils::RemoveFile(fFileInfoMap[filename].fGarbledNameAndPath.c_str(), true); // and remove it from the info map fFileInfoMap.erase(fFileInfoMap.find(filename)); break; } fNetError |= !succeeded; --fNumDownloadRequestsRemaining; LogMsg(kLogPerf, L"Received secure file:%s, success:%s", filename.c_str(), succeeded ? L"Yep" : L"Nope"); if (!succeeded) IPreloadComplete(); else // Issue some file download requests (up to kMaxConcurrency) IIssueDownloadRequests(); } //============================================================================ void pfSecurePreloader::INotifyAuthReconnected () { // The secure file download network protocol will now just pick up downloading // where it left off before the reconnect, so no need to reset in-progress files. /* std::map::iterator curFile; for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++) { // Reset files that were currently downloading if (curFile->second.fDownloading) curFile->second.fDownloading = false; } if (fNumDownloadRequestsRemaining > 0) { LogMsg(kLogPerf, L"pfSecurePreloader: Auth reconnected, resetting in-progress file downloads"); // Issue some file download requests (up to kMaxConcurrency) IIssueDownloadRequests(); } */ } //============================================================================ void pfSecurePreloader::IPreloadComplete () { DEL(fProgressBar); fProgressBar = nil; plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg(); msg->fSuccess = !fNetError; msg->Send(); } //============================================================================ hsBool pfSecurePreloader::MsgReceive (plMessage * msg) { if (plNetCommAuthConnectedMsg * authMsg = plNetCommAuthConnectedMsg::ConvertNoRef(msg)) { INotifyAuthReconnected(); return true; } return hsKeyedObject::MsgReceive(msg); } //============================================================================ pfSecurePreloader * pfSecurePreloader::GetInstance () { if (!fInstance) { fInstance = NEWZERO(pfSecurePreloader); fInstance->RegisterAs(kSecurePreloader_KEY); } return fInstance; } //============================================================================ bool pfSecurePreloader::IsInstanced () { return fInstance != nil; } //============================================================================ void pfSecurePreloader::Init () { if (!fInitialized) { fInitialized = true; plgDispatch::Dispatch()->RegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey()); } } //============================================================================ void pfSecurePreloader::Shutdown () { if (fInitialized) { fInitialized = false; plgDispatch::Dispatch()->UnRegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey()); } if (fInstance) { fInstance->UnRegister(); fInstance = nil; } } //============================================================================ pfSecurePreloader::pfSecurePreloader () { } //============================================================================ pfSecurePreloader::~pfSecurePreloader () { Cleanup(); }