Пример #1
0
        private protected CrossValidationResult[] CrossValidateTrain(IDataView data, IEstimator <ITransformer> estimator,
                                                                     int numFolds, string samplingKeyColumn, int?seed = null)
        {
            Environment.CheckValue(data, nameof(data));
            Environment.CheckValue(estimator, nameof(estimator));
            Environment.CheckParam(numFolds > 1, nameof(numFolds), "Must be more than 1");
            Environment.CheckValueOrNull(samplingKeyColumn);

            var splitColumn = DataOperationsCatalog.CreateSplitColumn(Environment, ref data, samplingKeyColumn, seed, fallbackInEnvSeed: true);
            var result      = new CrossValidationResult[numFolds];
            int fold        = 0;

            // Sequential per-fold training.
            // REVIEW: we could have a parallel implementation here. We would need to
            // spawn off a separate host per fold in that case.
            foreach (var split in DataOperationsCatalog.CrossValidationSplit(Environment, data, splitColumn, numFolds))
            {
                var model      = estimator.Fit(split.TrainSet);
                var scoredTest = model.Transform(split.TestSet);
                result[fold] = new CrossValidationResult(model, scoredTest, fold);
                fold++;
            }

            return(result);
        }