public static Output Split(IHostEnvironment env, Input input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(ModuleName); host.CheckValue(input, nameof(input)); host.Check(0 < input.Fraction && input.Fraction < 1, "The fraction must be in the interval (0,1)."); EntryPointUtils.CheckInputArgs(host, input); var data = input.Data; var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn); IDataView trainData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = 0, Max = input.Fraction, Complement = false }, data); trainData = SelectColumnsTransform.CreateDrop(host, trainData, stratCol); IDataView testData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = 0, Max = input.Fraction, Complement = true }, data); testData = SelectColumnsTransform.CreateDrop(host, testData, stratCol); return(new Output() { TrainData = trainData, TestData = testData }); }
public static Output Split(IHostEnvironment env, Input input) { Contracts.CheckValue(env, nameof(env)); var host = env.Register(ModuleName); host.CheckValue(input, nameof(input)); EntryPointUtils.CheckInputArgs(host, input); var data = input.Data; var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn); int n = input.NumFolds; var output = new Output { TrainData = new IDataView[n], TestData = new IDataView[n] }; // Construct per-fold datasets. double fraction = 1.0 / n; for (int i = 0; i < n; i++) { var trainData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true }, data); output.TrainData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments { Column = new[] { stratCol } }, trainData); var testData = new RangeFilter(host, new RangeFilter.Arguments { Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false }, data); output.TestData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments { Column = new[] { stratCol } }, testData); } return(output); }