/// <summary> /// Writes a trainer parameters to event log. /// </summary> /// <param name="logFile">The log file to write to.</param> /// <param name="trainer">The trainer which parameters to write.</param> /// <param name="algorithm">The training algorithm.</param> /// <param name="loss">The loss function.</param> private void WriteTrainerParameters( StreamWriter logFile, ClassificationNetworkTrainer trainer, ITrainingAlgorithm algorithm, ILoss <int[]> loss) { this.WriteLine(logFile, "Trainer parameters:"); this.WriteLine(logFile, " Batch Size: {0}", trainer.BatchSize); this.WriteLine(logFile, " L1 Rate: {0}", trainer.RateL1); this.WriteLine(logFile, " L2 Rate: {0}", trainer.RateL2); this.WriteLine(logFile, " Clip Value: {0}", trainer.ClipValue); this.WriteLine(logFile, "Algorithm parameters:"); this.WriteLine(logFile, " Algorithm: {0}", algorithm.GetType().Name); if (algorithm is Adadelta adadelta) { this.WriteLine(logFile, " Learning Rate: {0}", adadelta.LearningRate); this.WriteLine(logFile, " Decay: {0}", adadelta.Decay); this.WriteLine(logFile, " Rho: {0}", adadelta.Rho); this.WriteLine(logFile, " Eps: {0}", adadelta.Eps); } if (algorithm is Adagrad adagrad) { this.WriteLine(logFile, " Learning Rate: {0}", adagrad.LearningRate); this.WriteLine(logFile, " Eps: {0}", adagrad.Eps); } if (algorithm is Adam adam) { this.WriteLine(logFile, " Learning Rate: {0}", adam.LearningRate); this.WriteLine(logFile, " Beta1: {0}", adam.Beta1); this.WriteLine(logFile, " Beta2: {0}", adam.Beta2); this.WriteLine(logFile, " Eps: {0}", adam.Eps); } if (algorithm is RMSProp rmsProp) { this.WriteLine(logFile, " Learning Rate: {0}", rmsProp.LearningRate); this.WriteLine(logFile, " Rho: {0}", rmsProp.Rho); this.WriteLine(logFile, " Eps: {0}", rmsProp.Eps); } if (algorithm is SGD sgd) { this.WriteLine(logFile, " Learning Rate: {0}", sgd.LearningRate); this.WriteLine(logFile, " Decay: {0}", sgd.Decay); this.WriteLine(logFile, " Momentum: {0}", sgd.Momentum); this.WriteLine(logFile, " Nesterov: {0}", sgd.Nesterov); } this.WriteLine(logFile, "Loss parameters:"); this.WriteLine(logFile, " Loss: {0}", loss.GetType().Name); if (loss is LogLikelihoodLoss logLikelihoodLoss) { this.WriteLine(logFile, " LSR: {0}", logLikelihoodLoss.LSR); } }
public NetworkTopology() { hiddenLayers = null; inputLayer = null; outputLayer = null; preProcessor = null; postProcessor = null; TrainingPreProcessor = null; TrainingAlgorithm = null; }
private ITrainingAlgorithm CreateAlgorithm() { ITrainingAlgorithm algorithm = null; switch (this.TaskParameters.Algorithm.Name) { case "Adadelta": default: algorithm = new Adadelta(); break; case "Adagrad": algorithm = new Adagrad(); break; case "Adam": algorithm = new Adam(); break; case "RMSProp": algorithm = new RMSProp(); break; case "SGD": algorithm = new SGD(); break; } JsonSerializer jsonSerializer = new JsonSerializer(); using (JTokenReader jtokenReader = new JTokenReader(this.TaskParameters.Algorithm.Parameters)) { jsonSerializer.Populate(jtokenReader, algorithm); } return(algorithm); }