示例#1
0
        public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env, ColumnSelectingTransformer.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("SelectColumns");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = new ColumnSelectingTransformer(env, input.KeepColumns, input.DropColumns, input.KeepHidden, input.IgnoreMissing).Transform(input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
示例#2
0
        public static CommonOutputs.TransformOutput ConcatColumns(IHostEnvironment env, ColumnConcatenatingTransformer.Arguments input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ConcatColumns");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = ColumnConcatenatingTransformer.Create(env, input, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
示例#3
0
        public static CommonOutputs.TransformOutput FilterByRange(IHostEnvironment env, RangeFilter.Options input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(RangeFilter.LoaderSignature);

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var xf = new RangeFilter(host, input, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
示例#4
0
        public static Output ScoreUsingTransform(IHostEnvironment env, InputTransformScorer input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ScoreModelUsingTransform");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            return
                (new Output
            {
                ScoredData = input.TransformModel.Apply(env, input.Data),
                ScoringTransform = null
            });
        }
        public static CommonOutputs.SummaryOutput Summarize(IHostEnvironment env, SummarizePredictor.Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("LinearRegressionPredictor");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            RoleMappedData rmd;
            IPredictor     predictor;

            input.PredictorModel.PrepareData(host, new EmptyDataView(host, input.PredictorModel.TransformModel.InputSchema), out rmd, out predictor);

            var output = new CommonOutputs.SummaryOutput();

            output.Summary = GetSummaryAndStats(host, predictor, rmd.Schema, out output.Stats);
            return(output);
        }
        public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvironment env, ClassificationLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareClassificationLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var labelCol = input.Data.Schema.GetColumnOrNull(input.LabelColumn);

            if (!labelCol.HasValue)
            {
                throw host.ExceptSchemaMismatch(nameof(input), "predicted label", input.LabelColumn);
            }

            var labelType = labelCol.Value.Type;

            if (labelType is KeyType || labelType is BooleanDataViewType)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop
                });
            }

            var args = new ValueToKeyMappingTransformer.Options()
            {
                Columns = new[]
                {
                    new ValueToKeyMappingTransformer.Column()
                    {
                        Name          = input.LabelColumn,
                        Source        = input.LabelColumn,
                        TextKeyValues = input.TextKeyValues,
                        Sort          = ValueToKeyMappingEstimator.KeyOrdinality.ByValue
                    }
                }
            };
            var xf = ValueToKeyMappingTransformer.Create(host, args, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
示例#7
0
        public static CommonOutputs.TransformOutput PrepareClassificationLabel(IHostEnvironment env, ClassificationLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareClassificationLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            int labelCol;

            if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol))
            {
                throw host.Except($"Column '{input.LabelColumn}' not found.");
            }
            var labelType = input.Data.Schema[labelCol].Type;

            if (labelType.IsKey || labelType.IsBool)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop
                });
            }

            var args = new ValueToKeyMappingTransformer.Arguments()
            {
                Column = new[]
                {
                    new ValueToKeyMappingTransformer.Column()
                    {
                        Name          = input.LabelColumn,
                        Source        = input.LabelColumn,
                        TextKeyValues = input.TextKeyValues,
                        Sort          = ValueToKeyMappingTransformer.SortOrder.Value
                    }
                }
            };
            var xf = ValueToKeyMappingTransformer.Create(host, args, input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
示例#8
0
        public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironment env, RegressionLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("PrepareRegressionLabel");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            int labelCol;

            if (!input.Data.Schema.TryGetColumnIndex(input.LabelColumn, out labelCol))
            {
                throw host.Except($"Column '{input.LabelColumn}' not found.");
            }
            var labelType = input.Data.Schema[labelCol].Type;

            if (labelType == NumberType.R4 || !labelType.IsNumber)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return(new CommonOutputs.TransformOutput {
                    Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop
                });
            }

            var args = new TypeConvertingTransformer.Arguments()
            {
                Column = new[]
                {
                    new TypeConvertingTransformer.Column()
                    {
                        Name       = input.LabelColumn,
                        Source     = input.LabelColumn,
                        ResultType = DataKind.R4
                    }
                }
            };
            var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, input.LabelColumn, DataKind.R4)).Transform(input.Data);

            return(new CommonOutputs.TransformOutput {
                Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf
            });
        }
        public static CombineTransformModelsOutput CombineTransformModels(IHostEnvironment env, CombineTransformModelsInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("CombineTransformModels");

            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            host.CheckNonEmpty(input.Models, nameof(input.Models));

            TransformModel model = input.Models[input.Models.Length - 1];

            for (int i = input.Models.Length - 2; i >= 0; i--)
            {
                model = model.Apply(env, input.Models[i]);
            }

            return(new CombineTransformModelsOutput {
                OutputModel = model
            });
        }
示例#10
0
        public static CommonOutputs.TransformOutput ConvertPredictedLabel(IHostEnvironment env, PredictedLabelInput input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register("ConvertPredictedLabel");
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);

            var predictedLabelCol = input.Data.Schema.GetColumnOrNull(input.PredictedLabelColumn);
            if (!predictedLabelCol.HasValue)
                throw host.ExceptSchemaMismatch(nameof(input), "PredictedLabel",input.PredictedLabelColumn);
            var predictedLabelType = predictedLabelCol.Value.Type;
            if (predictedLabelType is NumberType || predictedLabelType is BoolType)
            {
                var nop = NopTransform.CreateIfNeeded(env, input.Data);
                return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, nop, input.Data), OutputData = nop };
            }

            var xf = new KeyToValueMappingTransformer(host, input.PredictedLabelColumn).Transform(input.Data);
            return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf };
        }
示例#11
0
        public static CommonOutputs.TransformOutput PrepareFeatures(IHostEnvironment env, FeatureCombinerInput input)
        {
            const string featureCombiner = "FeatureCombiner";
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(featureCombiner);
            host.CheckValue(input, nameof(input));
            EntryPointUtils.CheckInputArgs(host, input);
            using (var ch = host.Start(featureCombiner))
            {
                var viewTrain = input.Data;
                var rms = new RoleMappedSchema(viewTrain.Schema, input.GetRoles());
                var feats = rms.GetColumns(RoleMappedSchema.ColumnRole.Feature);
                if (Utils.Size(feats) == 0)
                    throw ch.Except("No feature columns specified");
                var featNames = new HashSet<string>();
                var concatNames = new List<KeyValuePair<string, string>>();
                List<TypeConvertingTransformer.ColumnInfo> cvt;
                int errCount;
                var ktv = ConvertFeatures(feats.ToArray(), featNames, concatNames, ch, out cvt, out errCount);
                Contracts.Assert(featNames.Count > 0);
                Contracts.Assert(concatNames.Count == featNames.Count);
                if (errCount > 0)
                    throw ch.Except("Encountered {0} invalid training column(s)", errCount);

                viewTrain = ApplyConvert(cvt, viewTrain, host);
                viewTrain = ApplyKeyToVec(ktv, viewTrain, host);

                // REVIEW: What about column name conflicts? Eg, what if someone uses the group id column
                // (a key type) as a feature column. We convert that column to a vector so it is no longer valid
                // as a group id. That's just one example - you get the idea.
                string nameFeat = DefaultColumnNames.Features;
                viewTrain = ColumnConcatenatingTransformer.Create(host,
                    new ColumnConcatenatingTransformer.TaggedArguments()
                    {
                        Column =
                            new[] { new ColumnConcatenatingTransformer.TaggedColumn() { Name = nameFeat, Source = concatNames.ToArray() } }
                    },
                    viewTrain);
                return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, viewTrain, input.Data), OutputData = viewTrain };
            }
        }
示例#12
0
        public static Output Split(IHostEnvironment env, Input input)
        {
            Contracts.CheckValue(env, nameof(env));
            var host = env.Register(ModuleName);

            host.CheckValue(input, nameof(input));

            EntryPointUtils.CheckInputArgs(host, input);

            var data = input.Data;

            var stratCol = SplitUtils.CreateStratificationColumn(host, ref data, input.StratificationColumn);

            int n      = input.NumFolds;
            var output = new Output
            {
                TrainData = new IDataView[n],
                TestData  = new IDataView[n]
            };

            // Construct per-fold datasets.
            double fraction = 1.0 / n;

            for (int i = 0; i < n; i++)
            {
                var trainData = new RangeFilter(host,
                                                new RangeFilter.Options {
                    Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = true
                }, data);
                output.TrainData[i] = ColumnSelectingTransformer.CreateDrop(host, trainData, stratCol);

                var testData = new RangeFilter(host,
                                               new RangeFilter.Options {
                    Column = stratCol, Min = i * fraction, Max = (i + 1) * fraction, Complement = false
                }, data);
                output.TestData[i] = ColumnSelectingTransformer.CreateDrop(host, testData, stratCol);
            }

            return(output);
        }