public static Output Split(IHostEnvironment env, Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(ModuleName);

            host.CheckValue(input, nameof(input));
            host.Check(0 < input.Fraction && input.Fraction < 1, "The fraction must be in the interval (0,1).");

            EntryPointUtils.CheckInputArgs(host, input);

            var data     = input.Data;
            var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);

            IDataView trainData = new RangeFilter(host,
                                                  new RangeFilter.Arguments {
                Column = stratCol, Min = 0, Max = input.Fraction, Complement = false
            }, data);

            trainData = SelectColumnsTransform.CreateDrop(host, trainData, stratCol);

            IDataView testData = new RangeFilter(host,
                                                 new RangeFilter.Arguments {
                Column = stratCol, Min = 0, Max = input.Fraction, Complement = true
            }, data);

            testData = SelectColumnsTransform.CreateDrop(host, testData, stratCol);

            return(new Output()
            {
                TrainData = trainData, TestData = testData
            });
        }
Exemple #2
0
        public static Output Split(IHostEnvironment env, Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(ModuleName);

            host.CheckValue(input, nameof(input));

            EntryPointUtils.CheckInputArgs(host, input);

            var data = input.Data;

            var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);

            int n      = input.NumFolds;
            var output = new Output
            {
                TrainData = new IDataView[n],
                TestData  = new IDataView[n]
            };

            // Construct per-fold datasets.
            double fraction = 1.0 / n;

            for (int i = 0; i < n; i++)
            {
                var trainData = new RangeFilter(host,
                                                new RangeFilter.Arguments {
                    Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true
                }, data);
                output.TrainData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments {
                    Column = new[] { stratCol }
                }, trainData);

                var testData = new RangeFilter(host,
                                               new RangeFilter.Arguments {
                    Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false
                }, data);
                output.TestData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments {
                    Column = new[] { stratCol }
                }, testData);
            }

            return(output);
        }