public static int TrainAndEvaulate(string type, string path) { if (string.IsNullOrWhiteSpace(path)) { return(-1); } // 1 in x of these will be picked up (the smaller the number the more of them var takeOnly = 0; var rand = new Random(); var actionData = new DataSet() { TrainingCountMax = 150000 }; var xyData = new DataSet() { TrainingCountMax = 10000 }; var angleData = new DataSet() { TrainingCountMax = 150000 }; int trainingCount = 0, testCount = 0; var lastFile = ""; var duplicates = new HashSet <int>(); foreach (var kvp in AITraining.GetTrainingFiles(path)) { var file = kvp.Key; var count = kvp.Value; lastFile = file; if (count > 0) { // for debugging if (takeOnly > 0 && --takeOnly == 0) { break; } // once enough data has been gathered, break if (actionData.TrainingCountMax + xyData.TrainingCountMax + angleData.TrainingCountMax == 0) { break; } foreach (var d in AITraining.GetTraingingData(file)) { if (d.Result) { var dm = d.AsModelDataSet(); var hash = dm.ComputeHash(); // void duplicates if (duplicates.Contains(hash)) { continue; } // sample ~15% for test if (rand.Next() % 6 == 0) { // test data set testCount++; if (actionData.TrainingCountMax > 0) { actionData.Test.Add(dm); } switch ((ActionEnum)d.Action) { case ActionEnum.Attack: if (angleData.TrainingCountMax > 0) { angleData.Test.Add(dm); } break; case ActionEnum.Move: if (xyData.TrainingCountMax > 0) { xyData.Test.Add(dm); } break; } } else { // training data set trainingCount++; if (actionData.TrainingCountMax > 0) { actionData.Training.Add(dm); actionData.TrainingCountMax--; } switch ((ActionEnum)d.Action) { case ActionEnum.Attack: if (angleData.TrainingCountMax > 0) { angleData.Training.Add(dm); angleData.TrainingCountMax--; } break; case ActionEnum.Move: if (xyData.TrainingCountMax > 0 && !duplicates.Contains(hash)) { xyData.Training.Add(dm); xyData.TrainingCountMax--; } break; } } // add as a potential collision duplicates.Add(hash); } // is result } // foreach TrainingData } // if count > 0 } // foreach file Console.WriteLine("Last file considered {0}", lastFile); Console.WriteLine("Training data set ({0} items) and test data set ({1} items)", trainingCount, testCount); Console.WriteLine(" Training: Action({0}) XY({1}) Angle({2})", actionData.Training.Count, xyData.Training.Count, angleData.Training.Count); Console.WriteLine(" Test: Action({0}) XY({1}) Angle({2})", actionData.Test.Count, xyData.Test.Count, angleData.Test.Count); // train shootMup.Bots.Model actions = null; if (type.Equals("ml", StringComparison.OrdinalIgnoreCase)) { actions = new ModelMLNet(actionData.Training, ModelValue.Action); } else { actions = new ModelOpenCV(actionData.Training, ModelValue.Action); } actions.Save(Path.Combine(path, string.Format("action.{0}.model", type))); // evaluate var eval = actions.Evaluate(actionData.Test, ModelValue.Action); Console.WriteLine("Actions RMS={0} R^2={1}", eval.RMS, eval.RSquared); // train shootMup.Bots.Model xy = null; if (type.Equals("ml", StringComparison.OrdinalIgnoreCase)) { xy = new ModelMLNet(xyData.Training, ModelValue.XY); } else { xy = new ModelOpenCV(xyData.Training, ModelValue.XY); } xy.Save(Path.Combine(path, string.Format("xy.{0}.model", type))); // evaluate eval = xy.Evaluate(xyData.Test, ModelValue.XY); Console.WriteLine("XY RMS={0} R^2={1}", eval.RMS, eval.RSquared); // train shootMup.Bots.Model angle = null; if (type.Equals("ml", StringComparison.OrdinalIgnoreCase)) { angle = new ModelMLNet(angleData.Training, ModelValue.Angle); } else { angle = new ModelOpenCV(angleData.Training, ModelValue.Angle); } angle.Save(Path.Combine(path, string.Format("angle.{0}.model", type))); // evaluate eval = angle.Evaluate(angleData.Test, ModelValue.Angle); Console.WriteLine("Angle RMS={0} R^2={1}", eval.RMS, eval.RSquared); return(0); }