Exemple #1
0
        // This is vectorized gradient descent
        private void GradientDescentStep(Data trainingData, int samplesInTrainingData, ref float trainError, ref int trainHits)
        {
            FeedForward(trainingData.Inputs);
            var outputs = Model.GetOutputs();

            Tensor[] losses = new Tensor[outputs.Length];
            for (int i = 0; i < outputs.Length; ++i)
            {
                losses[i] = new Tensor(outputs[i].Shape);
                LossFuncs[i].Compute(trainingData.Outputs[i], outputs[i], losses[i]);
                trainError += losses[i].Sum() / outputs[i].BatchLength;
                trainHits  += AccuracyFuncs != null ? AccuracyFuncs[i](trainingData.Outputs[i], outputs[i]) : 0;
                LossFuncs[i].Derivative(trainingData.Outputs[i], outputs[i], losses[i]);
            }
            BackProp(losses);
            Optimizer.Step(GetParametersAndGradients(), samplesInTrainingData);
        }
Exemple #2
0
        // Training method, when batch size is -1 the whole training set is used for single gradient descent step (in other words, batch size equals to training set size)
        public void Fit(List <Data> trainingData, int batchSize = -1, int epochs = 1, List <Data> validationData = null, int verbose = 1, Track trackFlags = Track.TrainError | Track.TestAccuracy, bool shuffle = true)
        {
            int  inputsBatchSize            = trainingData[0].Inputs[0].BatchSize;
            bool trainingDataAlreadyBatched = inputsBatchSize > 1;

#if VALIDATION_ENABLED
            //for (int i = 0; i < trainingData.Count; ++i)
            //{
            //    Data d = trainingData[i];
            //    Debug.Assert(d.Inputs.BatchSize == d.Outputs.BatchSize, $"Training data set contains mismatched number if input and output batches for data at index {i}!");
            //    Debug.Assert(d.Inputs.BatchSize == trainingData[0].Inputs.BatchSize, "Training data set contains batches of different size!");
            //}
#endif

            if (batchSize < 0)
            {
                batchSize = trainingDataAlreadyBatched ? trainingData[0].Inputs[0].BatchSize : trainingData.Count;
            }

            string         outFilename = $"{FilePrefix}_training_data_{Optimizer.GetType().Name.ToLower()}_b{batchSize}{(Seed > 0 ? ("_seed" + Seed) : "")}_{Tensor.CurrentOpMode}";
            ChartGenerator chartGen    = null;
            if (trackFlags != Track.Nothing)
            {
                chartGen = new ChartGenerator($"{outFilename}", $"{Name}\nloss=[{string.Join(",", LossFuncs.Select(x => x.GetType().Name))}] optimizer={Optimizer} batch_size={batchSize}\nseed={(Seed > 0 ? Seed.ToString() : "None")} tensor_mode={Tensor.CurrentOpMode}", "Epoch");
            }

            if (trackFlags.HasFlag(Track.TrainError))
            {
                chartGen.AddSeries((int)Track.TrainError, "Error on train data\n(left Y axis)", Color.DarkRed);
            }
            if (trackFlags.HasFlag(Track.TestError))
            {
                chartGen.AddSeries((int)Track.TestError, "Error on test data\n(left Y axis)", Color.IndianRed);
            }
            if (trackFlags.HasFlag(Track.TrainAccuracy))
            {
                chartGen.AddSeries((int)Track.TrainAccuracy, "Accuracy on train data\n(right Y axis)", Color.DarkBlue, true);
            }
            if (trackFlags.HasFlag(Track.TestAccuracy))
            {
                chartGen.AddSeries((int)Track.TestAccuracy, "Accuracy on test\n(right Y axis)", Color.CornflowerBlue, true);
            }

            //var lastLayer = Layers.Last();
            int outputLayersCount = Model.GetOutputLayersCount();

            int batchesNum           = trainingDataAlreadyBatched ? trainingData.Count : (trainingData.Count / batchSize);
            int totalTrainingSamples = trainingData.Count * inputsBatchSize;

            if (AccuracyFuncs == null && (trackFlags.HasFlag(Track.TrainAccuracy) || trackFlags.HasFlag(Track.TestAccuracy)))
            {
                AccuracyFuncs = new AccuracyFunc[outputLayersCount];

                for (int i = 0; i < outputLayersCount; ++i)
                {
                    if (Model.GetOutputLayers().ElementAt(i).OutputShape.Length == 1)
                    {
                        AccuracyFuncs[i] = Tools.AccBinaryClassificationEquality;
                    }
                    else
                    {
                        AccuracyFuncs[i] = Tools.AccCategoricalClassificationEquality;
                    }
                }
            }

            Stopwatch trainTimer = new Stopwatch();

            for (int e = 1; e <= epochs; ++e)
            {
                string output;

                if (verbose > 0)
                {
                    LogLine($"Epoch {e}/{epochs}");
                }

                // no point shuffling stuff when we have single batch
                if (batchesNum > 1 && shuffle)
                {
                    trainingData.Shuffle();
                }

                List <Data> batchedTrainingData = trainingDataAlreadyBatched ? trainingData : Tools.MergeData(trainingData, batchSize);

                float trainTotalError = 0;
                int   trainHits       = 0;

                trainTimer.Restart();

                for (int b = 0; b < batchedTrainingData.Count; ++b)
                {
                    // this will be equal to batch size; however, the last batch size may be different if there is a reminder of training data by batch size division
                    int samples = batchedTrainingData[b].Inputs[0].BatchSize;
                    GradientDescentStep(batchedTrainingData[b], samples, ref trainTotalError, ref trainHits);

                    if (verbose == 2)
                    {
                        output = Tools.GetProgressString(b * batchSize + samples, totalTrainingSamples);
                        Console.Write(output);
                        Console.Write(new string('\b', output.Length));
                    }
                }

                trainTimer.Stop();

                if (verbose == 2)
                {
                    output = Tools.GetProgressString(totalTrainingSamples, totalTrainingSamples);
                    LogLine(output);
                }

                float trainError = trainTotalError / totalTrainingSamples;

                chartGen?.AddData(e, trainError, (int)Track.TrainError);
                chartGen?.AddData(e, (float)trainHits / totalTrainingSamples / outputLayersCount, (int)Track.TrainAccuracy);

                if (verbose > 0)
                {
                    string s = $" - loss: {Math.Round(trainError, 4)}";
                    if (trackFlags.HasFlag(Track.TrainAccuracy))
                    {
                        s += $" - acc: {Math.Round((float)trainHits / totalTrainingSamples * 100, 4)}%";
                    }
                    s += " - eta: " + trainTimer.Elapsed.ToString(@"mm\:ss\.ffff");

                    LogLine(s);
                }

                float testTotalError = 0;

                if (validationData != null)
                {
                    int   validationSamples = validationData.Count * validationData[0].Inputs[0].BatchSize;
                    float testHits          = 0;

                    for (int n = 0; n < validationData.Count; ++n)
                    {
                        FeedForward(validationData[n].Inputs);
                        var      outputs = Model.GetOutputs();
                        Tensor[] losses  = new Tensor[outputs.Length];
                        for (int i = 0; i < outputLayersCount; ++i)
                        {
                            LossFuncs[i].Compute(validationData[n].Outputs[i], outputs[i], losses[i]);
                            testTotalError += losses[i].Sum() / outputs[i].BatchLength;
                            testHits       += AccuracyFuncs[i](validationData[n].Outputs[i], outputs[i]);
                        }

                        if (verbose == 2)
                        {
                            string progress = " - validating: " + Math.Round(n / (float)validationData.Count * 100) + "%";
                            Console.Write(progress);
                            Console.Write(new string('\b', progress.Length));
                        }
                    }

                    chartGen?.AddData(e, testTotalError / validationSamples, (int)Track.TestError);
                    chartGen?.AddData(e, (float)testHits / validationSamples / outputLayersCount, (int)Track.TestAccuracy);
                }

                if ((ChartSaveInterval > 0 && (e % ChartSaveInterval == 0)) || e == epochs)
                {
                    chartGen?.Save();
                }
            }

            if (verbose > 0)
            {
                File.WriteAllLines($"{outFilename}_log.txt", LogLines);
            }
        }