Exemplo n.º 1
0
    public void TrainModel()
    {
        List <float[, , ]> trainingData = new List <float[, , ]>();
        List <float[]>     outputData   = new List <float[]>();

        List <Texture2D[]> MNIST_images = new List <Texture2D[]>();

        MNIST_images.Add(MNIST_images0);
        MNIST_images.Add(MNIST_images1);
        MNIST_images.Add(MNIST_images2);
        MNIST_images.Add(MNIST_images3);
        MNIST_images.Add(MNIST_images4);
        MNIST_images.Add(MNIST_images5);
        MNIST_images.Add(MNIST_images6);
        MNIST_images.Add(MNIST_images7);
        MNIST_images.Add(MNIST_images8);
        MNIST_images.Add(MNIST_images9);
        Noedify_Utils.ImportImageData(ref trainingData, ref outputData, MNIST_images, true);
        debugger.net = net;

        Noedify_Solver.SolverMethod solverMethod = Noedify_Solver.SolverMethod.MainThread;
        if (solverMethodToggle != null)
        {
            if (solverMethodToggle.isOn)
            {
                solverMethod = Noedify_Solver.SolverMethod.Background;
            }
        }

        if (solver == null)
        {
            solver = Noedify.CreateSolver();
        }
        solver.debug = new Noedify_Solver.DebugReport();
        sw.Start();
        //solver.costThreshold = 0.01f; // Add a cost threshold to prematurely end training when a suitably low error is achieved
        //solver.suppressMessages = true; // suppress training messages from appearing in editor the console
        solver.TrainNetwork(net, trainingData, outputData, no_epochs, batch_size, trainingRate, costFunction, solverMethod, null, 8);
        float[] cost = solver.cost_report;

        StartCoroutine(PlotCostWhenComplete(solver, cost));
    }
Exemplo n.º 2
0
 public IEnumerator TrainNetwork()
 {
     if (trainingSet != null)
     {
         if (trainingSet.Count > 0)
         {
             while (trainingSolver.trainingInProgress)
             {
                 yield return(null);
             }
             List <float[, , ]> observation_inputs = new List <float[, , ]>();
             List <float[]>     decision_outputs   = new List <float[]>();
             List <float>       trainingSetWeights = new List <float>();
             for (int n = 0; n < trainingSet.Count; n++)
             {
                 observation_inputs.Add(Noedify_Utils.AddTwoSingularDims(trainingSet[n].observation));
                 decision_outputs.Add(trainingSet[n].decision);
                 trainingSetWeights.Add(trainingSet[n].weight);
             }
             trainingSolver.TrainNetwork(net, observation_inputs, decision_outputs, trainingEpochs, trainingBatchSize, trainingRate, Noedify_Solver.CostFunction.MeanSquare, Noedify_Solver.SolverMethod.MainThread, trainingSetWeights, N_threads);
         }
     }
 }