/*==LICENSE==*

CyanWorlds.com Engine - MMOG client, server and tools
Copyright (C) 2011  Cyan Worlds, Inc.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.

You can contact Cyan Worlds, Inc. by email legal@cyan.com
 or by snail mail at:
      Cyan Worlds, Inc.
      14617 N Newport Hwy
      Mead, WA   99021

*==LICENSE==*/
#include "hsSTLStream.h"
#include "hsResMgr.h"
#include "plgDispatch.h"
#include "pnUtils/pnUtils.h"
#include "pnNetBase/pnNetBase.h"
#include "pnAsyncCore/pnAsyncCore.h"
#include "pnNetCli/pnNetCli.h"
#include "plNetGameLib/plNetGameLib.h"
#include "plFile/plFileUtils.h"
#include "plFile/plStreamSource.h"
#include "plNetCommon/plNetCommon.h"
#include "plProgressMgr/plProgressMgr.h"
#include "plMessage/plPreloaderMsg.h"
#include "plMessage/plNetCommMsgs.h"
#include "pfSecurePreloader.h"

#include "plNetClientComm/plNetClientComm.h"

extern	hsBool	gDataServerLocal;


// Max number of concurrent file downloads
static const unsigned kMaxConcurrency	= 1;

pfSecurePreloader * pfSecurePreloader::fInstance;

///////////////////////////////////////////////////////////////////////////////
// Callback routines for the network code

// Called when a file's info is retrieved from the server
static void DefaultFileListRequestCallback(ENetError result, void* param, const NetCliAuthFileInfo infoArr[], unsigned infoCount)
{
	bool success = !IS_NET_ERROR(result);
	
	std::vector<std::wstring> filenames;
	std::vector<UInt32> sizes;
	if (success)
	{
		filenames.reserve(infoCount);
		sizes.reserve(infoCount);
		for (unsigned curFile = 0; curFile < infoCount; curFile++)
		{
			filenames.push_back(infoArr[curFile].filename);
			sizes.push_back(infoArr[curFile].filesize);
		}
	}
	((pfSecurePreloader*)param)->RequestFinished(filenames, sizes, success);
}

// Called when a file download is either finished, or failed
static void DefaultFileRequestCallback(ENetError result, void* param, const wchar filename[], hsStream* stream)
{
	// Retry download unless shutting down or file not found
	switch (result) {
		case kNetSuccess:
			((pfSecurePreloader*)param)->FinishedDownload(filename, true);
		break;
		
		case kNetErrFileNotFound:
		case kNetErrRemoteShutdown:
			((pfSecurePreloader*)param)->FinishedDownload(filename, false);
		break;
		
		default:
			stream->Rewind();
			NetCliAuthFileRequest(
				filename,
				stream, 
				&DefaultFileRequestCallback,
				param
			);
		break;
	}
}


///////////////////////////////////////////////////////////////////////////////
// Our custom stream for writing directly to disk securely, and updating the
//  progress bar. Does NOT support reading (cause it doesn't need to)
class Direct2DiskStream : public hsUNIXStream
{
protected:
	wchar *			fWriteFileName;

	pfSecurePreloader* fPreloader;

public:
	Direct2DiskStream(pfSecurePreloader* preloader);
	~Direct2DiskStream();

	virtual hsBool Open(const char* name, const char* mode = "wb");
	virtual hsBool Open(const wchar* name, const wchar* mode = L"wb");
	virtual hsBool Close();
	virtual UInt32 Read(UInt32 byteCount, void* buffer);
	virtual UInt32 Write(UInt32 byteCount, const void* buffer);
};


Direct2DiskStream::Direct2DiskStream(pfSecurePreloader* preloader) :
fWriteFileName(nil),
fPreloader(preloader)
{}

Direct2DiskStream::~Direct2DiskStream()
{
	Close();
}

hsBool Direct2DiskStream::Open(const char* name, const char* mode)
{
	wchar* wName = hsStringToWString(name);
	wchar* wMode = hsStringToWString(mode);
	hsBool ret = Open(wName, wMode);
	delete [] wName;
	delete [] wMode;
	return ret;
}

hsBool Direct2DiskStream::Open(const wchar* name, const wchar* mode)
{
	if (0 != wcscmp(mode, L"wb")) {
		hsAssert(0, "Unsupported open mode");
		return false;
	}
	
	fWriteFileName = TRACKED_NEW(wchar[wcslen(name) + 1]);
	wcscpy(fWriteFileName, name);
	
//	LogMsg(kLogPerf, L"Opening disk file %S", fWriteFileName);
	return hsUNIXStream::Open(name, mode);
}

hsBool Direct2DiskStream::Close()
{
	delete [] fWriteFileName;
	fWriteFileName = nil;
	return hsUNIXStream::Close();
}

UInt32 Direct2DiskStream::Read(UInt32 bytes, void* buffer)
{
	hsAssert(0, "not implemented");
	return 0; // we don't read
}

UInt32 Direct2DiskStream::Write(UInt32 bytes, const void* buffer)
{
//	LogMsg(kLogPerf, L"Writing %u bytes to disk file %S", bytes, fWriteFileName);
	fPreloader->UpdateProgressBar(bytes);
	return hsUNIXStream::Write(bytes, buffer);
}


///////////////////////////////////////////////////////////////////////////////
// secure preloader class implementation

// closes and deletes all streams
void pfSecurePreloader::ICleanupStreams()
{
	if (fD2DStreams.size() > 0)
	{
		std::map<std::wstring, hsStream*>::iterator curStream;
		for (curStream = fD2DStreams.begin(); curStream != fD2DStreams.end(); curStream++)
		{
			curStream->second->Close();
			delete curStream->second;
			curStream->second = nil;
		}
		fD2DStreams.clear();
	}
}

// queues a single file to be preloaded (does nothing if already preloaded)
void pfSecurePreloader::RequestSingleFile(std::wstring filename)
{
	fileRequest request;
	ZERO(request);
	request.fType = fileRequest::kSingleFile;
	request.fPath = filename;
	request.fExt = L"";

	fRequests.push_back(request);
}

// queues a group of files to be preloaded (does nothing if already preloaded)
void pfSecurePreloader::RequestFileGroup(std::wstring dir, std::wstring ext)
{
	fileRequest request;
	ZERO(request);
	request.fType = fileRequest::kFileList;
	request.fPath = dir;
	request.fExt = ext;

	fRequests.push_back(request);
}

// preloads all requested files from the server (does nothing if already preloaded)
void pfSecurePreloader::Start()
{
	if (gDataServerLocal) {
		// using local data, don't do anything
		plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg();
		msg->fSuccess = true;
		msg->Send();
		return;
	}

	NetCliAuthGetEncryptionKey(fEncryptionKey, 4); // grab the encryption key from the server

	fNetError = false;

	// make sure we are all cleaned up
	ICleanupStreams();
	fTotalDataReceived = 0;

	// update the progress bar for downloading
	if (!fProgressBar)
		fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fRequests.size()), "Getting file info...", plProgressMgr::kUpdateText, false, true);
	
	for (unsigned curRequest = 0; curRequest < fRequests.size(); curRequest++)
	{
		fNumInfoRequestsRemaining++; // increment the counter
		if (fRequests[curRequest].fType == fileRequest::kSingleFile)
		{
#ifndef PLASMA_EXTERNAL_RELEASE
			// in internal releases, we can use on-disk files if they exist
			if (plFileUtils::FileExists(fRequests[curRequest].fPath.c_str()))
			{
				fileInfo info;
				info.fOriginalNameAndPath = fRequests[curRequest].fPath;
				info.fSizeInBytes = plFileUtils::GetFileSize(info.fOriginalNameAndPath.c_str());
				info.fDownloading = false;
				info.fDownloaded = false;
				info.fLocal = true;

				// generate garbled name
				wchar_t pathBuffer[MAX_PATH + 1];
				wchar_t filename[arrsize(pathBuffer)];
				GetTempPathW(arrsize(pathBuffer), pathBuffer);
				GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
				info.fGarbledNameAndPath = filename;

				fTotalDataDownload += info.fSizeInBytes;

				fFileInfoMap[info.fOriginalNameAndPath] = info;
			}
			// internal client will still request it, even if it exists locally,
			// so that things get updated properly
#endif // PLASMA_EXTERNAL_RELEASE
			NetCliAuthFileListRequest(
				fRequests[curRequest].fPath.c_str(),
				nil,
				&DefaultFileListRequestCallback,
				(void*)this
			);
		}
		else
		{
#ifndef PLASMA_EXTERNAL_RELEASE
			// in internal releases, we can use on-disk files if they exist
			// Build the search string as "dir\\*.ext"
			wchar searchStr[MAX_PATH];

			PathAddFilename(searchStr, fRequests[curRequest].fPath.c_str(), L"*", arrsize(searchStr));
			PathSetExtension(searchStr, searchStr, fRequests[curRequest].fExt.c_str(), arrsize(searchStr));

			ARRAY(PathFind) paths;
			PathFindFiles(&paths, searchStr, kPathFlagFile); // find all files that match

			// convert it to our little file info array
			PathFind* curFile = paths.Ptr();
			PathFind* lastFile = paths.Term();
			while (curFile != lastFile) {
				fileInfo info;
				info.fOriginalNameAndPath = curFile->name;
				info.fSizeInBytes = (UInt32)curFile->fileLength;
				info.fDownloading = false;
				info.fDownloaded = false;
				info.fLocal = true;

				// generate garbled name
				wchar_t pathBuffer[MAX_PATH + 1];
				wchar_t filename[arrsize(pathBuffer)];
				GetTempPathW(arrsize(pathBuffer), pathBuffer);
				GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
				info.fGarbledNameAndPath = filename;

				fTotalDataDownload += info.fSizeInBytes;

				fFileInfoMap[info.fOriginalNameAndPath] = info;
				curFile++;
			}
#endif // PLASMA_EXTERNAL_RELEASE

			NetCliAuthFileListRequest(
				fRequests[curRequest].fPath.c_str(),
				fRequests[curRequest].fExt.c_str(),
				&DefaultFileListRequestCallback,
				(void*)this
			);
		}
	}
}

// closes all file pointers and cleans up after itself
void pfSecurePreloader::Cleanup()
{
	ICleanupStreams();

	fRequests.clear();
	fFileInfoMap.clear();

	fNumInfoRequestsRemaining = 0;
	fTotalDataDownload = 0;
	fTotalDataReceived = 0;

	DEL(fProgressBar);
	fProgressBar = nil;
}

//============================================================================
void pfSecurePreloader::RequestFinished(const std::vector<std::wstring> & filenames, const std::vector<UInt32> & sizes, bool succeeded)
{
	fNetError |= !succeeded;
	
	if (succeeded)
	{
		unsigned count = 0;
		for (int curFile = 0; curFile < filenames.size(); curFile++)
		{
			if (fFileInfoMap.find(filenames[curFile]) != fFileInfoMap.end())
				continue; // if it is a duplicate, ignore it (the duplicate is probably one we found locally)

			fileInfo info;
			info.fOriginalNameAndPath = filenames[curFile];
			info.fSizeInBytes = sizes[curFile];
			info.fDownloading = false;
			info.fDownloaded = false;
			info.fLocal = false; // if we get here, it was retrieved remotely

			// generate garbled name
			wchar_t pathBuffer[MAX_PATH + 1];
			wchar_t filename[arrsize(pathBuffer)];
			GetTempPathW(arrsize(pathBuffer), pathBuffer);
			GetTempFileNameW(pathBuffer, L"CYN", 0, filename);
			info.fGarbledNameAndPath = filename;

			fTotalDataDownload += info.fSizeInBytes;

			fFileInfoMap[info.fOriginalNameAndPath] = info;
			++count;
		}
		LogMsg(kLogPerf, "Added %u files to secure download queue", count);
	}
	if (fProgressBar)
		fProgressBar->Increment(1.f);
		
	--fNumInfoRequestsRemaining;	// even if we fail, decrement the counter

	if (succeeded) {
		DEL(fProgressBar);
		fProgressBar = plProgressMgr::GetInstance()->RegisterOperation((hsScalar)(fTotalDataDownload), "Downloading...", plProgressMgr::kUpdateText, false, true);

		// Issue some file download requests (up to kMaxConcurrency)
		IIssueDownloadRequests();
	}
	else {
		IPreloadComplete();
	}	
}

//============================================================================
void pfSecurePreloader::IIssueDownloadRequests () {

	std::map<std::wstring, fileInfo>::iterator curFile;
	for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++)
	{
		// Skip files already downloaded or currently downloading
		if (curFile->second.fDownloaded || curFile->second.fDownloading)
			continue;
			
		std::wstring filename = curFile->second.fOriginalNameAndPath;
#ifndef PLASMA_EXTERNAL_RELEASE
		// in internal releases, we can use on-disk files if they exist
		if (plFileUtils::FileExists(filename.c_str()))
		{
			// don't bother streaming, just make the secure stream using the local file

			// a local key overrides the server-downloaded key
			UInt32 localKey[4];
			bool hasLocalKey = plFileUtils::GetSecureEncryptionKey(filename.c_str(), localKey, arrsize(localKey));
			hsStream* stream = nil;
			if (hasLocalKey)
				stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, localKey);
			else
				stream = plSecureStream::OpenSecureFile(filename.c_str(), 0, fEncryptionKey);

			// add it to the stream source
			bool added = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream);
			if (!added)
				DEL(stream); // wasn't added, so nuke our local copy

			// and make sure the vars are set up right
			curFile->second.fDownloaded = true;
			curFile->second.fLocal = true;
		}
		else
#endif
		{
			// Enforce concurrency limit
			if (fNumDownloadRequestsRemaining >= kMaxConcurrency)
				break;

			curFile->second.fDownloading = true;
			curFile->second.fDownloaded = false;
			curFile->second.fLocal = false;

			// create and setup the stream
			Direct2DiskStream* fileStream = TRACKED_NEW Direct2DiskStream(this);
			fileStream->Open(curFile->second.fGarbledNameAndPath.c_str(), L"wb");
			fD2DStreams[filename] = (hsStream*)fileStream;

			// request the file from the server
			LogMsg(kLogPerf, L"Requesting secure file:%s", filename.c_str());
			++fNumDownloadRequestsRemaining;
			NetCliAuthFileRequest(
				filename.c_str(),
				(hsStream*)fileStream, 
				&DefaultFileRequestCallback,
				this
			);
		}
	}
	
	if (!fNumDownloadRequestsRemaining)
		IPreloadComplete();
}

void pfSecurePreloader::UpdateProgressBar(UInt32 bytesReceived)
{
	fTotalDataReceived += bytesReceived;
	if (fTotalDataReceived > fTotalDataDownload)
		fTotalDataReceived = fTotalDataDownload; // shouldn't happen... but just in case

	if (fProgressBar)
		fProgressBar->Increment((hsScalar)bytesReceived);
}

void pfSecurePreloader::FinishedDownload(std::wstring filename, bool succeeded)
{
	for (;;)
	{
		if (fFileInfoMap.find(filename) == fFileInfoMap.end())
		{
			// file doesn't exist... abort
			succeeded = false;
			break;
		}

		fFileInfoMap[filename].fDownloading = false;

		// close and delete the writer stream (even if we failed)
		fD2DStreams[filename]->Close();
		delete fD2DStreams[filename];
		fD2DStreams.erase(fD2DStreams.find(filename));

		if (succeeded)
		{
			// open a secure stream to that file
			hsStream* stream = plSecureStream::OpenSecureFile(
				fFileInfoMap[filename].fGarbledNameAndPath.c_str(),
				plSecureStream::kRequireEncryption | plSecureStream::kDeleteOnExit, // force delete and encryption
				fEncryptionKey
			);

			bool addedToSource = plStreamSource::GetInstance()->InsertFile(filename.c_str(), stream);
			if (!addedToSource)
				DEL(stream); // cleanup if it wasn't added

			fFileInfoMap[filename].fDownloaded = true;
			break;
		}
		
		// file download failed, clean up after it

		// delete the temporary file
		if (plFileUtils::FileExists(fFileInfoMap[filename].fGarbledNameAndPath.c_str()))
			plFileUtils::RemoveFile(fFileInfoMap[filename].fGarbledNameAndPath.c_str(), true);

		// and remove it from the info map
		fFileInfoMap.erase(fFileInfoMap.find(filename));
		break;
	}
		
	fNetError |= !succeeded;
	--fNumDownloadRequestsRemaining;
	LogMsg(kLogPerf, L"Received secure file:%s, success:%s", filename.c_str(), succeeded ? L"Yep" : L"Nope");

	if (!succeeded)
		IPreloadComplete();
	else
		// Issue some file download requests (up to kMaxConcurrency)
		IIssueDownloadRequests();
}

//============================================================================
void pfSecurePreloader::INotifyAuthReconnected () {

	// The secure file download network protocol will now just pick up downloading
	// where it left off before the reconnect, so no need to reset in-progress files.
	
	/*
	std::map<std::wstring, fileInfo>::iterator curFile;
	for (curFile = fFileInfoMap.begin(); curFile != fFileInfoMap.end(); curFile++) {

		// Reset files that were currently downloading
		if (curFile->second.fDownloading)
			curFile->second.fDownloading = false;
	}

	if (fNumDownloadRequestsRemaining > 0) {

		LogMsg(kLogPerf, L"pfSecurePreloader: Auth reconnected, resetting in-progress file downloads");

		// Issue some file download requests (up to kMaxConcurrency)
		IIssueDownloadRequests();
	}
	*/
}

//============================================================================
void pfSecurePreloader::IPreloadComplete () {
	DEL(fProgressBar);
	fProgressBar = nil;
	
	plPreloaderMsg * msg = TRACKED_NEW plPreloaderMsg();
	msg->fSuccess = !fNetError;
	msg->Send();
}

//============================================================================
hsBool pfSecurePreloader::MsgReceive (plMessage * msg) {

	if (plNetCommAuthConnectedMsg * authMsg = plNetCommAuthConnectedMsg::ConvertNoRef(msg)) {
	
		INotifyAuthReconnected();
		return true;
	}
	
	return hsKeyedObject::MsgReceive(msg);
}

//============================================================================
pfSecurePreloader * pfSecurePreloader::GetInstance () {

	if (!fInstance) {
	
		fInstance = NEWZERO(pfSecurePreloader);
		fInstance->RegisterAs(kSecurePreloader_KEY);
	}

	return fInstance;
}

//============================================================================
bool pfSecurePreloader::IsInstanced () {

	return fInstance != nil;
}

//============================================================================
void pfSecurePreloader::Init () {

	if (!fInitialized) {
		
		fInitialized = true;
		plgDispatch::Dispatch()->RegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey());
	}
}

//============================================================================
void pfSecurePreloader::Shutdown () {

	if (fInitialized) {
		
		fInitialized = false;
		plgDispatch::Dispatch()->UnRegisterForExactType(plNetCommAuthConnectedMsg::Index(), GetKey());
	}

	if (fInstance) {
	
		fInstance->UnRegister();
		fInstance = nil;
	}
}

//============================================================================
pfSecurePreloader::pfSecurePreloader () {
}

//============================================================================
pfSecurePreloader::~pfSecurePreloader () {

	Cleanup();
}