IEnumerator PlotCostWhenComplete(Noedify_Solver solver, float[] cost)
 {
     while (solver.trainingInProgress)
     {
         yield return(null);
     }
     debugger.PlotCost(cost, new float[2] {
         1.0f / no_epochs * 2.5f, 5
     }, costPlotOrigin);
     for (int n = 0; n < predictionTester.sampleImagePlanes.Length; n++)
     {
         float[,,] testInputImage = new float[1, 1, 1];
         Noedify_Utils.ImportImageData(ref testInputImage, predictionTester.sampleImageRandomSet[n], true);
         solver.Evaluate(net, testInputImage, Noedify_Solver.SolverMethod.MainThread);
         int prediction = Noedify_Utils.ConvertOneHotToInt(solver.prediction);
         predictionTester.CNN_predictionText[n].text = prediction.ToString();
     }
 }
Ejemplo n.º 2
0
    public void TrainModel()
    {
        List <float[, , ]> trainingData = new List <float[, , ]>();
        List <float[]>     outputData   = new List <float[]>();

        List <Texture2D[]> MNIST_images = new List <Texture2D[]>();

        MNIST_images.Add(MNIST_images0);
        MNIST_images.Add(MNIST_images1);
        MNIST_images.Add(MNIST_images2);
        MNIST_images.Add(MNIST_images3);
        MNIST_images.Add(MNIST_images4);
        MNIST_images.Add(MNIST_images5);
        MNIST_images.Add(MNIST_images6);
        MNIST_images.Add(MNIST_images7);
        MNIST_images.Add(MNIST_images8);
        MNIST_images.Add(MNIST_images9);
        Noedify_Utils.ImportImageData(ref trainingData, ref outputData, MNIST_images, true);
        debugger.net = net;

        Noedify_Solver.SolverMethod solverMethod = Noedify_Solver.SolverMethod.MainThread;
        if (solverMethodToggle != null)
        {
            if (solverMethodToggle.isOn)
            {
                solverMethod = Noedify_Solver.SolverMethod.Background;
            }
        }

        if (solver == null)
        {
            solver = Noedify.CreateSolver();
        }
        solver.debug = new Noedify_Solver.DebugReport();
        sw.Start();
        //solver.costThreshold = 0.01f; // Add a cost threshold to prematurely end training when a suitably low error is achieved
        //solver.suppressMessages = true; // suppress training messages from appearing in editor the console
        solver.TrainNetwork(net, trainingData, outputData, no_epochs, batch_size, trainingRate, costFunction, solverMethod, null, 8);
        float[] cost = solver.cost_report;

        StartCoroutine(PlotCostWhenComplete(solver, cost));
    }