IEnumerator PlotCostWhenComplete(Noedify_Solver solver, float[] cost) { while (solver.trainingInProgress) { yield return(null); } debugger.PlotCost(cost, new float[2] { 1.0f / no_epochs * 2.5f, 5 }, costPlotOrigin); for (int n = 0; n < predictionTester.sampleImagePlanes.Length; n++) { float[,,] testInputImage = new float[1, 1, 1]; Noedify_Utils.ImportImageData(ref testInputImage, predictionTester.sampleImageRandomSet[n], true); solver.Evaluate(net, testInputImage, Noedify_Solver.SolverMethod.MainThread); int prediction = Noedify_Utils.ConvertOneHotToInt(solver.prediction); predictionTester.CNN_predictionText[n].text = prediction.ToString(); } }
// Start is called before the first frame update void Start() { sampleImageRandomSet = new List <Texture2D>(); int[] sampleRange = new int[sampleImageSet.Length]; for (int i = 0; i < sampleRange.Length; i++) { sampleRange[i] = i; } Noedify_Utils.Shuffle(sampleRange); for (int i = 0; i < sampleImagePlanes.Length; i++) { sampleImageRandomSet.Add(sampleImageSet[sampleRange[i]]); } for (int i = 0; i < sampleImagePlanes.Length; i++) { sampleImagePlanes[i].GetComponent <MeshRenderer>().material.SetTexture("_MainTex", sampleImageRandomSet[i]); } }
public void TrainModel() { List <float[, , ]> trainingData = new List <float[, , ]>(); List <float[]> outputData = new List <float[]>(); List <Texture2D[]> MNIST_images = new List <Texture2D[]>(); MNIST_images.Add(MNIST_images0); MNIST_images.Add(MNIST_images1); MNIST_images.Add(MNIST_images2); MNIST_images.Add(MNIST_images3); MNIST_images.Add(MNIST_images4); MNIST_images.Add(MNIST_images5); MNIST_images.Add(MNIST_images6); MNIST_images.Add(MNIST_images7); MNIST_images.Add(MNIST_images8); MNIST_images.Add(MNIST_images9); Noedify_Utils.ImportImageData(ref trainingData, ref outputData, MNIST_images, true); debugger.net = net; Noedify_Solver.SolverMethod solverMethod = Noedify_Solver.SolverMethod.MainThread; if (solverMethodToggle != null) { if (solverMethodToggle.isOn) { solverMethod = Noedify_Solver.SolverMethod.Background; } } if (solver == null) { solver = Noedify.CreateSolver(); } solver.debug = new Noedify_Solver.DebugReport(); sw.Start(); //solver.costThreshold = 0.01f; // Add a cost threshold to prematurely end training when a suitably low error is achieved //solver.suppressMessages = true; // suppress training messages from appearing in editor the console solver.TrainNetwork(net, trainingData, outputData, no_epochs, batch_size, trainingRate, costFunction, solverMethod, null, 8); float[] cost = solver.cost_report; StartCoroutine(PlotCostWhenComplete(solver, cost)); }
public IEnumerator TrainNetwork() { if (trainingSet != null) { if (trainingSet.Count > 0) { while (trainingSolver.trainingInProgress) { yield return(null); } List <float[, , ]> observation_inputs = new List <float[, , ]>(); List <float[]> decision_outputs = new List <float[]>(); List <float> trainingSetWeights = new List <float>(); for (int n = 0; n < trainingSet.Count; n++) { observation_inputs.Add(Noedify_Utils.AddTwoSingularDims(trainingSet[n].observation)); decision_outputs.Add(trainingSet[n].decision); trainingSetWeights.Add(trainingSet[n].weight); } trainingSolver.TrainNetwork(net, observation_inputs, decision_outputs, trainingEpochs, trainingBatchSize, trainingRate, Noedify_Solver.CostFunction.MeanSquare, Noedify_Solver.SolverMethod.MainThread, trainingSetWeights, N_threads); } } }
// Evaluate netowrk to generate decision float[] AIDecision(float[] observation) { evalSolver.Evaluate(simController.net, Noedify_Utils.AddTwoSingularDims(observation), Noedify_Solver.SolverMethod.MainThread); return(evalSolver.prediction); }