Beispiel #1
0
 public abstract ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema);
        private protected override IEnumerable <KeyValuePair <RoleMappedSchema.ColumnRole, string> > GetInputColumnRolesCore(RoleMappedSchema schema)
        {
            foreach (var col in base.GetInputColumnRolesCore(schema))
            {
                if (!col.Key.Equals(RoleMappedSchema.ColumnRole.Label))
                {
                    yield return(col);
                }
                else if (schema.Schema.TryGetColumnIndex(col.Value, out int labelIndex))
                {
                    yield return(col);
                }
            }

            if (_calculateDbi)
            {
                string feat = EvaluateUtils.GetColName(_featureCol, schema.Feature, DefaultColumnNames.Features);
                if (!schema.Schema.TryGetColumnIndex(feat, out int featCol))
                {
                    throw Host.ExceptUserArg(nameof(Arguments.FeatureColumn), "Features column '{0}' not found", feat);
                }
                yield return(RoleMappedSchema.ColumnRole.Feature.Bind(feat));
            }
        }
Beispiel #3
0
 public Bound(SchemaBindablePipelineEnsemble <T> parent, RoleMappedSchema schema)
     : base(parent, schema)
 {
     _combiner = parent.Combiner;
 }
Beispiel #4
0
 public override ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
 {
     return(new Bound(this, schema));
 }
 public override ISchemaBoundMapper Bind(IHostEnvironment env, RoleMappedSchema schema)
 {
     return(new XGBoostScalarRowMapper(schema, this, env, new ScoreMapperSchema(NumberType.Float, MetadataUtils.Const.ScoreColumnKind.Ranking)));
 }
Beispiel #6
0
        /// <summary>
        /// Get the getter for the feature column, assuming it is a vector of float.
        /// </summary>
        public static ValueGetter <VBuffer <float> > GetFeatureFloatVectorGetter(this IRow row, RoleMappedSchema schema)
        {
            Contracts.CheckValue(row, nameof(row));
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.CheckParam(schema.Schema == row.Schema, nameof(schema), "schemas don't match!");
            Contracts.CheckParam(schema.Feature != null, nameof(schema), "Missing feature column");

            return(row.GetGetter <VBuffer <float> >(schema.Feature.Index));
        }
        internal static IDataView GetSummaryAndStats(IHostEnvironment env, IPredictor predictor, RoleMappedSchema schema, out IDataView stats)
        {
            var calibrated = predictor as IWeaklyTypedCalibratedModelParameters;

            while (calibrated != null)
            {
                predictor  = calibrated.WeaklyTypedSubModel;
                calibrated = predictor as IWeaklyTypedCalibratedModelParameters;
            }

            IDataView summary = null;

            stats = null;
            var dvGetter  = predictor as ICanGetSummaryAsIDataView;
            var rowGetter = predictor as ICanGetSummaryAsIRow;

            if (dvGetter != null)
            {
                summary = dvGetter.GetSummaryDataView(schema);
            }
            if (rowGetter != null)
            {
                var row = rowGetter.GetSummaryIRowOrNull(schema);
                env.Check(dvGetter == null || row == null,
                          "Predictor outputs two summary data views, don't know which one to choose");
                if (row != null)
                {
                    summary = RowCursorUtils.RowAsDataView(env, row);
                }
                var statsRow = rowGetter.GetStatsIRowOrNull(schema);
                if (statsRow != null)
                {
                    stats = RowCursorUtils.RowAsDataView(env, statsRow);
                }
            }
            if (dvGetter == null && rowGetter == null)
            {
                var bldr         = new ArrayDataViewBuilder(env);
                var summaryModel = predictor as ICanSaveSummary;

                // Save a data view containing one row and one column with the model summary.
                if (summaryModel != null)
                {
                    var sb = new StringBuilder();
                    using (StringWriter sw = new StringWriter(sb))
                        summaryModel.SaveSummary(sw, schema);
                    bldr.AddColumn("Summary", sb.ToString());
                }
                else
                {
                    bldr.AddColumn("PredictorName", predictor.GetType().ToString());
                }
                summary = bldr.GetDataView();
            }
            env.AssertValue(summary);
            return(summary);
        }
 internal virtual DataViewSchema.Annotations MakeStatisticsMetadata(RoleMappedSchema schema, in VBuffer <ReadOnlyMemory <char> > names)
Beispiel #9
0
 ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema)
 => new FieldAwareFactorizationMachineScalarRowMapper(env, schema, Schema.Create(new BinaryClassifierSchema()), this);
Beispiel #10
0
 public static IDataScorerTransform Create(IHostEnvironment env,
                                           TreeEnsembleFeaturizerBindableMapper.Arguments args, IDataView data, ISchemaBoundMapper mapper, RoleMappedSchema trainSchema)
 {
     return(new GenericScorer(env, args, data, mapper, trainSchema));
 }
 private protected abstract ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema);
 ISchemaBoundMapper ISchemaBindableMapper.Bind(IHostEnvironment env, RoleMappedSchema schema) => BindCore(env, schema);
 private protected override ISchemaBoundMapper BindCore(IHostEnvironment env, RoleMappedSchema schema)
 {
     return(new Bound(this, schema));
 }
        private void Run(IChannel ch)
        {
            IDataLoader      loader  = null;
            IPredictor       rawPred = null;
            IDataView        view;
            RoleMappedSchema trainSchema = null;

            if (_model == null)
            {
                if (string.IsNullOrEmpty(Args.InputModelFile))
                {
                    loader      = CreateLoader();
                    rawPred     = null;
                    trainSchema = null;
                    Host.CheckUserArg(Args.LoadPredictor != true, nameof(Args.LoadPredictor),
                                      "Cannot be set to true unless " + nameof(Args.InputModelFile) + " is also specifified.");
                }
                else
                {
                    LoadModelObjects(ch, _loadPredictor, out rawPred, true, out trainSchema, out loader);
                }

                view = loader;
            }
            else
            {
                view = _model.Apply(Host, new EmptyDataView(Host, _model.InputSchema));
            }

            // Create the ONNX context for storing global information
            var assembly    = System.Reflection.Assembly.GetExecutingAssembly();
            var versionInfo = System.Diagnostics.FileVersionInfo.GetVersionInfo(assembly.Location);
            var ctx         = new OnnxContextImpl(Host, _name, ProducerName, versionInfo.FileVersion,
                                                  ModelVersion, _domain, Args.OnnxVersion);

            // Get the transform chain.
            IDataView source;
            IDataView end;
            LinkedList <ITransformCanSaveOnnx> transforms;

            GetPipe(ctx, ch, view, out source, out end, out transforms);
            Host.Assert(transforms.Count == 0 || transforms.Last.Value == end);

            // If we have a predictor, try to get the scorer for it.
            if (rawPred != null)
            {
                RoleMappedData data;
                if (trainSchema != null)
                {
                    data = new RoleMappedData(end, trainSchema.GetColumnRoleNames());
                }
                else
                {
                    // We had a predictor, but no roles stored in the model. Just suppose
                    // default column names are OK, if present.
                    data = new RoleMappedData(end, DefaultColumnNames.Label,
                                              DefaultColumnNames.Features, DefaultColumnNames.GroupId, DefaultColumnNames.Weight, DefaultColumnNames.Name, opt: true);
                }

                var scorePipe = ScoreUtils.GetScorer(rawPred, data, Host, trainSchema);
                var scoreOnnx = scorePipe as ITransformCanSaveOnnx;
                if (scoreOnnx?.CanSaveOnnx(ctx) == true)
                {
                    Host.Assert(scorePipe.Source == end);
                    end = scorePipe;
                    transforms.AddLast(scoreOnnx);
                }
                else
                {
                    Contracts.CheckUserArg(_loadPredictor != true,
                                           nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but we do not know how to save it as ONNX.");
                    ch.Warning("We do not know how to save the predictor as ONNX. Ignoring.");
                }
            }
            else
            {
                Contracts.CheckUserArg(_loadPredictor != true,
                                       nameof(Arguments.LoadPredictor), "We were explicitly told to load the predictor but one was not present.");
            }

            HashSet <string> inputColumns = new HashSet <string>();

            //Create graph inputs.
            for (int i = 0; i < source.Schema.ColumnCount; i++)
            {
                string colName = source.Schema.GetColumnName(i);
                if (_inputsToDrop.Contains(colName))
                {
                    continue;
                }

                ctx.AddInputVariable(source.Schema.GetColumnType(i), colName);
                inputColumns.Add(colName);
            }

            //Create graph nodes, outputs and intermediate values.
            foreach (var trans in transforms)
            {
                Host.Assert(trans.CanSaveOnnx(ctx));
                trans.SaveAsOnnx(ctx);
            }

            //Add graph outputs.
            for (int i = 0; i < end.Schema.ColumnCount; ++i)
            {
                if (end.Schema.IsHidden(i))
                {
                    continue;
                }

                var idataviewColumnName = end.Schema.GetColumnName(i);

                // Since the last IDataView also contains columns of the initial IDataView, last IDataView's columns found in
                // _inputToDrop should be removed too.
                if (_inputsToDrop.Contains(idataviewColumnName) || _outputsToDrop.Contains(idataviewColumnName))
                {
                    continue;
                }

                var variableName     = ctx.TryGetVariableName(idataviewColumnName);
                var trueVariableName = ctx.AddIntermediateVariable(null, idataviewColumnName, true);
                ctx.CreateNode("Identity", variableName, trueVariableName, ctx.GetNodeName("Identity"), "");
                ctx.AddOutputVariable(end.Schema.GetColumnType(i), trueVariableName);
            }

            var model = ctx.MakeModel();

            using (var file = Host.CreateOutputFile(_outputModelPath))
                using (var stream = file.CreateWriteStream())
                    model.WriteTo(stream);

            if (_outputJsonModelPath != null)
            {
                using (var file = Host.CreateOutputFile(_outputJsonModelPath))
                    using (var stream = file.CreateWriteStream())
                        using (var writer = new StreamWriter(stream))
                        {
                            var parsedJson = JsonConvert.DeserializeObject(model.ToString());
                            writer.Write(JsonConvert.SerializeObject(parsedJson, Formatting.Indented));
                        }
            }

            if (!string.IsNullOrWhiteSpace(Args.OutputModelFile))
            {
                Contracts.Assert(loader != null);

                ch.Trace("Saving the data pipe");
                // Should probably include "end"?
                SaveLoader(loader, Args.OutputModelFile);
            }
        }
Beispiel #15
0
        private protected override IDataView GetPerInstanceMetricsCore(IDataView perInst, RoleMappedSchema schema)
        {
            // Wrap with a DropSlots transform to pick only the first _numTopClusters slots.
            if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusters, out int index))
            {
                var type = perInst.Schema[index].Type;
                if (_numTopClusters < type.GetVectorSize())
                {
                    perInst = new SlotsDroppingTransformer(Host, ClusteringPerInstanceEvaluator.SortedClusters, min: _numTopClusters).Transform(perInst);
                }
            }

            if (perInst.Schema.TryGetColumnIndex(ClusteringPerInstanceEvaluator.SortedClusterScores, out index))
            {
                var type = perInst.Schema[index].Type;
                if (_numTopClusters < type.GetVectorSize())
                {
                    perInst = new SlotsDroppingTransformer(Host, ClusteringPerInstanceEvaluator.SortedClusterScores, min: _numTopClusters).Transform(perInst);
                }
            }
            return(perInst);
        }