public void IncorrectLabelColumnThrows() { var dataPath = DatasetUtil.DownloadUciAdultDataset(); var context = new MLContext(); Assert.Throws <ArgumentException>(new System.Action(() => context.Auto().InferColumns(dataPath, "Junk", groupColumns: false))); }
public void IdentifyLabelColumnThroughIndexWithHeader() { var result = new MLContext().Auto().InferColumns(DatasetUtil.DownloadUciAdultDataset(), 14, hasHeader: true); Assert.True(result.TextLoaderOptions.HasHeader); var labelCol = result.TextLoaderOptions.Columns.First(c => c.Source[0].Min == 14 && c.Source[0].Max == 14); Assert.Equal("hours-per-week", labelCol.Name); Assert.Equal("hours-per-week", result.ColumnInformation.LabelColumnName); }
public void UnGroupReturnsMoreColumnsThanGroup() { var dataPath = DatasetUtil.DownloadUciAdultDataset(); var context = new MLContext(); var columnInferenceWithoutGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: false); foreach (var col in columnInferenceWithoutGrouping.TextLoaderOptions.Columns) { Assert.False(col.Source.Length > 1 || col.Source[0].Min != col.Source[0].Max); } var columnInferenceWithGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: true); Assert.True(columnInferenceWithGrouping.TextLoaderOptions.Columns.Count() < columnInferenceWithoutGrouping.TextLoaderOptions.Columns.Count()); }
public void AutoFitBinaryTest() { var context = new MLContext(); var dataPath = DatasetUtil.DownloadUciAdultDataset(); var columnInference = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel); var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); var trainData = textLoader.Load(dataPath); var result = context.Auto() .CreateBinaryClassificationExperiment(0) .Execute(trainData, new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }); Assert.True(result.BestRun.ValidationMetrics.Accuracy > 0.70); Assert.NotNull(result.BestRun.Estimator); Assert.NotNull(result.BestRun.Model); Assert.NotNull(result.BestRun.TrainerName); }
public void LabelIndexOutOfBoundsThrows() { Assert.Throws <ArgumentOutOfRangeException>(() => new MLContext().Auto().InferColumns(DatasetUtil.DownloadUciAdultDataset(), 100)); }
public void ValidateInferColsPath() { UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.DownloadUciAdultDataset()); }