public void Run() { var device = DeviceDescriptor.UseDefaultDevice(); int nbObservationsInItem = 5; List <Example_106_Data> allData = GenerateExampleData(10000, nbObservationsInItem, 5); List <Example_106_Data> trainingData = allData.Take(8000).ToList(); List <Example_106_Data> evalData = allData.Skip(8000).ToList(); int[] inputDim = new int[] { 1 }; Variable input = Variable.InputVariable(NDShape.CreateNDShape(inputDim), DataType.Double, "input", dynamicAxes: new List <Axis>() { Axis.DefaultDynamicAxis(), Axis.DefaultBatchAxis() }); // default dynamic axis ??? int outputDim = 1; Function model = DefineModel_LSTM(input, nbObservationsInItem, outputDim); Function output = model.Output; //Function output = Variable.InputVariable(NDShape.CreateNDShape(new[] { 1 }), DataType.Float, "output", model.Output.DynamicAxes); //IEnumerable<int> inputSequence = ...; // pas clair //Value sequence = Value.CreateSequence(NDShape.CreateNDShape(new int[] { 1 }), inputSequence, device); uint minibatchSize = 100; Variable expectedOutput = Variable.InputVariable(NDShape.CreateNDShape(new int[] { outputDim }), DataType.Double, "expectedOutput", dynamicAxes: new List <Axis>() { Axis.DefaultBatchAxis() }); // default dynamic axis ??? Trainer trainer = MakeTrainer(expectedOutput, output, model, minibatchSize); { // train int nbSamplesToUseForTraining = trainingData.Count; int numSweepsToTrainWith = 20; int numMinibatchesToTrain = nbSamplesToUseForTraining * numSweepsToTrainWith / (int)minibatchSize; var trainingInput = trainingData.Select(x => x.Observations.Select(y => y)).ToList(); var trainingOutput = trainingData.Select(x => new[] { x.ExpectedPrediction }).ToList(); var trainingMinibatchSource = new GenericMinibatchSequenceSource(input, trainingInput, expectedOutput, trainingOutput, nbSamplesToUseForTraining, numSweepsToTrainWith, minibatchSize, device); RunTraining(trainer, trainingMinibatchSource, numMinibatchesToTrain, device); } // evaluate Evaluate(model, evalData, input, device); }
private void RunTraining(Trainer trainer, GenericMinibatchSequenceSource minibatchSource, int numMinibatchesToTrain, DeviceDescriptor device) { double aggregate_metric = 0; for (int minibatchCount = 0; minibatchCount < numMinibatchesToTrain; minibatchCount++) { IDictionary <Variable, MinibatchData> data = minibatchSource.GetNextRandomMinibatch(); trainer.TrainMinibatch(data, device); double samples = trainer.PreviousMinibatchSampleCount(); double avg = trainer.PreviousMinibatchEvaluationAverage(); aggregate_metric += avg * samples; double nbSampleSeen = trainer.TotalNumberOfSamplesSeen(); double train_error = aggregate_metric / nbSampleSeen; Debug.WriteLine($"{minibatchCount} Average training error: {train_error:p2}"); } }