Пример #1
0
        public static CommonOutputs.TransformOutput FilterByRange(IHostEnvironment env, RangeFilter.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(RangeFilter.LoaderSignature);

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = new RangeFilter(host, input, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
        public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, SelectColumnsTransform.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("SelectColumns");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = new SelectColumnsTransform(env, input.KeepColumns, input.DropColumns, input.KeepHidden, input.IgnoreMissing).Transform(input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
        public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ConcatTransform.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ConcatColumns");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = ConcatTransform.Create(env, input, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
Пример #4
0
        public static CommonOutputs.TransformOutput MutualInformationSelect(IHostEnvironment env, MutualInformationFeatureSelectionTransform.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("MutualInformationSelect");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = MutualInformationFeatureSelectionTransform.Create(host, input, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
Пример #5
0
        public static Output ScoreUsingTransform(IHostEnvironment env, InputTransformScorer input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ScoreModelUsingTransform");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return
                (new Output
            {
                ScoredData = input.TransformModel.Apply(env, input.Data),
                ScoringTransform = null
            });
        }
Пример #6
0
        public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, DropColumnsTransform.KeepArguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("SelectColumns");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            // We can have an empty Columns array, indicating we
            // wish to drop all the columns.

            var xf = new DropColumnsTransform(env, input, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
Пример #7
0
        public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("LinearRegressionPredictor");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            RoleMappedData rmd;
            IPredictor     predictor;

            input.PredictorModel.PrepareData(host, new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema), out rmd, out predictor);

            var output = new CommonOutputs.SummaryOutput();

            output.Summary = GetSummaryAndStats(host, predictor, rmd.Schema, out output.Stats);
            return(output);
        }
Пример #8
0
        public static Output Split(IHostEnvironment env, Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(ModuleName);

            host.CheckValue(input, nameof(input));

            EntryPointUtils.CheckInputArgs(host, input);

            var data = input.Data;

            var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);

            int n      = input.NumFolds;
            var output = new Output
            {
                TrainData = new IDataView[n],
                TestData  = new IDataView[n]
            };

            // Construct per-fold datasets.
            double fraction = 1.0 / n;

            for (int i = 0; i < n; i++)
            {
                var trainData = new RangeFilter(host,
                                                new RangeFilter.Arguments {
                    Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true
                }, data);
                output.TrainData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments {
                    Column = new[] { stratCol }
                }, trainData);

                var testData = new RangeFilter(host,
                                               new RangeFilter.Arguments {
                    Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false
                }, data);
                output.TestData[i] = new DropColumnsTransform(host, new DropColumnsTransform.Arguments {
                    Column = new[] { stratCol }
                }, testData);
            }

            return(output);
        }
        public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvironment env, ClassificationLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareClassificationLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            int labelCol;

            if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol))
            {
                throw host.Except($"Column '{input.LabelColumn}' not found.");
            }
            var labelType = input.Data.Schema[labelCol].Type;

            if (labelType.IsKey || labelType.IsBool)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop
                });
            }

            var args = new ValueToKeyMappingTransformer.Arguments()
            {
                Column = new[]
                {
                    new ValueToKeyMappingTransformer.Column()
                    {
                        Name          = input.LabelColumn,
                        Source        = input.LabelColumn,
                        TextKeyValues = input.TextKeyValues,
                        Sort          = ValueToKeyMappingTransformer.SortOrder.Value
                    }
                }
            };
            var xf = ValueToKeyMappingTransformer.Create(host, args, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
        public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironment env, RegressionLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareRegressionLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            int labelCol;

            if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol))
            {
                throw host.Except($"Column '{input.LabelColumn}' not found.");
            }
            var labelType = input.Data.Schema[labelCol].Type;

            if (labelType == NumberType.R4 || !labelType.IsNumber)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop
                });
            }

            var args = new TypeConvertingTransformer.Arguments()
            {
                Column = new[]
                {
                    new TypeConvertingTransformer.Column()
                    {
                        Name       = input.LabelColumn,
                        Source     = input.LabelColumn,
                        ResultType = DataKind.R4
                    }
                }
            };
            var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, input.LabelColumn, DataKind.R4)).Transform(input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
Пример #11
0
        public static CombineTransformModelsOutput CombineTransformModels(IHostEnvironment env, CombineTransformModelsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineTransformModels");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.Models, nameof(input.Models));

            ITransformModel model = input.Models[input.Models.Length - 1];

            for (int i = input.Models.Length - 2; i >= 0; i--)
            {
                model = model.Apply(env, input.Models[i]);
            }

            return(new CombineTransformModelsOutput {
                OutputModel = model
            });
        }
Пример #12
0
        public static PredictorModelOutput CombineModels(IHostEnvironment env, PredictorModelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineModels");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.TransformModels, nameof(input.TransformModels));

            TransformModel model = input.TransformModels[input.TransformModels.Length - 1];

            for (int i = input.TransformModels.Length - 2; i >= 0; i--)
            {
                model = model.Apply(env, input.TransformModels[i]);
            }
            return(new PredictorModelOutput()
            {
                PredictorModel = input.PredictorModel.Apply(env, model)
            });
        }
Пример #13
0
        public static CommonOutputs.TransformOutput ConvertPredictedLabel(IHostEnvironment env, PredictedLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ConvertPredictedLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            int predictedLabelCol;

            if (!input.Data.Schema.TryGetColumnIndex(input.PredictedLabelColumn, out predictedLabelCol))
            {
                throw host.Except($"Column '{input.PredictedLabelColumn}' not found.");
            }
            var predictedLabelType = input.Data.Schema.GetColumnType(predictedLabelCol);

            if (predictedLabelType.IsNumber || predictedLabelType.IsBool)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModel(env, nop, input.Data), OutputData = nop
                });
            }

            var args = new KeyToValueTransform.Arguments()
            {
                Column = new[]
                {
                    new KeyToValueTransform.Column()
                    {
                        Name   = input.PredictedLabelColumn,
                        Source = input.PredictedLabelColumn,
                    }
                }
            };
            var xf = new KeyToValueTransform(host, args, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, xf, input.Data), OutputData = xf
            });
        }
Пример #14
0
        public static Output Score(IHostEnvironment env, Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ScoreModel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);


            IPredictor     predictor;
            var            inputData = input.Data;
            RoleMappedData data;

            input.PredictorModel.PrepareData(host, inputData, out data, out predictor);

            IDataView scoredPipe;

            using (var ch = host.Start("Creating scoring pipeline"))
            {
                ch.Trace("Creating pipeline");
                var bindable = ScoreUtils.GetSchemaBindableMapper(host, predictor, scorerSettings: null);
                ch.AssertValue(bindable);

                var mapper = bindable.Bind(host, data.Schema);
                var scorer = ScoreUtils.GetScorerComponent(mapper);
                Contracts.Assert(string.IsNullOrEmpty(scorer.SubComponentSettings));
                scorer.SubComponentSettings = string.Format("suffix={{{0}}}", input.Suffix);
                scoredPipe = scorer.CreateInstance(host, data.Data, mapper, input.PredictorModel.GetTrainingSchema(host));
                ch.Done();
            }

            return
                (new Output
            {
                ScoredData = scoredPipe,
                ScoringTransform = new TransformModel(host, scoredPipe, inputData)
            });
        }
Пример #15
0
        public static Output Split(IHostEnvironment env, Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(ModuleName);

            host.CheckValue(input, nameof(input));
            host.Check(0 < input.Fraction && input.Fraction < 1, "The fraction must be in the interval (0,1).");

            EntryPointUtils.CheckInputArgs(host, input);

            var data     = input.Data;
            var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);

            IDataView trainData = new RangeFilter(host,
                                                  new RangeFilter.Arguments {
                Column = stratCol, Min = 0, Max = input.Fraction, Complement = false
            }, data);

            trainData = new DropColumnsTransform(host, new DropColumnsTransform.Arguments {
                Column = new[] { stratCol }
            }, trainData);

            IDataView testData = new RangeFilter(host,
                                                 new RangeFilter.Arguments {
                Column = stratCol, Min = 0, Max = input.Fraction, Complement = true
            }, data);

            testData = new DropColumnsTransform(host, new DropColumnsTransform.Arguments {
                Column = new[] { stratCol }
            }, testData);

            return(new Output()
            {
                TrainData = trainData, TestData = testData
            });
        }
        public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(IHostEnvironment env,
                                                                                       RenameBinaryPredictionScoreColumnsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ScoreModel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            if (input.PredictorModel.Predictor.PredictionKind == PredictionKind.BinaryClassification)
            {
                ColumnType labelType;
                var        labelNames = input.PredictorModel.GetLabelInfo(host, out labelType);
                if (labelNames != null && labelNames.Length == 2)
                {
                    var positiveClass = labelNames[1];

                    // Rename all the score columns.
                    int colMax;
                    var maxScoreId = input.Data.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId);
                    var copyCols   = new List <(string Source, string Name)>();
                    for (int i = 0; i < input.Data.Schema.ColumnCount; i++)
                    {
                        if (input.Data.Schema.IsHidden(i))
                        {
                            continue;
                        }
                        if (!ShouldAddColumn(input.Data.Schema, i, null, maxScoreId))
                        {
                            continue;
                        }
                        // Do not rename the PredictedLabel column.
                        ReadOnlyMemory <char> tmp = default;
                        if (input.Data.Schema.TryGetMetadata(TextType.Instance, MetadataUtils.Kinds.ScoreValueKind, i,
                                                             ref tmp) &&
                            ReadOnlyMemoryUtils.EqualsStr(MetadataUtils.Const.ScoreValueKind.PredictedLabel, tmp))
                        {
                            continue;
                        }
                        var source = input.Data.Schema.GetColumnName(i);
                        var name   = source + "." + positiveClass;
                        copyCols.Add((source, name));
                    }

                    var copyColumn = new CopyColumnsTransform(env, copyCols.ToArray()).Transform(input.Data);
                    var dropColumn = new DropColumnsTransform(env, new DropColumnsTransform.Arguments()
                    {
                        Column = copyCols.Select(c => c.Source).ToArray()
                    }, copyColumn);
                    return(new CommonOutputs.TransformOutput {
                        Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn
                    });
                }
            }

            var newView = NopTransform.CreateIfNeeded(env, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModel(env, newView, input.Data), OutputData = newView
            });
        }