import * as rlConfig from "./rl_config"
import * as rlHyperParams from "./rl_hyperparameters"
import * as rlQtable from "./rl_q_table"
import * as rlStat from "./rl_stats"
import * as reward from "./reward_function"

// For Q-table
let maxVelocity;
let minVelocity;
let maxDistance;
let minDistance;

let availableActionCount;
let weightedAvailableBrakingForceFactorIndexArray = [];

let model;
let explorationRate;
let learningRate;

let isCurrentTraining = false;
let totalNoOfEpisode = 0;

// Episode Data
let episodeStates = [];
let episodeActionIndexes = [];
let episodeNonExploreActionIndexes = [];
let episodeRewards = [];
let lastDistance = null;

// Training Accumulate Data
let xs_array_waitingToTrain = [];
let ys_array_waitingToTrain = [];

let lastAction = null;

export function initConfig() {
	maxDistance = rlConfig.maxInitialDistance;
	minDistance = rlConfig.episodeTerminationDistance;
	maxVelocity = rlConfig.maxInitialVelocity;
	minVelocity = 0;

	availableActionCount = rlConfig.availableBrakingForceFactorArray.length;

	for (let i = 0; i < availableActionCount; i++) {
		let weight = 0;
		if (i < rlConfig.availableBrakingForceFactorWeightArray.length) {
			weight = rlConfig.availableBrakingForceFactorWeightArray[i];
		}

		for (let j = 0; j < weight; j++) {
			weightedAvailableBrakingForceFactorIndexArray.push(i);
		}
	}

	return [
		true,
		rlConfig.maxGameSpeed,
		rlConfig.carBrakingMode,
		rlConfig.actionPerSecond,
		rlConfig.maxInitialDistance,
		rlConfig.minInitialDistance,
		rlConfig.maxInitialVelocity,
		rlConfig.minInitialVelocity,
		rlConfig.episodeTerminationDistance,
		rlHyperParams.noOfEpisodeBeforeTrain,
		rlConfig.recommendedNoOfTrainToDo,
		rlConfig.noOfEvaluationTrial,
		rlConfig.maxSuccessDistance,
		rlConfig.minSuccessDistance,
		rlConfig.largestReward,
		rlConfig.smallestReward,
		rlConfig.rewardItemArray,
	];
}

export function reset() {
	if (rlConfig.rl_mode === 0) {
		// model = tf.sequential();
		// model.add(tf.layers.dense({ units: 32, inputShape: [stateInputNo], activation: 'relu' }));
		// model.add(tf.layers.dense({ units: availableActionCount, activation: 'relu' }));
		// model.compile({ optimizer: tf.train.adam(), loss: 'meanSquaredError' });
	} else if (rlConfig.rl_mode === 1) {
		rlQtable.resetQTable(maxVelocity, minVelocity, maxDistance, minDistance, availableActionCount);
	}

	rlStat.resetStats(true, true);
	resetTrainingData();
	resetEpisodeData();
	updateDecayFactors();

	return true;
}

function resetTrainingData() {
	xs_array_waitingToTrain = [];
	ys_array_waitingToTrain = [];
}

function resetEpisodeData() {
	lastDistance = null;

	if (episodeStates != null) {
		if (rlConfig.rl_mode === 0) {
			// tf.dispose(episodeStates);
		}
	}

	episodeStates = [];
	episodeActionIndexes = [];
	episodeNonExploreActionIndexes = [];
	episodeRewards = [];
}

function updateDecayFactors() {
	explorationRate = getDecayedExplorationRate(rlStat.noOfFinishedEpisodeInTraining);
	if (rlConfig.rl_mode === 1) {
		learningRate = getDecayedLearningRate(rlStat.noOfFinishedEpisodeInTraining);
	}

	console.log("explorationRate   ", explorationRate.toFixed(2), "  ;  learningRate   ", learningRate.toFixed(2));
}

function getDecayedExplorationRate(decayIndex) {
	if (decayIndex <= 0) {
		return rlHyperParams.maxExplorationRate;
	} else {
		let decayedValue = rlHyperParams.minExplorationRate + (1 - rlHyperParams.minExplorationRate) * Math.exp((decayIndex + 1) * rlHyperParams.explorationRateDecayRate * (-1));
		return Math.min(rlHyperParams.maxExplorationRate, decayedValue);
	}
}

function getDecayedLearningRate(decayIndex) {
	if (decayIndex <= 0) {
		return rlHyperParams.maxLearningRate;
	} else {
		let decayedValue = Math.min(rlHyperParams.maxLearningRate, 1 - Math.log10((decayIndex + 1) * rlHyperParams.learningRateDecayRate));
		return Math.max(rlHyperParams.minLearningRate, decayedValue);
	}
}

// Remark (Tracy 20200131): comment out this to avoid error, since tf is undefined now.
// function getStateTensor(remainDistance, velocity) {
// 	let state = [remainDistance, velocity];
// 	// let state = [Math.round(remainDistance * 10) / 10, Math.round(velocity * 10) / 10];
// 	// console.log(remainDistance.toFixed(2), "   ", velocity.toFixed(2));
// 	return tf.tensor2d(state, [1, stateInputNo]);
// }

// TODO : To simplify the isSaveData logic, may need to add isEpisodeFinished parameter from Unity side while calling "chooseEvaluationAction" JS action
export function chooseAction(remainDistance, velocity, isEpisodeFinished) {
	let action;

	if (rlConfig.carBrakingMode === 0) {
		let isSaveData = true;
		if (lastAction !== null) {
			isSaveData = false;
		}
		action = getActionAddSaveData(remainDistance, velocity, isCurrentTraining, isSaveData); 
	} else {
		action = getActionAddSaveData(remainDistance, velocity, isCurrentTraining, !isEpisodeFinished);  
	}

	if (lastAction !== null) {	// For the purpose of skipping first reward
		episodeRewards.push(reward.rewardFunction(remainDistance, velocity, isEpisodeFinished));
	}

	lastAction = action;

	lastDistance = remainDistance;

	return [true, action];
}

function getActionAddSaveData(remainDistance, velocity, isTrainingMode, isSaveData) {
    if (isSaveData) {
        if (!isCurrentTraining) {
            console.log("Q value array : ", rlQtable.getQTableValueArray(velocity, remainDistance));
        }
        episodeStates.push([velocity, remainDistance]);
    }
            
    let nonExploreActionIndex = rlQtable.getQTableBestAction(velocity, remainDistance, true);
    let actionIndex = getActionIndex(nonExploreActionIndex, isTrainingMode, isSaveData);
    const brakingForceFactor = rlConfig.availableBrakingForceFactorArray[actionIndex];
    if (isSaveData) {
        episodeActionIndexes.push(actionIndex);
    }
    
    return brakingForceFactor;
}

function getActionIndex(nonExploreActionIndex, isTrainingMode, isSaveData) {
	let returnValue = nonExploreActionIndex;

	if (isSaveData) {
		episodeNonExploreActionIndexes.push(nonExploreActionIndex);
	}

	if (isTrainingMode) {
		if (rlConfig.carBrakingMode === 0 || Math.random() < explorationRate) {
			returnValue = weightedAvailableBrakingForceFactorIndexArray[Math.floor(Math.random() * weightedAvailableBrakingForceFactorIndexArray.length)];
		} else {
			// console.log("exploitation");
		}
	}

	return returnValue;
}

function getMaxValueIndex(array) {
	if (!array || array.length === 0) {
		return null;
	}

	let max = array[0];
	let index = 0;

	for (let i = 1; i < array.length; i++) {
		if (array[i] > max) {
			max = array[i];
			index = i;
		}
	}
	return index;
}

export function startSimulation(isTraining, noOfEpisodeToSimulate) {
	isCurrentTraining = isTraining;
	totalNoOfEpisode = noOfEpisodeToSimulate;

	rlStat.resetNoOfFinishedEpisode(isCurrentTraining, !isCurrentTraining);
	return true;
}

export function startEpisode() {
	let episodeNo;
	if (isCurrentTraining) {
		episodeNo = rlStat.noOfFinishedEpisodeInTraining + 1;
	} else {
		episodeNo = rlStat.noOfFinishedEpisodeInEvaluation + 1;
	}

	lastAction = null;

	return [true, episodeNo];
}

export function endEpisode() {
	let totalReward = 0;
	episodeRewards.forEach(reward => {
		totalReward += reward;
	});

	if (isCurrentTraining) {
		console.log("Raw Action : ", episodeNonExploreActionIndexes);
	}
	console.log("Real Action : ", episodeActionIndexes);
	console.log("Total Reward : ", totalReward.toFixed(2), " , lastDistance : " + lastDistance.toFixed(2));

	if (isCurrentTraining) {
		rlStat.addTrainingEpisodeStats(totalReward, lastDistance, checkIsGoingToTrain(), explorationRate);
	} else {
		rlStat.addEvaluationEpisodeStats(totalReward, lastDistance);
	}

	if (isCurrentTraining) {
		const [xs_array_singleEpisode, ys_array_singleEpisode] = convertSingleEpisodeDateToTrainFormat(episodeStates, episodeActionIndexes, episodeRewards);

		xs_array_waitingToTrain = xs_array_waitingToTrain.concat(xs_array_singleEpisode);
		ys_array_waitingToTrain = ys_array_waitingToTrain.concat(ys_array_singleEpisode);
	}

	return [true, isCurrentTraining, totalReward, lastDistance];
}

function checkIsGoingToTrain() {
	let isGoingToTrain = false;
	if (isCurrentTraining) {
		isGoingToTrain = (rlStat.noOfFinishedEpisodeInTraining + 1) % rlHyperParams.noOfEpisodeBeforeTrain === 0;
	}

	return isGoingToTrain;
}

export function postEndEpisode(isRequestStopSimulation) {
	resetEpisodeData();

	if (checkIsGoingToTrain()) {
		if (rlConfig.rl_mode === 0) {
			// let xs = tf.tensor2d(xs_array_waitingToTrain, [xs_array_waitingToTrain.length / stateInputNo, stateInputNo]);
			// let ys = tf.tensor2d(ys_array_waitingToTrain, [ys_array_waitingToTrain.length / availableActionCount, availableActionCount]);

			// // xs.print();
			// // ys.print();

			// function onBatchEnd(batch, logs) {
			// 	// console.log("Training Log", logs);
			// 	addTrainingLog(logs);
			// }

			// model.fit(xs, ys, {
			// 	callbacks: { onBatchEnd }
			// }).then(() => {
			// 	console.log("training finished");
			// 	tf.dispose(xs);
			// 	tf.dispose(ys);
			// 	// showLastTrainingStats();

			// 	addTrainedRoundCount();
			// });
		} else if (rlConfig.rl_mode === 1) {
			rlQtable.updateQTableOnBatch(xs_array_waitingToTrain, ys_array_waitingToTrain);
			// TODO : Implement real loss function
			rlStat.addTrainingLog([{
				loss: -1,
			}]);
			rlStat.addTrainedRoundCount();
		}

		resetTrainingData();
		updateDecayFactors();
	}

	rlStat.addFinishedEpisodeInTrainingCount();

	let isStopSimulation = false;
	if (isRequestStopSimulation) {
		isStopSimulation = true;
	} else {
		if (totalNoOfEpisode > 0) {
			let episodeNo;
			if (isCurrentTraining) {
				episodeNo = rlStat.noOfFinishedEpisodeInTraining;
			} else {
				episodeNo = rlStat.noOfFinishedEpisodeInEvaluation;
			}

			if (episodeNo >= totalNoOfEpisode) {
				isStopSimulation = true;
			}
		}
	}

	return [true, isStopSimulation];
}

function convertSingleEpisodeDateToTrainFormat(singleEpisodeStateArray, singleEpisodeActionIndexArray, singleEpisodeRewardArray) {
	if (rlConfig.rl_mode === 0) {
		// return tf.tidy(() => {
		// 	let xs = tf.stack(singleEpisodeStateArray);
		// 	xs = xs.reshape([-1, stateInputNo]);

		// 	const original_ys = model.predictOnBatch(xs);
		// 	let data = original_ys.dataSync();

		// 	for (let i = 0; i < singleEpisodeActionIndexArray.length; i++) {

		// 		let index = availableActionCount * i + singleEpisodeActionIndexArray[i];

		// 		let maxQNext = 0;
		// 		let nextStateStartIndex = -1;

		// 		if (i != singleEpisodeActionIndexArray.length - 1 && singleEpisodeActionIndexArray.length > 1) {	// To ensure there is next state
		// 			nextStateStartIndex = availableActionCount * (i + 1);
		// 		}

		// 		if (nextStateStartIndex != -1) {
		// 			let nextStateArray = data.slice(nextStateStartIndex, nextStateStartIndex + availableActionCount);
		// 			maxQNext = Math.max(...nextStateArray);
		// 		}

		// 		// console.log(data[index], "     ", maxQNext, "     ", singleEpisodeRewardArray[i], "     ", data[index] + learningRate * (singleEpisodeRewardArray[i] + discountRate * maxQNext - data[index]));

		// 		data[index] = singleEpisodeRewardArray[i] + discountRate * maxQNext;
		// 	}
		// 	// console.log(data);
		// 	// const ys = tf.tensor2d(data, original_ys.shape);

		// 	return [Array.prototype.slice.call(xs.dataSync()), Array.prototype.slice.call(data)];
		// });
	} else if (rlConfig.rl_mode === 1) {
		// Monte Carlo
		let stateAndActionIndexArray = [];
		let updateQValueArray = [];

		let discountedAccumulatedReward = 0;
		for (let i = singleEpisodeStateArray.length - 1; i >= 0; i--) {
			discountedAccumulatedReward = singleEpisodeRewardArray[i] + rlHyperParams.discountRate * discountedAccumulatedReward;
			let originalQValue = rlQtable.getQTableValue(singleEpisodeStateArray[i][0], singleEpisodeStateArray[i][1], singleEpisodeActionIndexArray[i]);
			let updateQValue = originalQValue + learningRate * (discountedAccumulatedReward - originalQValue);

			updateQValueArray.push(updateQValue);
			stateAndActionIndexArray.push([singleEpisodeStateArray[i], singleEpisodeActionIndexArray[i]]);
		}

		return [stateAndActionIndexArray, updateQValueArray];
	}
}