public void TrainerExtensionInstanceTests() { var context = new MLContext(1); var columnInfo = new ColumnInformation(); var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast <TrainerName>() .Except(new[] { TrainerName.Ova }); foreach (var trainerName in trainerNames) { var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName); IEnumerable <SweepableParam> sweepParams = null; if (trainerName != TrainerName.ImageClassification) { sweepParams = extension.GetHyperparamSweepRanges(); Assert.NotNull(sweepParams); foreach (var sweepParam in sweepParams) { sweepParam.RawValue = 1; } var instance = extension.CreateInstance(context, sweepParams, columnInfo); Assert.NotNull(instance); var pipelineNode = extension.CreatePipelineNode(null, columnInfo); Assert.NotNull(pipelineNode); } } }
public void TrainerExtensionTensorFlowInstanceTests() { var context = new MLContext(1); var columnInfo = new ColumnInformation(); var extension = TrainerExtensionCatalog.GetTrainerExtension(TrainerName.ImageClassification); var instance = extension.CreateInstance(context, null, columnInfo); Assert.NotNull(instance); var pipelineNode = extension.CreatePipelineNode(null, columnInfo); Assert.NotNull(pipelineNode); }