public static void LoadData(string modelFileName, string waveDataFileName) { LocalLogManager.AddLine("Machine Learning manager initialization"); _modelFileName = modelFileName; _waveDataFileName = waveDataFileName; InputCount = (MaxPathingPoints * CountPerPathing) + (MaxWaterPoints * CountPerWater) + (MaxTowers * CountPerTower) + (MaxEnemies * CountPerEnemy); _machineLearningModel = new DeepBeliefNetworkModel(); lock (FileLocker) { var canLoadExistingModel = _machineLearningModel.Load(_modelFileName); if (!canLoadExistingModel) { _machineLearningModel.Initialize(InputCount, (int)(InputCount * 0.8)); } } var waveDataExists = FileManager.FileExists(_waveDataFileName); if (waveDataExists) { _waveData = FileManager.BinaryDeserialize(typeof(WaveData), _waveDataFileName) as WaveData; } _waveData = _waveData ?? new WaveData { WaveInputs = new List <double[]>(), WaveScores = new List <double>(), }; }
private static void LoadTestData() { CSV = new StringBuilder(); CSV.Append("SampleSize, AVMMSE, MSE"); CSV.Append(Environment.NewLine); var slr = 0.8; var sm = 0.8; var uslr = 0.9; var usm = 0.2; var usd = 0.21; InputCount = (MaxPathingPoints * CountPerPathing) + (MaxWaterPoints * CountPerWater) + (MaxTowers * CountPerTower) + (MaxEnemies * CountPerEnemy); _machineLearningModel = new DeepBeliefNetworkModel(); _machineLearningModel.Initialize(InputCount, (int)(InputCount * 0.8), uslr, usm, usd, slr, sm); var outputList = new List <double>(); var inputList = new List <double[]>(); _waveData = new WaveData { WaveInputs = inputList, WaveScores = outputList, }; var priorMSE = 0.0; using (var reader = new StreamReader(@"wavedata.csv")) { var sampleSize = 0; while (!reader.EndOfStream) { var line = reader.ReadLine(); var values = line.Split('\t'); var output = double.Parse(values[0]); outputList.Add(output); var inputCount = values.Length - 1; var inputArray = new double[inputCount]; for (var i = 1; i < inputCount; i++) { inputArray[i - 1] = double.Parse(values[i]); } inputList.Add(inputArray); _machineLearningModel.LearnAll(_waveData); CSV.Append($"{++sampleSize}, {_machineLearningModel.AVMMSE}, {_machineLearningModel.LastMSE}"); CSV.Append(Environment.NewLine); if ((sampleSize / 50 >= 1 && sampleSize % 50 == 0) || (sampleSize > 20 && Math.Abs(priorMSE - _machineLearningModel.LastMSE) > 0.05)) { var wait = _machineLearningModel.LastMSE; } priorMSE = _machineLearningModel.LastMSE; } } var fileName = "ModelData.csv"; File.WriteAllText(fileName, CSV.ToString()); }
private static void IncrementalTest() { var slr = 0.8; var sm = 0.8; var uslr = 0.9; var usm = 0.2; var usd = 0.21; InputCount = (MaxPathingPoints * CountPerPathing) + (MaxWaterPoints * CountPerWater) + (MaxTowers * CountPerTower) + (MaxEnemies * CountPerEnemy); var outputList = new List <double>(); var inputList = new List <double[]>(); using (var reader = new StreamReader(@"wavedata.csv")) { while (!reader.EndOfStream) { var line = reader.ReadLine(); var values = line.Split('\t'); var output = double.Parse(values[0]); outputList.Add(output); var inputCount = values.Length - 1; var inputArray = new double[inputCount]; for (var i = 1; i < inputCount; i++) { inputArray[i - 1] = double.Parse(values[i]); } inputList.Add(inputArray); } } for (var modelTest = 5; modelTest <= 9; modelTest++) { _machineLearningModel = new DeepBeliefNetworkModel(); _machineLearningModel.Initialize(InputCount, (int)(InputCount * 0.8), uslr, usm, usd, slr, sm); CSV = new StringBuilder(); CSV.Append("SampleSize, AVMMSE, MSE, LearnTime"); CSV.Append(Environment.NewLine); var partialOutput = new List <double>(); var partialInput = new List <double[]>(); _waveData = new WaveData { WaveInputs = partialInput, WaveScores = partialOutput, }; var sampleSize = FlatRedBallServices.Random.Next(150, 630); for (var i = 0; i <= sampleSize; i++) { var randomIndex = FlatRedBallServices.Random.Next(0, outputList.Count - 1); partialInput.Add(inputList[randomIndex]); partialOutput.Add(outputList[randomIndex]); _machineLearningModel.LearnAll(_waveData); CSV.Append($"{i + 1}, {_machineLearningModel.AVMMSE}, {_machineLearningModel.LastMSE}, {_machineLearningModel.LastLearnTime}"); CSV.Append(Environment.NewLine); } var fileName = $"Player{modelTest}Data.csv"; File.WriteAllText(fileName, CSV.ToString()); } }
private static void SearchTest() { var fileName = "CollectedData.csv"; InputCount = (MaxPathingPoints * CountPerPathing) + (MaxWaterPoints * CountPerWater) + (MaxTowers * CountPerTower) + (MaxEnemies * CountPerEnemy); var slrList = new List <double>(); var smList = new List <double>(); var uslrList = new List <double>(); var usmList = new List <double>(); var usdList = new List <double>(); CSV = new StringBuilder(); File.WriteAllText(fileName, CSV.ToString()); for (var slr = 0.0; slr <= 1; slr += 0.1) { slrList.Add(slr); } for (var sm = 0.0; sm <= 1; sm += 0.1) { smList.Add(sm); } for (var uslr = 0.0; uslr <= 1; uslr += 0.1) { uslrList.Add(uslr); } for (var usm = 0.0; usm <= 1; usm += 0.1) { usmList.Add(usm); } for (var usd = 0.01; usd <= 0.3; usd += 0.025) { usdList.Add(usd); } var triedCombos = new List <string>(); var outputList = new List <double>(); var inputList = new List <double[]>(); _waveData = new WaveData { WaveInputs = inputList, WaveScores = outputList, }; using (var reader = new StreamReader(@"wavedata.csv")) { while (!reader.EndOfStream) { var line = reader.ReadLine(); var values = line.Split('\t'); var output = double.Parse(values[0]); outputList.Add(output); var inputCount = values.Length - 1; var inputArray = new double[inputCount]; for (var i = 1; i < inputCount; i++) { inputArray[i - 1] = double.Parse(values[i]); } inputList.Add(inputArray); } } DeepBeliefNetworkModel baseModel = new DeepBeliefNetworkModel(); baseModel.Initialize(InputCount, (int)(InputCount * 0.8), 0, 0, 0, 0, 0); baseModel.Save("BaseModel.model"); DeepBeliefNetworkModel newModel = baseModel; while (true) { var slr = FlatRedBallServices.Random.In(slrList); var sm = FlatRedBallServices.Random.In(smList); var uslr = FlatRedBallServices.Random.In(uslrList); var usm = FlatRedBallServices.Random.In(usmList); var usd = FlatRedBallServices.Random.In(usdList); var combo = $"{slr}{sm}{uslr}{usm}{usd}"; if (triedCombos.Contains(combo)) { continue; } triedCombos.Add(combo); CSV.Clear(); newModel.Initialize(InputCount, (int)(InputCount * 0.8), uslr, usm, usd, slr, sm); newModel.Load("BaseModel.model"); CSV.Append($"{slr},{sm},{uslr},{usm},{usd},"); newModel.LearnAll(_waveData); CSV.Append(newModel.LastMSE); CSV.Append(Environment.NewLine); File.AppendAllText(fileName, CSV.ToString()); CSV.Clear(); if (newModel.LastMSE < 0.03) { newModel.Save("BESTMODEL" + newModel.LastMSE + "SAVE.model"); } } }