public Dictionary <string, List <double> > Train(object trainData, object validationData, int epoches, int batchSize, On_Epoch_Start OnEpochStart, On_Epoch_End OnEpochEnd, On_Batch_Start onBatchStart, On_Batch_End OnBatchEnd, bool shuffle = false) { XYFrame train = (XYFrame)trainData; XYFrame validation = validationData != null ? (XYFrame)validationData : null; Dictionary <string, List <double> > result = new Dictionary <string, List <double> >(); var trainer = Trainer.CreateTrainer(Model, lossFunc, metricFunc, learners); int currentEpoch = 1; Dictionary <string, double> metricsList = new Dictionary <string, double>(); while (currentEpoch <= epoches) { if (shuffle) { train.Shuffle(); } metricsList = new Dictionary <string, double>(); OnEpochStart(currentEpoch); int miniBatchCount = 1; while (train.NextBatch(miniBatchCount, batchSize)) { onBatchStart(currentEpoch, miniBatchCount); Value features = DataFrameUtil.GetValueBatch(train.CurrentBatch.XFrame); Value labels = DataFrameUtil.GetValueBatch(train.CurrentBatch.YFrame); trainer.TrainMinibatch(new Dictionary <Variable, Value>() { { featureVariable, features }, { labelVariable, labels } }, GlobalParameters.Device); OnBatchEnd(currentEpoch, miniBatchCount, trainer.TotalNumberOfSamplesSeen(), trainer.PreviousMinibatchLossAverage(), new Dictionary <string, double>() { { metricName, trainer.PreviousMinibatchEvaluationAverage() } }); miniBatchCount++; } if (!result.ContainsKey("loss")) { result.Add("loss", new List <double>()); } if (!result.ContainsKey(metricName)) { result.Add(metricName, new List <double>()); } double lossValue = trainer.PreviousMinibatchLossAverage(); double metricValue = trainer.PreviousMinibatchEvaluationAverage(); result["loss"].Add(lossValue); result[metricName].Add(metricValue); metricsList.Add(metricName, metricValue); if (validation != null) { if (!result.ContainsKey("val_loss")) { result.Add("val_loss", new List <double>()); } if (!result.ContainsKey("val_" + metricName)) { result.Add("val_" + metricName, new List <double>()); } int evalMiniBatchCount = 1; List <double> totalEvalBatchLossList = new List <double>(); List <double> totalEvalMetricValueList = new List <double>(); while (validation.NextBatch(evalMiniBatchCount, batchSize)) { Variable actualVariable = CNTKLib.InputVariable(labelVariable.Shape, DataType.Float); var evalLossFunc = Losses.Get(lossName, labelVariable, actualVariable); var evalMetricFunc = Metrics.Get(metricName, labelVariable, actualVariable); Value actual = EvaluateInternal(validation.CurrentBatch.XFrame); Value expected = DataFrameUtil.GetValueBatch(validation.CurrentBatch.YFrame); var inputDataMap = new Dictionary <Variable, Value>() { { labelVariable, expected }, { actualVariable, actual } }; var outputDataMap = new Dictionary <Variable, Value>() { { evalLossFunc.Output, null } }; evalLossFunc.Evaluate(inputDataMap, outputDataMap, GlobalParameters.Device); var evalLoss = outputDataMap[evalLossFunc.Output].GetDenseData <float>(evalLossFunc.Output).Select(x => x.First()).ToList(); totalEvalBatchLossList.Add(evalLoss.Average()); inputDataMap = new Dictionary <Variable, Value>() { { labelVariable, expected }, { actualVariable, actual } }; outputDataMap = new Dictionary <Variable, Value>() { { evalMetricFunc.Output, null } }; evalMetricFunc.Evaluate(inputDataMap, outputDataMap, GlobalParameters.Device); var evalMetric = outputDataMap[evalMetricFunc.Output].GetDenseData <float>(evalMetricFunc.Output).Select(x => x.First()).ToList(); totalEvalMetricValueList.Add(evalMetric.Average()); evalMiniBatchCount++; } result["val_loss"].Add(totalEvalBatchLossList.Average()); metricsList.Add("val_loss", totalEvalBatchLossList.Average()); result["val_" + metricName].Add(totalEvalMetricValueList.Average()); metricsList.Add("val_" + metricName, totalEvalMetricValueList.Average()); } OnEpochEnd(currentEpoch, trainer.TotalNumberOfSamplesSeen(), lossValue, metricsList); currentEpoch++; } return(result); }
public Dictionary <string, List <double> > Train(object trainData, object validationData, int epoches, int batchSize, On_Epoch_Start OnEpochStart, On_Epoch_End OnEpochEnd, On_Batch_Start onBatchStart, On_Batch_End OnBatchEnd, bool shuffle = false) { ImageDataGenerator train = (ImageDataGenerator)trainData; ImageDataGenerator validation = validationData != null ? (ImageDataGenerator)validationData : null; Dictionary <string, List <double> > result = new Dictionary <string, List <double> >(); var trainer = Trainer.CreateTrainer(Model, lossFunc, metricFunc, learners); int currentEpoch = 1; Dictionary <string, double> metricsList = new Dictionary <string, double>(); int imageSize = featureVariable.Shape.Rank == 1 ? featureVariable.Shape[0] : featureVariable.Shape[0] * featureVariable.Shape[1] * featureVariable.Shape[2]; int numClasses = labelVariable.Shape[0]; IList <StreamConfiguration> streamConfigurations = new StreamConfiguration[] { new StreamConfiguration("features", imageSize), new StreamConfiguration("labels", numClasses) }; if (train.GenType == ImageGenType.FromTextFile) { train.LoadTextData(featureVariable, labelVariable); if (validation != null) { validation.LoadTextData(featureVariable, labelVariable); } } while (currentEpoch <= epoches) { metricsList.Clear(); OnEpochStart(currentEpoch); int miniBatchCount = 1; List <double> miniBatchLosses = new List <double>(); while (!train.NextBatch(batchSize)) { onBatchStart(currentEpoch, miniBatchCount); trainer.TrainMinibatch(new Dictionary <Variable, Value> { { featureVariable, train.CurrentBatchX }, { labelVariable, train.CurrentBatchY } }, true, GlobalParameters.Device); OnBatchEnd(currentEpoch, miniBatchCount, trainer.TotalNumberOfSamplesSeen(), trainer.PreviousMinibatchLossAverage(), new Dictionary <string, double>() { { metricName, trainer.PreviousMinibatchEvaluationAverage() } }); miniBatchLosses.Add(trainer.PreviousMinibatchLossAverage()); miniBatchCount++; } if (!result.ContainsKey("loss")) { result.Add("loss", new List <double>()); } if (!result.ContainsKey(metricName)) { result.Add(metricName, new List <double>()); } double lossValue = miniBatchLosses.Average(); double metricValue = trainer.PreviousMinibatchEvaluationAverage(); result["loss"].Add(lossValue); result[metricName].Add(metricValue); metricsList.Add(metricName, metricValue); if (validation != null) { if (!result.ContainsKey("val_loss")) { result.Add("val_loss", new List <double>()); } if (!result.ContainsKey("val_" + metricName)) { result.Add("val_" + metricName, new List <double>()); } List <double> totalEvalBatchLossList = new List <double>(); List <double> totalEvalMetricValueList = new List <double>(); while (validation.NextBatch(batchSize)) { Variable actualVariable = CNTKLib.InputVariable(labelVariable.Shape, DataType.Float); var evalLossFunc = Losses.Get(lossName, labelVariable, actualVariable); var evalMetricFunc = Metrics.Get(metricName, labelVariable, actualVariable); Value actual = EvaluateInternal(validation.CurrentBatchX); Value expected = validation.CurrentBatchY; var inputDataMap = new Dictionary <Variable, Value>() { { labelVariable, expected }, { actualVariable, actual } }; var outputDataMap = new Dictionary <Variable, Value>() { { evalLossFunc.Output, null } }; evalLossFunc.Evaluate(inputDataMap, outputDataMap, GlobalParameters.Device); var evalLoss = outputDataMap[evalLossFunc.Output].GetDenseData <float>(evalLossFunc.Output).Select(x => x.First()).ToList(); totalEvalBatchLossList.Add(evalLoss.Average()); inputDataMap = new Dictionary <Variable, Value>() { { labelVariable, expected }, { actualVariable, actual } }; outputDataMap = new Dictionary <Variable, Value>() { { evalMetricFunc.Output, null } }; evalMetricFunc.Evaluate(inputDataMap, outputDataMap, GlobalParameters.Device); var evalMetric = outputDataMap[evalMetricFunc.Output].GetDenseData <float>(evalMetricFunc.Output).Select(x => x.First()).ToList(); totalEvalMetricValueList.Add(evalMetric.Average()); } result["val_loss"].Add(totalEvalBatchLossList.Average()); metricsList.Add("val_loss", totalEvalBatchLossList.Average()); result["val_" + metricName].Add(totalEvalMetricValueList.Average()); metricsList.Add("val_" + metricName, totalEvalMetricValueList.Average()); } OnEpochEnd(currentEpoch, trainer.TotalNumberOfSamplesSeen(), lossValue, metricsList); currentEpoch++; } return(result); }