// This is vectorized gradient descent private void GradientDescentStep(Data trainingData, int samplesInTrainingData, ref float trainError, ref int trainHits) { FeedForward(trainingData.Inputs); var outputs = Model.GetOutputs(); Tensor[] losses = new Tensor[outputs.Length]; for (int i = 0; i < outputs.Length; ++i) { losses[i] = new Tensor(outputs[i].Shape); LossFuncs[i].Compute(trainingData.Outputs[i], outputs[i], losses[i]); trainError += losses[i].Sum() / outputs[i].BatchLength; trainHits += AccuracyFuncs != null ? AccuracyFuncs[i](trainingData.Outputs[i], outputs[i]) : 0; LossFuncs[i].Derivative(trainingData.Outputs[i], outputs[i], losses[i]); } BackProp(losses); Optimizer.Step(GetParametersAndGradients(), samplesInTrainingData); }
// Training method, when batch size is -1 the whole training set is used for single gradient descent step (in other words, batch size equals to training set size) public void Fit(List <Data> trainingData, int batchSize = -1, int epochs = 1, List <Data> validationData = null, int verbose = 1, Track trackFlags = Track.TrainError | Track.TestAccuracy, bool shuffle = true) { int inputsBatchSize = trainingData[0].Inputs[0].BatchSize; bool trainingDataAlreadyBatched = inputsBatchSize > 1; #if VALIDATION_ENABLED //for (int i = 0; i < trainingData.Count; ++i) //{ // Data d = trainingData[i]; // Debug.Assert(d.Inputs.BatchSize == d.Outputs.BatchSize, $"Training data set contains mismatched number if input and output batches for data at index {i}!"); // Debug.Assert(d.Inputs.BatchSize == trainingData[0].Inputs.BatchSize, "Training data set contains batches of different size!"); //} #endif if (batchSize < 0) { batchSize = trainingDataAlreadyBatched ? trainingData[0].Inputs[0].BatchSize : trainingData.Count; } string outFilename = $"{FilePrefix}_training_data_{Optimizer.GetType().Name.ToLower()}_b{batchSize}{(Seed > 0 ? ("_seed" + Seed) : "")}_{Tensor.CurrentOpMode}"; ChartGenerator chartGen = null; if (trackFlags != Track.Nothing) { chartGen = new ChartGenerator($"{outFilename}", $"{Name}\nloss=[{string.Join(",", LossFuncs.Select(x => x.GetType().Name))}] optimizer={Optimizer} batch_size={batchSize}\nseed={(Seed > 0 ? Seed.ToString() : "None")} tensor_mode={Tensor.CurrentOpMode}", "Epoch"); } if (trackFlags.HasFlag(Track.TrainError)) { chartGen.AddSeries((int)Track.TrainError, "Error on train data\n(left Y axis)", Color.DarkRed); } if (trackFlags.HasFlag(Track.TestError)) { chartGen.AddSeries((int)Track.TestError, "Error on test data\n(left Y axis)", Color.IndianRed); } if (trackFlags.HasFlag(Track.TrainAccuracy)) { chartGen.AddSeries((int)Track.TrainAccuracy, "Accuracy on train data\n(right Y axis)", Color.DarkBlue, true); } if (trackFlags.HasFlag(Track.TestAccuracy)) { chartGen.AddSeries((int)Track.TestAccuracy, "Accuracy on test\n(right Y axis)", Color.CornflowerBlue, true); } //var lastLayer = Layers.Last(); int outputLayersCount = Model.GetOutputLayersCount(); int batchesNum = trainingDataAlreadyBatched ? trainingData.Count : (trainingData.Count / batchSize); int totalTrainingSamples = trainingData.Count * inputsBatchSize; if (AccuracyFuncs == null && (trackFlags.HasFlag(Track.TrainAccuracy) || trackFlags.HasFlag(Track.TestAccuracy))) { AccuracyFuncs = new AccuracyFunc[outputLayersCount]; for (int i = 0; i < outputLayersCount; ++i) { if (Model.GetOutputLayers().ElementAt(i).OutputShape.Length == 1) { AccuracyFuncs[i] = Tools.AccBinaryClassificationEquality; } else { AccuracyFuncs[i] = Tools.AccCategoricalClassificationEquality; } } } Stopwatch trainTimer = new Stopwatch(); for (int e = 1; e <= epochs; ++e) { string output; if (verbose > 0) { LogLine($"Epoch {e}/{epochs}"); } // no point shuffling stuff when we have single batch if (batchesNum > 1 && shuffle) { trainingData.Shuffle(); } List <Data> batchedTrainingData = trainingDataAlreadyBatched ? trainingData : Tools.MergeData(trainingData, batchSize); float trainTotalError = 0; int trainHits = 0; trainTimer.Restart(); for (int b = 0; b < batchedTrainingData.Count; ++b) { // this will be equal to batch size; however, the last batch size may be different if there is a reminder of training data by batch size division int samples = batchedTrainingData[b].Inputs[0].BatchSize; GradientDescentStep(batchedTrainingData[b], samples, ref trainTotalError, ref trainHits); if (verbose == 2) { output = Tools.GetProgressString(b * batchSize + samples, totalTrainingSamples); Console.Write(output); Console.Write(new string('\b', output.Length)); } } trainTimer.Stop(); if (verbose == 2) { output = Tools.GetProgressString(totalTrainingSamples, totalTrainingSamples); LogLine(output); } float trainError = trainTotalError / totalTrainingSamples; chartGen?.AddData(e, trainError, (int)Track.TrainError); chartGen?.AddData(e, (float)trainHits / totalTrainingSamples / outputLayersCount, (int)Track.TrainAccuracy); if (verbose > 0) { string s = $" - loss: {Math.Round(trainError, 4)}"; if (trackFlags.HasFlag(Track.TrainAccuracy)) { s += $" - acc: {Math.Round((float)trainHits / totalTrainingSamples * 100, 4)}%"; } s += " - eta: " + trainTimer.Elapsed.ToString(@"mm\:ss\.ffff"); LogLine(s); } float testTotalError = 0; if (validationData != null) { int validationSamples = validationData.Count * validationData[0].Inputs[0].BatchSize; float testHits = 0; for (int n = 0; n < validationData.Count; ++n) { FeedForward(validationData[n].Inputs); var outputs = Model.GetOutputs(); Tensor[] losses = new Tensor[outputs.Length]; for (int i = 0; i < outputLayersCount; ++i) { LossFuncs[i].Compute(validationData[n].Outputs[i], outputs[i], losses[i]); testTotalError += losses[i].Sum() / outputs[i].BatchLength; testHits += AccuracyFuncs[i](validationData[n].Outputs[i], outputs[i]); } if (verbose == 2) { string progress = " - validating: " + Math.Round(n / (float)validationData.Count * 100) + "%"; Console.Write(progress); Console.Write(new string('\b', progress.Length)); } } chartGen?.AddData(e, testTotalError / validationSamples, (int)Track.TestError); chartGen?.AddData(e, (float)testHits / validationSamples / outputLayersCount, (int)Track.TestAccuracy); } if ((ChartSaveInterval > 0 && (e % ChartSaveInterval == 0)) || e == epochs) { chartGen?.Save(); } } if (verbose > 0) { File.WriteAllLines($"{outFilename}_log.txt", LogLines); } }