public void NeuralNetOptimizer_Reset_Does_Not_Throw() { var parametersAndGradients = new List <ParametersAndGradients> { new ParametersAndGradients(new float[10], new float[10]), new ParametersAndGradients(new float[10], new float[10]), }; foreach (OptimizerMethod optimizer in Enum.GetValues(typeof(OptimizerMethod))) { var sut = new NeuralNetOptimizer(0.001, 10, optimizerMethod: optimizer); sut.UpdateParameters(parametersAndGradients); sut.Reset(); } }
/// <summary> /// Learns a neural net based on the observations and targets. /// The learning only uses the observations which indices are present in indices. /// ValidationObservations and ValidationTargets are used to track the validation loss pr. iteration. /// The iteration with the best validaiton loss is returned. /// </summary> /// <param name="observations"></param> /// <param name="targets"></param> /// <param name="indices"></param> /// <param name="validationObservations"></param> /// <param name="validationTargets"></param> /// <returns></returns> public NeuralNet Learn(F64Matrix observations, double[] targets, int[] indices, F64Matrix validationObservations, double[] validationTargets) { Checks.VerifyObservationsAndTargets(observations, targets); Checks.VerifyIndices(indices, observations, targets); // Only check validation data if in use. if (validationObservations != null && validationTargets != null) { Checks.VerifyObservationsAndTargets(validationObservations, validationTargets); } // targetEncoder var oneOfNTargets = m_targetEncoder.Encode(targets); // Setup working parameters var samples = indices.Length; var learningIndices = indices.ToArray(); var numberOfBatches = samples / m_batchSize; // check for size mismatch var batchTargets = Matrix <float> .Build.Dense(m_batchSize, oneOfNTargets.ColumnCount); var batchObservations = Matrix <float> .Build.Dense(m_batchSize, observations.ColumnCount); if (m_batchSize > samples) { throw new ArgumentException("BatchSize: " + m_batchSize + " is larger than number og observations: " + samples); } var currentLoss = 0.0; // initialize net m_net.Initialize(m_batchSize, m_random); // extract reference to parameters and gradients var parametersAndGradients = m_net.GetParametersAndGradients(); // reset optimizer m_optimizer.Reset(); // Setup early stopping if validation data is provided. var earlyStopping = validationObservations != null && validationTargets != null; NeuralNet bestNeuralNet = null; Matrix <float> floatValidationObservations = null; Matrix <float> floatValidationTargets = null; Matrix <float> floatValidationPredictions = null; var bestLoss = double.MaxValue; if (earlyStopping) { var validationIndices = Enumerable.Range(0, validationTargets.Length).ToArray(); floatValidationObservations = Matrix <float> .Build .Dense(validationObservations.RowCount, validationObservations.ColumnCount); CopyBatch(validationObservations, floatValidationObservations, validationIndices); floatValidationTargets = m_targetEncoder.Encode(validationTargets); floatValidationPredictions = Matrix <float> .Build .Dense(floatValidationTargets.RowCount, floatValidationTargets.ColumnCount); } var timer = new Stopwatch(); // train using stochastic gradient descent for (int iteration = 0; iteration < m_iterations; iteration++) { timer.Restart(); var accumulatedLoss = 0.0; learningIndices.Shuffle(m_random); for (int i = 0; i < numberOfBatches; i++) { var workIndices = learningIndices .Skip(i * m_batchSize) .Take(m_batchSize).ToArray(); if (workIndices.Length != m_batchSize) { continue; // only train with full batch size } CopyBatchTargets(oneOfNTargets, batchTargets, workIndices); CopyBatch(observations, batchObservations, workIndices); // forward pass. var predictions = m_net.Forward(batchObservations); // loss var batchLoss = m_loss.Loss(batchTargets, predictions); accumulatedLoss += batchLoss * m_batchSize; // Backwards pass. m_net.Backward(batchTargets); // Weight update. m_optimizer.UpdateParameters(parametersAndGradients); } currentLoss = accumulatedLoss / (double)indices.Length; if (earlyStopping) { var candidate = m_net.CopyNetForPredictionModel(); candidate.Forward(floatValidationObservations, floatValidationPredictions); var validationLoss = m_loss.Loss(floatValidationTargets, floatValidationPredictions); timer.Stop(); Trace.WriteLine(string.Format("Iteration: {0:000} - Loss {1:0.00000} - Validation: {2:0.00000} - Time (ms): {3}", (iteration + 1), currentLoss, validationLoss, timer.ElapsedMilliseconds)); if (validationLoss < bestLoss) { bestLoss = validationLoss; bestNeuralNet = candidate; } } else { timer.Stop(); Trace.WriteLine(string.Format("Iteration: {0:000} - Loss {1:0.00000} - Time (ms): {2}", (iteration + 1), currentLoss, timer.ElapsedMilliseconds)); } if (double.IsNaN(currentLoss)) { Trace.WriteLine("Loss is NaN, stopping..."); break; } } if (earlyStopping) { return(bestNeuralNet); } else { return(m_net.CopyNetForPredictionModel()); } }