コード例 #1
0
        private TreeEnsembleFeaturizationTransformer(IHostEnvironment host, ModelLoadContext ctx)
            : base(Contracts.CheckRef(host, nameof(host)).Register(nameof(TreeEnsembleFeaturizationTransformer)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // string: feature column's name.
            // string: the name of the columns where tree prediction values are stored.
            // string: the name of the columns where trees' leave are stored.
            // string: the name of the columns where trees' paths are stored.

            // Load stored fields.
            string featureColumnName = ctx.LoadString();

            _featureDetachedColumn = new DataViewSchema.DetachedColumn(TrainSchema[featureColumnName]);
            _treesColumnName       = ctx.LoadStringOrNull();
            _leavesColumnName      = ctx.LoadStringOrNull();
            _pathsColumnName       = ctx.LoadStringOrNull();

            // Create an argument to specify output columns' names of this transformer.
            _scorerArgs = new TreeEnsembleFeaturizerBindableMapper.Arguments
            {
                TreesColumnName  = _treesColumnName,
                LeavesColumnName = _leavesColumnName,
                PathsColumnName  = _pathsColumnName
            };

            // Create a bindable mapper. It provides the core computation and can be attached to any IDataView and produce
            // a transformed IDataView.
            BindableMapper = new TreeEnsembleFeaturizerBindableMapper(host, _scorerArgs, Model);

            // Create a scorer.
            var roleMappedSchema = MakeFeatureRoleMappedSchema(TrainSchema);

            Scorer = new GenericScorer(Host, _scorerArgs, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, roleMappedSchema), roleMappedSchema);
        }
コード例 #2
0
        public static FeatureNameCollection Create(ModelLoadContext ctx)
        {
            Contracts.AssertValue(ctx);
            ctx.CheckAtModel();
            ctx.CheckVersionInfo(GetVersionInfo());

            // *** Binary format ***
            // int: number of features (size)
            // int: number of indices (0 if dense)
            // int[]: indices (if sparse)
            // int[]: ids of names (matches either number of features or number of indices
            var size = ctx.Reader.ReadInt32();

            Contracts.CheckDecode(size >= 0);

            var isize = ctx.Reader.ReadInt32();

            Contracts.CheckDecode(isize >= -1);

            if (isize < 0)
            {
                // Dense case
                var names = new string[size];
                for (int i = 0; i < size; i++)
                {
                    var name = ctx.LoadStringOrNull();
                    names[i] = string.IsNullOrEmpty(name) ? null : name;
                }
                return(Create(size, names));
            }
            var dict    = new Dictionary <int, string>();
            var indices = new int[isize];
            var prev    = -1;

            for (int ii = 0; ii < isize; ii++)
            {
                indices[ii] = ctx.Reader.ReadInt32();
                Contracts.CheckDecode(prev < indices[ii]);
                prev = indices[ii];
            }
            Contracts.CheckDecode(prev < size);
            for (int ii = 0; ii < isize; ii++)
            {
                var name = ctx.LoadStringOrNull();
                if (!string.IsNullOrEmpty(name))
                {
                    dict.Add(indices[ii], name);
                }
            }
            return(Create(size, dict));
        }
コード例 #3
0
            public ColInfoEx(ModelLoadContext ctx, ISchema input)
            {
                Contracts.AssertValue(ctx);
                Contracts.AssertValue(input);

                // *** Binary format ***
                // int: the stopwords list language
                // int: the id of languages column name
                Lang = (Language)ctx.Reader.ReadInt32();
                Contracts.CheckDecode(Enum.IsDefined(typeof(Language), Lang));
                if (Lang == Language.Norwegian_Bokmal &&
                    ctx.Header.ModelVerWritten == 0x00010001)
                {
                    Lang = Language.Norwegian_Bokmal_v1;
                }
                _langsColName = ctx.LoadStringOrNull();
                if (_langsColName != null)
                {
                    Bind(input, _langsColName, out LangsColIndex, false);
                    Contracts.Assert(LangsColIndex >= 0);
                }
                else
                {
                    LangsColIndex = -1;
                }
            }
コード例 #4
0
        /// <summary>
        /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test,
        /// applies them sequentially to the <paramref name="srcLoader"/>, and returns the (composite) data loader.
        /// </summary>
        private static IDataLoader LoadTransforms(ModelLoadContext ctx, IDataLoader srcLoader, IHost host, Func <string, bool> isTransformTagAccepted)
        {
            Contracts.AssertValue(host, "host");
            host.AssertValue(srcLoader);
            host.AssertValue(ctx);

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of transforms
            // foreach transform: (starting from version VersionAddedTags)
            //     string: tag
            //     string: args string

            int cbFloat = ctx.Reader.ReadInt32();

            host.CheckDecode(cbFloat == sizeof(Float));

            int cxf = ctx.Reader.ReadInt32();

            host.CheckDecode(cxf >= 0);

            bool hasTags     = ctx.Header.ModelVerReadable >= VersionAddedTags;
            var  tagData     = new List <KeyValuePair <string, string> >();
            var  acceptedIds = new List <int>();

            for (int i = 0; i < cxf; i++)
            {
                string tag        = "";
                string argsString = null;
                if (hasTags)
                {
                    tag        = ctx.LoadNonEmptyString();
                    argsString = ctx.LoadStringOrNull();
                }
                if (!isTransformTagAccepted(tag))
                {
                    continue;
                }

                acceptedIds.Add(i);
                tagData.Add(new KeyValuePair <string, string>(tag, argsString));
            }

            host.Assert(tagData.Count == acceptedIds.Count);
            if (tagData.Count == 0)
            {
                return(srcLoader);
            }

            return(ApplyTransformsCore(host, srcLoader, tagData.ToArray(),
                                       (h, index, data) =>
            {
                IDataTransform xf;
                ctx.LoadModel <IDataTransform, SignatureLoadDataTransform>(host, out xf,
                                                                           string.Format(TransformDirTemplate, acceptedIds[index]), data);
                return xf;
            }));
        }
コード例 #5
0
        private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn)
        {
            // *** Binary format ***
            // <base info>
            // id of string: train label column

            trainLabelColumn = ctx.LoadStringOrNull();
            SetScorer();
        }
コード例 #6
0
            public static Bindings Create(ModelLoadContext ctx, ISchema input, IChannel ch)
            {
                Contracts.AssertValue(ch);
                ch.AssertValue(ctx);

                // *** Binary format ***
                // int: count of group column infos (ie, count of source columns)
                // For each group column info
                //     int: the tokenizer language
                //     int: the id of source column name
                //     int: the id of languages column name
                //     bool: whether the types output is required
                //     For each column info that belongs to this group column info
                //     (either one column info for tokens or two for tokens and types)
                //          int: the id of the column name

                int groupsLen = ctx.Reader.ReadInt32();
                ch.CheckDecode(groupsLen > 0);

                var names = new List<string>();
                var infos = new List<ColInfo>();
                var groups = new ColGroupInfo[groupsLen];
                for (int i = 0; i < groups.Length; i++)
                {
                    int lang = ctx.Reader.ReadInt32();
                    ch.CheckDecode(Enum.IsDefined(typeof(Language), lang));

                    string srcName = ctx.LoadNonEmptyString();
                    int srcIdx;
                    ColumnType srcType;
                    Bind(input, srcName, t => t.ItemType.IsText, SrcTypeName, out srcIdx, out srcType, false);

                    string langsName = ctx.LoadStringOrNull();
                    int langsIdx;
                    if (langsName != null)
                    {
                        ColumnType langsType;
                        Bind(input, langsName, t => t.IsText, LangTypeName, out langsIdx, out langsType, false);
                    }
                    else
                        langsIdx = -1;

                    bool requireTypes = ctx.Reader.ReadBoolByte();
                    groups[i] = new ColGroupInfo((Language)lang, srcIdx, srcName, srcType, langsIdx, langsName, requireTypes);

                    infos.Add(new ColInfo(i));
                    names.Add(ctx.LoadNonEmptyString());
                    if (requireTypes)
                    {
                        infos.Add(new ColInfo(i, isTypes: true));
                        names.Add(ctx.LoadNonEmptyString());
                    }
                }

                return new Bindings(groups, infos.ToArray(), input, false, names.ToArray());
            }
コード例 #7
0
        public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer <TModel>)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // id of string: train label column

            _trainLabelColumn = ctx.LoadStringOrNull();
            SetScorer();
        }
コード例 #8
0
        /// <summary>
        /// Loads all transforms from the <paramref name="ctx"/> that pass the <paramref name="isTransformTagAccepted"/> test,
        /// applies them sequentially to the <paramref name="srcView"/>, and returns the resulting data view.
        /// If there are no transforms in <paramref name="ctx"/> that are accepted, returns the original <paramref name="srcView"/>.
        /// The difference from the <c>Create</c> method above is that:
        /// - it doesn't wrap the results into a loader, just returns the last transform in the chain.
        /// - it accepts <see cref="IDataView"/> as input, not necessarily a loader.
        /// - it throws away the tag information.
        /// - it doesn't throw if the context is not representing a <see cref="LegacyCompositeDataLoader"/>: in this case it's assumed that no transforms
        ///   meet the test, and the <paramref name="srcView"/> is returned.
        /// Essentially, this is a helper method for the LoadTransform class.
        /// </summary>
        public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView srcView, IHostEnvironment env, Func<string, bool> isTransformTagAccepted)
        {
            Contracts.CheckValue(env, nameof(env));
            var h = env.Register(RegistrationName);

            h.CheckValue(ctx, nameof(ctx));
            h.Check(ctx.Reader.BaseStream.Position == ctx.FpMin + ctx.Header.FpModel);
            var ver = GetVersionInfo();
            if (ctx.Header.ModelSignature != ver.ModelSignature)
            {
                using (var ch = h.Start("ModelCheck"))
                {
                    ch.Info("The data model doesn't contain transforms.");
                }
                return srcView;
            }
            ModelHeader.CheckVersionInfo(ref ctx.Header, ver);

            h.CheckValue(srcView, nameof(srcView));
            h.CheckValue(isTransformTagAccepted, nameof(isTransformTagAccepted));

            // *** Binary format ***
            // int: sizeof(Float)
            // int: number of transforms
            // foreach transform: (starting from version VersionAddedTags)
            //     string: tag
            //     string: args string

            int cbFloat = ctx.Reader.ReadInt32();
            h.CheckDecode(cbFloat == sizeof(float));

            int cxf = ctx.Reader.ReadInt32();
            h.CheckDecode(cxf >= 0);

            bool hasTags = ctx.Header.ModelVerReadable >= VersionAddedTags;
            var curView = srcView;
            for (int i = 0; i < cxf; i++)
            {
                string tag = "";
                if (hasTags)
                {
                    tag = ctx.LoadNonEmptyString();
                    ctx.LoadStringOrNull(); // ignore the args string
                }
                if (!isTransformTagAccepted(tag))
                    continue;

                IDataTransform xf;
                ctx.LoadModel<IDataTransform, SignatureLoadDataTransform>(env, out xf,
                    string.Format(TransformDirTemplate, i), curView);
                curView = xf;
            }

            return curView;
        }
コード例 #9
0
        private void InitializationLogic(ModelLoadContext ctx, out string trainLabelColumn, out string scoreColumn, out string predictedLabelColumn)
        {
            // *** Binary format ***
            // <base info>
            // id of string: train label column

            trainLabelColumn = ctx.LoadStringOrNull();
            if (ctx.Header.ModelVerWritten >= 0x00010002)
            {
                scoreColumn          = ctx.LoadStringOrNull();
                predictedLabelColumn = ctx.LoadStringOrNull();
            }
            else
            {
                scoreColumn          = AnnotationUtils.Const.ScoreValueKind.Score;
                predictedLabelColumn = DefaultColumnNames.PredictedLabel;
            }

            SetScorer();
        }
コード例 #10
0
ファイル: TransformBase.cs プロジェクト: artemiusgreat/ML-NET
            public static Bindings Create(OneToOneTransformBase parent, ModelLoadContext ctx, DataViewSchema inputSchema,
                                          ITransposeDataView transposeInput, Func <DataViewType, string> testType)
            {
                Contracts.AssertValue(parent);
                var host = parent.Host;

                host.CheckValue(ctx, nameof(ctx));
                host.AssertValue(inputSchema);
                host.AssertValueOrNull(transposeInput);
                host.AssertValueOrNull(testType);

                // *** Binary format ***
                // int: number of added columns
                // for each added column
                //   int: id of output column name
                //   int: id of input column name
                int cinfo = ctx.Reader.ReadInt32();

                host.CheckDecode(cinfo > 0);

                var names = new string[cinfo];
                var infos = new ColInfo[cinfo];

                for (int i = 0; i < cinfo; i++)
                {
                    string dst = ctx.LoadNonEmptyString();
                    names[i] = dst;

                    // Note that in old files, the source name may be null indicating that
                    // the source column has the same name as the added column.
                    string tmp = ctx.LoadStringOrNull();
                    string src = tmp ?? dst;
                    host.CheckDecode(!string.IsNullOrEmpty(src));

                    int colSrc;
                    if (!inputSchema.TryGetColumnIndex(src, out colSrc))
                    {
                        throw host.ExceptSchemaMismatch(nameof(inputSchema), "source", src);
                    }
                    var type = inputSchema[colSrc].Type;
                    if (testType != null)
                    {
                        string reason = testType(type);
                        if (reason != null)
                        {
                            throw host.Except(InvalidTypeErrorFormat, src, type, reason);
                        }
                    }
                    var slotType = transposeInput?.GetSlotType(i);
                    infos[i] = new ColInfo(dst, colSrc, type, slotType as VectorDataViewType);
                }

                return(new Bindings(parent, infos, inputSchema, false, names));
            }
コード例 #11
0
        private protected EnsembleModelParametersBase(IHostEnvironment env, string name, ModelLoadContext ctx)
            : base(env, name, ctx)
        {
            // *** Binary format ***
            // int: model count
            // int: weight count (0 or model count)
            // Float[]: weights
            // for each model:
            //   int: number of SelectedFeatures (in bits)
            //   byte[]: selected features (as many as needed for number of bits == (numSelectedFeatures + 7) / 8)
            //   int: number of Metric values
            //   for each Metric:
            //     Float: metric value
            //     int: metric name (id of the metric name in the string table)
            //     in version 0x0001x0002:
            //       bool: is the metric averaged

            int count = ctx.Reader.ReadInt32();

            Host.CheckDecode(count > 0);

            int weightCount = ctx.Reader.ReadInt32();

            Host.CheckDecode(weightCount == 0 || weightCount == count);
            Weights = ctx.Reader.ReadFloatArray(weightCount);

            Models = new FeatureSubsetModel <TOutput> [count];
            var ver = ctx.Header.ModelVerWritten;

            for (int i = 0; i < count; i++)
            {
                ctx.LoadModel <IPredictor, SignatureLoadModel>(Host, out IPredictor p, string.Format(SubPredictorFmt, i));
                var predictor = p as IPredictorProducing <TOutput>;
                Host.Check(predictor != null, "Inner predictor type not compatible with the ensemble type.");
                var features   = ctx.Reader.ReadBitArray();
                int numMetrics = ctx.Reader.ReadInt32();
                Host.CheckDecode(numMetrics >= 0);
                var metrics = new KeyValuePair <string, double> [numMetrics];
                for (int j = 0; j < numMetrics; j++)
                {
                    var metricValue = ctx.Reader.ReadFloat();
                    var metricName  = ctx.LoadStringOrNull();
                    if (ver == VerOld)
                    {
                        ctx.Reader.ReadBoolByte();
                    }
                    metrics[j] = new KeyValuePair <string, double>(metricName, metricValue);
                }
                Models[i] = new FeatureSubsetModel <TOutput>(predictor, features, metrics);
            }
            ctx.LoadModel <IOutputCombiner <TOutput>, SignatureLoadModel>(Host, out Combiner, @"Combiner");
        }
コード例 #12
0
        private protected SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx, TModel model)
            : base(host, ctx, model)
        {
            FeatureColumnName = ctx.LoadStringOrNull();

            if (FeatureColumnName == null)
                FeatureColumnType = null;
            else if (!TrainSchema.TryGetColumnIndex(FeatureColumnName, out int col))
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumnName), "feature", FeatureColumnName);
            else
                FeatureColumnType = TrainSchema[col].Type;

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, ModelAsPredictor);
        }
コード例 #13
0
        public MulticlassPredictionTransformer(IHostEnvironment env, ModelLoadContext ctx)
            : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MulticlassPredictionTransformer <TModel>)), ctx)
        {
            // *** Binary format ***
            // <base info>
            // id of string: train label column

            _trainLabelColumn = ctx.LoadStringOrNull();

            var schema = new RoleMappedSchema(TrainSchema, _trainLabelColumn, FeatureColumn);
            var args   = new MultiClassClassifierScorer.Arguments();

            _scorer = new MultiClassClassifierScorer(Host, args, new EmptyDataView(Host, TrainSchema), BindableMapper.Bind(Host, schema), schema);
        }
コード例 #14
0
        protected PerInstanceEvaluatorBase(IHostEnvironment env, ModelLoadContext ctx, ISchema schema)
        {
            Host = env.Register("PerInstanceRowMapper");

            // *** Binary format **
            // int: Id of the score column name
            // int: Id of the label column name

            ScoreCol = ctx.LoadNonEmptyString();
            LabelCol = ctx.LoadStringOrNull();
            if (!string.IsNullOrEmpty(LabelCol) && !schema.TryGetColumnIndex(LabelCol, out LabelIndex))
            {
                throw Host.Except($"Did not find the label column '{LabelCol}'");
            }
            if (!schema.TryGetColumnIndex(ScoreCol, out ScoreIndex))
            {
                throw Host.Except($"Did not find column '{ScoreCol}'");
            }
        }
コード例 #15
0
        internal SingleFeaturePredictionTransformerBase(IHost host, ModelLoadContext ctx)
            : base(host, ctx)
        {
            FeatureColumn = ctx.LoadStringOrNull();

            if (FeatureColumn == null)
            {
                FeatureColumnType = null;
            }
            else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
            }
            else
            {
                FeatureColumnType = TrainSchema[col].Type;
            }

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, Model);
        }
コード例 #16
0
        internal PredictionTransformerBase(IHost host, ModelLoadContext ctx)
        {
            Host = host;

            ctx.LoadModel <TModel, SignatureLoadModel>(host, out TModel model, DirModel);
            Model = model;

            // *** Binary format ***
            // model: prediction model.
            // stream: empty data view that contains train schema.
            // id of string: feature column.

            // Clone the stream with the schema into memory.
            var ms = new MemoryStream();

            ctx.TryLoadBinaryStream(DirTransSchema, reader =>
            {
                reader.BaseStream.CopyTo(ms);
            });

            ms.Position = 0;
            var loader = new BinaryLoader(host, new BinaryLoader.Arguments(), ms);

            TrainSchema = loader.Schema;

            FeatureColumn = ctx.LoadStringOrNull();
            if (FeatureColumn == null)
            {
                FeatureColumnType = null;
            }
            else if (!TrainSchema.TryGetColumnIndex(FeatureColumn, out int col))
            {
                throw Host.ExceptSchemaMismatch(nameof(FeatureColumn), RoleMappedSchema.ColumnRole.Feature.Value, FeatureColumn);
            }
            else
            {
                FeatureColumnType = TrainSchema.GetColumnType(col);
            }

            BindableMapper = ScoreUtils.GetSchemaBindableMapper(Host, model);
        }
コード例 #17
0
        public TreeEnsemble(ModelLoadContext ctx, bool usingDefaultValues, bool categoricalSplits)
        {
            // REVIEW: Verify the contents of the ensemble, both during building,
            // and during deserialization.

            // *** Binary format ***
            // int: Number of trees
            // Regression trees (num trees of these)
            // double: Bias
            // int: Id to InputInitializationContent string, currently ignored

            _trees = new List <RegressionTree>();
            int numTrees = ctx.Reader.ReadInt32();

            Contracts.CheckDecode(numTrees >= 0);
            for (int t = 0; t < numTrees; ++t)
            {
                AddTree(RegressionTree.Load(ctx, usingDefaultValues, categoricalSplits));
            }
            Bias = ctx.Reader.ReadDouble();
            _firstInputInitializationContent = ctx.LoadStringOrNull();
        }
コード例 #18
0
            public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
            {
                Contracts.AssertValue(ctx);

                // *** Binary format ***
                // int: number of columns
                // foreach column:
                //   int: id of column name
                //   byte: DataKind
                //   byte: bool of whether this is a key type
                //   for a key type:
                //     ulong: count for key range
                //   int: number of segments
                //   foreach segment:
                //     string id: name
                //     int: min
                //     int: lim
                //     byte: force vector (verWrittenCur: verIsVectorSupported)
                int cinfo = ctx.Reader.ReadInt32();

                Contracts.CheckDecode(cinfo > 0);
                Infos = new ColInfo[cinfo];

                for (int iinfo = 0; iinfo < cinfo; iinfo++)
                {
                    string name = ctx.LoadNonEmptyString();

                    PrimitiveDataViewType itemType;
                    var kind = (InternalDataKind)ctx.Reader.ReadByte();
                    Contracts.CheckDecode(Enum.IsDefined(typeof(InternalDataKind), kind));
                    bool isKey = ctx.Reader.ReadBoolByte();
                    if (isKey)
                    {
                        ulong count;
                        Contracts.CheckDecode(KeyDataViewType.IsValidDataType(kind.ToType()));

                        count = ctx.Reader.ReadUInt64();
                        Contracts.CheckDecode(0 < count);

                        itemType = new KeyDataViewType(kind.ToType(), count);
                    }
                    else
                    {
                        itemType = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
                    }

                    int cseg = ctx.Reader.ReadInt32();

                    Segment[] segs;

                    if (cseg == 0)
                    {
                        segs = null;
                    }
                    else
                    {
                        Contracts.CheckDecode(cseg > 0);
                        segs = new Segment[cseg];
                        for (int iseg = 0; iseg < cseg; iseg++)
                        {
                            string columnName = ctx.LoadStringOrNull();
                            int    min        = ctx.Reader.ReadInt32();
                            int    lim        = ctx.Reader.ReadInt32();
                            Contracts.CheckDecode(0 <= min && min < lim);
                            bool forceVector = ctx.Reader.ReadBoolByte();
                            segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector);
                        }
                    }

                    // Note that this will throw if the segments are ill-structured, including the case
                    // of multiple variable segments (since those segments will overlap and overlapping
                    // segments are illegal).
                    Infos[iinfo] = ColInfo.Create(name, itemType, segs, false);
                }

                OutputSchema = ComputeOutputSchema();
            }