public void CrossValSplitThrowsWhenNotEnoughData() { var mlContext = new MLContext(1); var dataViewBuilder = new ArrayDataViewBuilder(mlContext); dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, 0f); dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, 0f); var dataView = dataViewBuilder.GetDataView(); Assert.Throws <InvalidOperationException>(() => SplitUtil.CrossValSplit(mlContext, dataView, 10, null)); }
public void CrossValSplitLargeDataView() { var mlContext = new MLContext(seed: 0); var dataViewBuilder = new ArrayDataViewBuilder(mlContext); dataViewBuilder.AddColumn("Number", NumberDataViewType.Single, new float[10000]); dataViewBuilder.AddColumn("Label", NumberDataViewType.Single, new float[10000]); var dataView = dataViewBuilder.GetDataView(); const int requestedNumSplits = 10; var splits = SplitUtil.CrossValSplit(mlContext, dataView, requestedNumSplits, null); Assert.True(splits.trainDatasets.Any()); Assert.Equal(requestedNumSplits, splits.trainDatasets.Count()); Assert.Equal(requestedNumSplits, splits.validationDatasets.Count()); }