コード例 #1
0
        private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            Contracts.CheckValue(schema, nameof(schema));

            var mapper = ValueMapper as ISingleCanSaveOnnx;

            Contracts.CheckValue(mapper, nameof(mapper));
            Contracts.Assert(Utils.Size(outputNames) == 3); // Predicted Label, Score and Probability.

            // Prior doesn't have a feature column and uses the training label column to determine predicted labels
            if (!schema.Feature.HasValue)
            {
                Contracts.Assert(schema.Label.HasValue);
                var labelColumnName = schema.Label.Value.Name;
                return(mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(labelColumnName)));
            }

            var featName = schema.Feature.Value.Name;

            if (!ctx.ContainsColumn(featName))
            {
                return(false);
            }
            Contracts.Assert(ctx.ContainsColumn(featName));
            return(mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)));
        }
コード例 #2
0
        private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            //Check if "Probability" column was generated by the base class, only then
            //label can be predicted.
            if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2]))
            {
                string opType = "Binarizer";
                var    node   = ctx.CreateNode(opType, new[] { ctx.GetVariableName(outColumnNames[2]) },
                                               new[] { ctx.GetVariableName(outColumnNames[0]) }, ctx.GetNodeName(opType));
                node.AddAttribute("threshold", 0.5);
            }
        }
コード例 #3
0
        public void SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(((ICanSaveOnnx)this).CanSaveOnnx(ctx));

            for (int iinfo = 0; iinfo < _bindings.ColumnTypes.Length; ++iinfo)
            {
                var    columnType      = _bindings.ColumnTypes[iinfo];
                string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name;
                if (!ctx.ContainsColumn(inputColumnName))
                {
                    continue;
                }

                // If there is already a column of this name, don't add this column as an OptionalColumn/Initializer
                var srcVariableName = ctx.GetVariableName(inputColumnName);
                if (srcVariableName != inputColumnName)
                {
                    continue;
                }

                if (!SaveAsOnnxCore(ctx, srcVariableName, _bindings.ColumnTypes[iinfo]))
                {
                    ctx.RemoveColumn(inputColumnName, true);
                }
            }
        }
コード例 #4
0
        private protected override bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            Contracts.CheckValue(schema, nameof(schema));
            Contracts.Assert(ValueMapper is ISingleCanSaveOnnx);
            Contracts.Assert(schema.Feature.HasValue);
            Contracts.Assert(Utils.Size(outputNames) <= 2); // PredictedLabel and/or Score.
            var    mapper   = (ISingleCanSaveOnnx)ValueMapper;
            string featName = schema.Feature.Value.Name;

            if (!ctx.ContainsColumn(featName))
            {
                return(false);
            }
            Contracts.Assert(ctx.ContainsColumn(featName));
            return(mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(featName)));
        }
コード例 #5
0
        private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            Host.Assert(Bindings.InfoCount >= 2);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            /* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise,
             * the predicted label is based on the sign of the score.
             */
            string   opType = "Binarizer";
            OnnxNode node;
            var      binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);

            if (Bindings.InfoCount >= 3)
            {
                Host.Assert(ctx.ContainsColumn(outColumnNames[2]));
                node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType));
                node.AddAttribute("threshold", 0.5);
            }
            else
            {
                node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType));
                node.AddAttribute("threshold", 0.0);
            }
            opType = "Cast";
            node   = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
            var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();

            node.AddAttribute("to", t);
        }
コード例 #6
0
        public override bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames)
        {
            Contracts.CheckValue(ctx, nameof(ctx));
            Contracts.CheckValue(schema, nameof(schema));

            var mapper = ValueMapper as ISingleCanSaveOnnx;

            Contracts.CheckValue(mapper, nameof(mapper));
            Contracts.AssertValue(schema.Feature);
            Contracts.Assert(Utils.Size(outputNames) == 3); // Predicted Label, Score and Probablity.

            if (!ctx.ContainsColumn(schema.Feature.Name))
            {
                return(false);
            }

            Contracts.Assert(ctx.ContainsColumn(schema.Feature.Name));

            return(mapper.SaveAsOnnx(ctx, outputNames, ctx.GetVariableName(schema.Feature.Name)));
        }
コード例 #7
0
        private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            Host.Assert(Bindings.InfoCount >= 2);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            string scoreColumn = Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name;

            OnnxNode node;
            string   opType          = "Binarizer";
            var      binarizerOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "BinarizerOutput", false);

            node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType));
            node.AddAttribute("threshold", _threshold);

            string comparisonOutput = binarizerOutput;

            if (Bindings.PredColType is KeyDataViewType)
            {
                var one       = ctx.AddInitializer(1.0f, "one");
                var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "Add", false);
                opType = "Add";
                ctx.CreateNode(opType, new[] { binarizerOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");
                comparisonOutput = addOutput;
            }

            opType = "Cast";
            node   = ctx.CreateNode(opType, comparisonOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
            var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]);

            Host.Assert(predictedLabelCol.HasValue);
            node.AddAttribute("to", predictedLabelCol.Value.Type.RawType);
        }
コード例 #8
0
        private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            Host.Assert(Bindings.InfoCount >= 2);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            string   opType = "Binarizer";
            OnnxNode node;
            var      binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);

            string scoreColumn;

            if (Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name == "Score")
            {
                scoreColumn = outColumnNames[1];
            }
            else
            {
                Host.Assert(Bindings.InfoCount >= 3);
                scoreColumn = outColumnNames[2];
            }
            node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType));
            node.AddAttribute("threshold", _threshold);

            opType = "Cast";
            node   = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
            var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]);

            Host.Assert(predictedLabelCol.HasValue);
            node.AddAttribute("to", predictedLabelCol.Value.Type.RawType);
        }
コード例 #9
0
            public void SaveAsOnnx(OnnxContext ctx)
            {
                Host.CheckValue(ctx, nameof(ctx));
                for (int iinfo = 0; iinfo < _isSourceVector.Length; ++iinfo)
                {
                    string inputColumnName = _parent.ColumnPairs[iinfo].inputColumnName;
                    if (!ctx.ContainsColumn(inputColumnName))
                    {
                        continue;
                    }

                    string outputColumnName = _parent.ColumnPairs[iinfo].outputColumnName;
                    string srcVariableName  = ctx.GetVariableName(inputColumnName);
                    string dstVariableName  = ctx.AddIntermediateVariable(_type, outputColumnName, true);
                    SaveAsOnnxCore(ctx, iinfo, srcVariableName, dstVariableName);
                }
            }
コード例 #10
0
            public void SaveAsOnnx(OnnxContext ctx)
            {
                var outputToInputMap = _mapper.OutputToInputMap;

                for (int i = 0; i < outputToInputMap.Length; i++)
                {
                    var srcCol = InputSchema[outputToInputMap[i]];
                    var dstCol = OutputSchema[i];
                    if (!ctx.ContainsColumn(srcCol.Name) || dstCol.IsHidden)
                    {
                        continue;
                    }

                    var    srcVariable = ctx.GetVariableName(srcCol.Name);
                    var    dstVariable = ctx.AddIntermediateVariable(dstCol.Type, dstCol.Name);
                    string opType      = "Identity";
                    ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), "");
                }
            }
コード例 #11
0
        public void SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(((ICanSaveOnnx)this).CanSaveOnnx(ctx));

            for (int iinfo = 0; iinfo < _bindings.ColumnTypes.Length; ++iinfo)
            {
                var    columnType      = _bindings.ColumnTypes[iinfo];
                string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name;
                if (!ctx.ContainsColumn(inputColumnName))
                {
                    continue;
                }

                if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName),
                                    ctx.AddIntermediateVariable(OutputSchema[_bindings.MapIinfoToCol(iinfo)].Type, inputColumnName)))
                {
                    ctx.RemoveColumn(inputColumnName, true);
                }
            }
        }
コード例 #12
0
ファイル: TransformBase.cs プロジェクト: artemiusgreat/ML-NET
        void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(((ICanSaveOnnx)this).CanSaveOnnx(ctx));

            for (int iinfo = 0; iinfo < Infos.Length; ++iinfo)
            {
                ColInfo info            = Infos[iinfo];
                string  inputColumnName = Source.Schema[info.Source].Name;
                if (!ctx.ContainsColumn(inputColumnName))
                {
                    ctx.RemoveColumn(info.Name, false);
                    continue;
                }

                if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName),
                                    ctx.AddIntermediateVariable(OutputSchema[_bindings.MapIinfoToCol(iinfo)].Type, info.Name)))
                {
                    ctx.RemoveColumn(info.Name, true);
                }
            }
        }
コード例 #13
0
        public void SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(CanSaveOnnx);

            for (int iinfo = 0; iinfo < Infos.Length; ++iinfo)
            {
                ColInfo info             = Infos[iinfo];
                string  sourceColumnName = Source.Schema.GetColumnName(info.Source);
                if (!ctx.ContainsColumn(sourceColumnName))
                {
                    ctx.RemoveColumn(info.Name, false);
                    continue;
                }

                if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName),
                                    ctx.AddIntermediateVariable(Schema.GetColumnType(_bindings.MapIinfoToCol(iinfo)), info.Name)))
                {
                    ctx.RemoveColumn(info.Name, true);
                }
            }
        }