/// <summary> /// Creating Loss an Evaluation function for defined output(s) /// </summary> /// <param name="lrParams">learning parameters</param> /// <param name="network">network model</param> /// <returns></returns> private (Function loss, Function eval) createLossEvalFunction(LearningParameters lrParams, Function network) { //In case of single output variable create tuple of only two functions if (OutputVariables.Count == 1) { return(createFunction(lrParams.LossFunction, network, OutputVariables[0]), createFunction(lrParams.EvaluationFunction, network, OutputVariables[0])); } //handling multiple outputs by accumulating loss and evaluation function for each output var trLosses = new VariableVector(); var trEvals = new VariableVector(); for (int i = 0; i < OutputVariables.Count; i++) { var outVar = OutputVariables[i]; var output = network.Outputs[i]; var l = createFunction(lrParams.LossFunction, output, outVar); trLosses.Add(l); var e = createFunction(lrParams.EvaluationFunction, output, outVar); trEvals.Add(e); } //create cumulative evaluation and loss function var loss = CNTKLib.Sum(trLosses, "Overall_" + lrParams.LossFunction.ToString()); var eval = CNTKLib.Sum(trEvals, "Overall_" + lrParams.EvaluationFunction.ToString()); // return(loss, eval); }
public void TestLearningParametersDefault() { var x = new LearningParameters(); var result = new Dictionary <string, string>(); x.AddParameters(result); Assert.Empty(result); }
private protected TrainerBase(LearningParameters lp, ObjectiveParameters op) { Learning = lp; Objective = op; //ParallelTraining = Args.ParallelTrainer != null ? Args.ParallelTrainer.CreateComponent(env) : new SingleTrainer(); //InitParallelTraining(); }
public RankingTrainer(LearningParameters lp, ObjectiveParameters op) : base(lp, op) { if (op.Objective != ObjectiveType.LambdaRank) { throw new Exception("Require Objective == ObjectiveType.LambdaRank"); } if (op.Metric == MetricType.DefaultMetric) { op.Metric = MetricType.Ndcg; } }
public BinaryTrainer(LearningParameters lp, ObjectiveParameters op) : base(lp, op) { if (op.Objective != ObjectiveType.Binary) { throw new Exception("Require Objective == ObjectiveType.Binary"); } if (op.Metric == MetricType.DefaultMetric) { op.Metric = MetricType.BinaryLogLoss; } }
/// <summary> /// Helper method in order to create training before training. It also try to restore trained from /// checkpoint file in order to continue with training /// </summary> /// <param name="network"></param> /// <param name="lrParams"></param> /// <param name="trParams"></param> /// <param name="modelCheckPoint"></param> /// <returns></returns> public Trainer CreateTrainer(Function network, LearningParameters lrParams, TrainingParameters trParams, string modelCheckPoint, string historyPath) { try { //create trainer var trainer = createTrainer(network, lrParams, trParams); //set initial value for the evaluation value m_PrevTrainingEval = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue; m_PrevValidationEval = StatMetrics.IsGoalToMinimize(trainer.EvaluationFunction()) ? double.MaxValue : double.MinValue; //in case modelCheckpoint is saved and user select re-training existing trainer //first check if the checkpoint is available if (trParams.ContinueTraining && !string.IsNullOrEmpty(modelCheckPoint) && File.Exists(modelCheckPoint)) { //if the network model changed checkpoint state will throw exception //in that case throw exception that re-training is not possible try { trainer.RestoreFromCheckpoint(modelCheckPoint); //load history of training in case continuation of training is requested m_trainingHistory = loadTrainingHistory(historyPath); } catch (Exception) { throw new Exception("The Trainer cannot be restored from the previous state probably because the network has changed." + "\n Uncheck 'Continue Training' and train the model from scratch."); throw; } } else//delete checkpoint if exist in case no retraining is required, //so the next checkpoint saving is free of previous checkpoints { //delete heckpoint if (File.Exists(modelCheckPoint)) { File.Delete(modelCheckPoint); } //delete history if (File.Exists(historyPath)) { File.Delete(historyPath); } } return(trainer); } catch (Exception) { throw; } }
private void LearnButton_Click(object sender, RoutedEventArgs e) { NetworkParameters networkParameters = new NetworkParameters() { BiasesEnabled = (bool)BiasesCheckbox.IsChecked, ClassesNumber = (int)ClassNumberSlider.Value, LearningRate = (double)LearningRateSlider.Value, Seed = (double)SeedSider.Value, NumberOfEpochs = (int)NumberOfEpochsSlider.Value, Momentum = (double)MomentumSlider.Value, Layers = new List <NetworkLayer>(this.NetworkLayers) }; if (ClassificationProblemTypeRadio.IsChecked == true) { networkParameters.Problem = ProblemType.CLASSIFICATION; } else if (RegressionProblemTypeRadio.IsChecked == true) { networkParameters.Problem = ProblemType.REGRESSION; } if (MseErrorRadio.IsChecked == true) { networkParameters.ErrType = ErrorType.MSE; } else if (MaeErrorRadio.IsChecked == true) { networkParameters.ErrType = ErrorType.MAE; } else if (ClassicErrorRadio.IsChecked == true) { networkParameters.ErrType = ErrorType.CLASSIC; } else if (CrossEntropyRadio.IsChecked == true) { networkParameters.ErrType = ErrorType.CROSS_ENTROPY; } LearningParameters learningParameters = new LearningParameters() { Points = new List <Structs.Point>(this.trainPoints) }; try { Python.Learn(networkParameters, learningParameters); } catch (Exception exc) { MessageBox.Show(exc.Message); } }
PrepareNNData(Dictionary <string, string> dicMParameters, CreateCustomModel customModel, DeviceDescriptor device) { try { //create factory object MLFactory f = CreateMLFactory(dicMParameters); //create learning params var strLearning = dicMParameters["learning"]; LearningParameters lrData = MLFactory.CreateLearningParameters(strLearning); //create training param var strTraining = dicMParameters["training"]; TrainingParameters trData = MLFactory.CreateTrainingParameters(strTraining); //set model component locations var dicPath = MLFactory.GetMLConfigComponentPaths(dicMParameters["paths"]); // trData.ModelTempLocation = $"{dicMParameters["root"]}\\{dicPath["TempModels"]}"; trData.ModelFinalLocation = $"{dicMParameters["root"]}\\{dicPath["Models"]}"; var strTrainPath = $"{dicMParameters["root"]}\\{dicPath["Training"]}"; var strValidPath = (string.IsNullOrEmpty(dicPath["Validation"]) || dicPath["Validation"] == " ") ? "": $"{dicMParameters["root"]}\\{dicPath["Validation"]}"; //data normalization in case the option is enabled //check if network contains Normalization layer and assign value to normalization parameter if (dicMParameters["network"].Contains("Normalization")) { trData.Normalization = new string[] { MLFactory.m_NumFeaturesGroupName } } ; //perform data normalization according to the normalization parameter List <Variable> networkInput = NormalizeInputLayer(trData, f, strTrainPath, strValidPath, device); //create network parameters Function nnModel = CreateNetworkModel(dicMParameters["network"], networkInput, f.OutputVariables, customModel, device); //create minibatch spurce var mbs = new MinibatchSourceEx(trData.Type, f.StreamConfigurations.ToArray(), strTrainPath, strValidPath, MinibatchSource.InfinitelyRepeat, trData.RandomizeBatch); //return ml parameters return(f, lrData, trData, nnModel, mbs); } catch (Exception) { throw; } }
/// <summary> /// Creates trained based on training and learning parameters /// </summary> /// <param name="network">Network model being trained</param> /// <param name="lrParams">Learning parameters</param> /// <param name="trParams">Training parameters</param> /// <returns></returns> private Trainer createTrainer(Function network, LearningParameters lrParams, TrainingParameters trParams) { //network parameters var zParams = new ParameterVector(network.Parameters().ToList()); //create loss and eval (Function loss, Function eval) = createLossEvalFunction(lrParams, network); //learners var learners = createLearners(network, lrParams); //trainer var trainer = Trainer.CreateTrainer(network, loss, eval, learners); // return(trainer); }
public MulticlassTrainer(LearningParameters lp, ObjectiveParameters op) : base(lp, op) { if (!(op.Objective == ObjectiveType.MultiClass || op.Objective == ObjectiveType.MultiClassOva)) { throw new Exception("Require Objective == MultiClass or MultiClassOva"); } if (op.NumClass <= 1) { throw new Exception("Require NumClass > 1"); } if (op.Metric == MetricType.DefaultMetric) { op.Metric = MetricType.MultiLogLoss; // TODO: why was this MultiError????? } }
public RegressionTrainer(LearningParameters lp, ObjectiveParameters op) : base(lp, op) { if (!(op.Objective == ObjectiveType.Regression || op.Objective == ObjectiveType.RegressionL1 || op.Objective == ObjectiveType.Huber || op.Objective == ObjectiveType.Fair || op.Objective == ObjectiveType.Poisson || op.Objective == ObjectiveType.Quantile || op.Objective == ObjectiveType.Mape || op.Objective == ObjectiveType.Gamma || op.Objective == ObjectiveType.Tweedie )) { throw new Exception("Require regression ObjectiveType"); } if (op.Metric == MetricType.DefaultMetric) { op.Metric = MetricType.Mse; } }
/// <summary> /// Creates the learner based on learning parameters. /// ToDo: Not all learners parameters defined /// </summary> /// <param name="network">Network model being trained</param> /// <param name="lrParams">Learning parameters.</param> /// <returns></returns> private List <Learner> createLearners(Function network, LearningParameters lrParams) { //learning rate and momentum values var lr = new TrainingParameterScheduleDouble(lrParams.LearningRate); var mm = CNTKLib.MomentumAsTimeConstantSchedule(lrParams.Momentum); var addParam = new AdditionalLearningOptions(); // if (lrParams.L1Regularizer > 0) { addParam.l1RegularizationWeight = lrParams.L1Regularizer; } if (lrParams.L2Regularizer > 0) { addParam.l2RegularizationWeight = lrParams.L2Regularizer; } //SGD Momentum learner if (lrParams.LearnerType == LearnerType.MomentumSGDLearner) { // var llr = new List <Learner>(); var msgd = Learner.MomentumSGDLearner(network.Parameters(), lr, mm, true, addParam); llr.Add(msgd); return(llr); } //SGDLearner - rate and regulars else if (lrParams.LearnerType == LearnerType.SGDLearner) { // var llr = new List <Learner>(); var msgd = Learner.SGDLearner(network.Parameters(), lr, addParam); llr.Add(msgd); return(llr); } //FSAdaGradLearner learner - rate, moment regulars else if (lrParams.LearnerType == LearnerType.FSAdaGradLearner) { // var llr = new List <Learner>(); var msgd = CNTKLib.FSAdaGradLearner(new ParameterVector(network.Parameters().ToList()), lr, mm); llr.Add(msgd); return(llr); } //AdamLearner learner else if (lrParams.LearnerType == LearnerType.AdamLearner) { // var llr = new List <Learner>(); var msgd = CNTKLib.AdamLearner(new ParameterVector(network.Parameters().ToList()), lr, mm); llr.Add(msgd); return(llr); } //AdaGradLearner learner - Learning rate and regularizers else if (lrParams.LearnerType == LearnerType.AdaGradLearner) { // var llr = new List <Learner>(); var msgd = CNTKLib.AdaGradLearner(new ParameterVector(network.Parameters().ToList()), lr, false, addParam); llr.Add(msgd); return(llr); } else { throw new Exception("Learner type is not supported!"); } }
/// <summary> /// Create LearningParameters object from string. /// </summary> /// <param name="strLearning"></param> /// <returns></returns> public static LearningParameters CreateLearningParameters(string strLearning) { try { // var trParam = new LearningParameters(); //parse feature variables var strParameters = strLearning.Split(m_cntkSpearator, StringSplitOptions.RemoveEmptyEntries); //learner type var type = strParameters.Where(x => x.StartsWith("Type:")).Select(x => x.Substring("Type:".Length)).FirstOrDefault(); if (string.IsNullOrEmpty(type)) { throw new Exception("Unsupported Learning type!"); } //convert to enum trParam.LearnerType = (LearnerType)Enum.Parse(typeof(LearnerType), type, true); //loss function var loss = strParameters.Where(x => x.StartsWith("Loss:")).Select(x => x.Substring("Loss:".Length)).FirstOrDefault(); if (string.IsNullOrEmpty(loss)) { throw new Exception("Unsupported Loss function!"); } //convert to enum trParam.LossFunction = (EFunction)Enum.Parse(typeof(EFunction), loss, true); //eval function var eval = strParameters.Where(x => x.StartsWith("Eval:")).Select(x => x.Substring("Eval:".Length)).FirstOrDefault(); if (string.IsNullOrEmpty(eval)) { throw new Exception("Unsupported Evaluation function!"); } //convert to enum trParam.EvaluationFunction = (EFunction)Enum.Parse(typeof(EFunction), eval, true); //lr function var lr = strParameters.Where(x => x.StartsWith("LRate:")).Select(x => x.Substring("LRate:".Length)).FirstOrDefault(); if (string.IsNullOrEmpty(lr)) { throw new Exception("Unsupported Learning Rate function!"); } //convert to float trParam.LearningRate = double.Parse(lr, CultureInfo.InvariantCulture); //momentum function var momentum = strParameters.Where(x => x.StartsWith("Momentum:")).Select(x => x.Substring("momentum:".Length)).FirstOrDefault(); if (string.IsNullOrEmpty(momentum)) { throw new Exception("Unsupported Momentum parameter!"); } //convert to float trParam.Momentum = double.Parse(momentum, CultureInfo.InvariantCulture); // L1 var l1 = strParameters.Where(x => x.StartsWith("L1:")).Select(x => x.Substring("L1:".Length)).FirstOrDefault(); if (string.IsNullOrEmpty(l1)) { trParam.L1Regularizer = 0; } else //convert to float { trParam.L1Regularizer = double.Parse(l1, CultureInfo.InvariantCulture); } // L2 var l2 = strParameters.Where(x => x.StartsWith("L2:")).Select(x => x.Substring("L2:".Length)).FirstOrDefault(); if (string.IsNullOrEmpty(l2)) { trParam.L2Regularizer = 0; } else //convert to float { trParam.L2Regularizer = double.Parse(l2, CultureInfo.InvariantCulture); } // return(trParam); } catch (Exception) { throw; } }
private string saveLearningParameters(LearningParameters lp) { return(lp.ToString()); }
public void Learn(NetworkParameters networkParameters, LearningParameters learningParameters) { using (Py.GIL()) { layersModule = Py.Import("Python.layers"); mlpModule = Py.Import("Python.mlp"); network = mlpModule.MultilayerPerceptron( networkParameters.NumberOfEpochs, networkParameters.Seed, networkParameters.LearningRate, networkParameters.Momentum, networkParameters.Problem, networkParameters.ErrType, networkParameters.BiasesEnabled ); NetworkLayer previousLayer = null; if (networkParameters.Problem == ProblemType.REGRESSION) { previousLayer = new NetworkLayer() { NeuronsNumber = 1 } } ; else if (networkParameters.Problem == ProblemType.CLASSIFICATION) { previousLayer = new NetworkLayer() { NeuronsNumber = 2 } } ; NetworkLayer[] layers = networkParameters.Layers.OrderBy(x => x.LayerNumber).ToArray(); if (networkParameters.Problem == ProblemType.REGRESSION) { layers[layers.Length - 1].NeuronsNumber = 1; } else if (networkParameters.Problem == ProblemType.CLASSIFICATION) { layers[layers.Length - 1].NeuronsNumber = networkParameters.ClassesNumber; } for (int i = 0; i < layers.Length; i++) { NetworkLayer layer = layers[i]; dynamic curr_layer = null; switch (layer.ActFun) { case "Sigmoid": curr_layer = layersModule.SigmoidLayer(layer.NeuronsNumber, previousLayer.NeuronsNumber); break; case "Relu": curr_layer = layersModule.ReluLayer(layer.NeuronsNumber, previousLayer.NeuronsNumber); break; case "Linear": curr_layer = layersModule.LinearLayer(layer.NeuronsNumber, previousLayer.NeuronsNumber); break; case "Tanh": curr_layer = layersModule.TanhLayer(layer.NeuronsNumber, previousLayer.NeuronsNumber); break; default: System.Windows.MessageBox.Show(string.Format("Nieznana funkcja aktywacji {0}", layer.ActFun.ToString()), "Błąd", System.Windows.MessageBoxButton.OK, System.Windows.MessageBoxImage.Error); return; } previousLayer = layer; network.add_layer(curr_layer); } if (networkParameters.Problem == ProblemType.CLASSIFICATION) //X - punkty, y - etykiety { network.fit(learningParameters.Points.Select(x => new double[] { x.X, x.Y }).ToArray(), learningParameters.Points.Select(x => x.Class).ToArray()); } else if (networkParameters.Problem == ProblemType.REGRESSION) //X - wartości x, y - wartości y { network.fit(learningParameters.Points.Select(x => x.X).ToArray(), learningParameters.Points.Select(x => x.Y).ToArray()); } System.Windows.MessageBox.Show("ok"); } }