/// <summary> /// Performs SGD on a set of training data using a mini-batch size provided. /// </summary> /// <param name="trainingSet"></param> /// <param name="batchSize"></param> /// <returns>The average cost function evaluation for each batch</returns> internal override void Learn(HashSet <TrainingData> trainingSet, int batchSize) { batchSize = Math.Min(batchSize, trainingSet.Count); Vector <double> output; int count = 0; double cost = 0; int batchNumber = 0; foreach (TrainingData td in trainingSet) { output = Process(td.Data); cost += CostFunc.Of(td.Response, output) / batchSize; PropogateError(CostFunc.Derivative(td.Response, output), batchSize); count++; if (Abort) { return; } if (count > 0 && count % batchSize == 0) { batchNumber++; LastCost = cost; ApplyError(); if (!Abort) // Trying to make this kind of threadsafe { Hook?.Invoke(batchNumber, this); // Trigger the batch level external control } count = 0; cost = 0; } } }
internal override void Learn(HashSet <TrainingData> trainSet, int batchSize = 1) { batchSize = Math.Min(batchSize, trainSet.Count); Vector <double> output; double cost = 0; int nBatch = 0; foreach (TrainingData trainSeq in trainSet) { // Start over for each different training sequence provided. WipeMemory(); PreviousResponse = null; for (var i = 0; i < trainSeq.Count; i++) { // Process the pair TrainingData.TrainingPair pair = trainSeq[i]; output = Process(pair.Data); cost += CostFunc.Of(trainSeq[i].Response, output) / (batchSize * MaxMemory); // If we have completely overwriten our short term memory, then // update the weights based on how we performed this time. if (i > 0 && i % MaxMemory == 0) { PropogateError(trainSeq.SubSequence(Math.Min(i - MaxMemory, 0), i), batchSize); } // Count batches by number of error propogations if (i % (batchSize * MaxMemory) == 0) { nBatch++; LastCost = cost; Hook?.Invoke(nBatch, this); ApplyError(); cost = 0; } // Keep the last... uhh. this is a PIPI (Parallel Implementation Prone to Inconsistency) // See this.Process if (ForceOutput) { PreviousResponse = pair.Response; } else { PreviousResponse = output; } if (Abort) { Abort = false; return; } } } }