コード例 #1
0
            public void SaveAsOnnxPostProcess(OnnxContext ctx, string inputName, string[] outputNames)
            {
                Contracts.Assert(outputNames.Length >= 2);

                string opType;

                opType = "ArgMax";
                var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "ArgMaxOutput");
                var argMaxNode   = ctx.CreateNode(opType, inputName, argMaxOutput, ctx.GetNodeName(opType), "");

                argMaxNode.AddAttribute("keepdims", 1);
                argMaxNode.AddAttribute("axis", 1);

                opType = "Add";
                var one       = ctx.AddInitializer(1);
                var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput");
                var addNode   = ctx.CreateNode(opType, new[] { argMaxOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");

                opType = "Cast";
                var castToUint32Node = ctx.CreateNode(opType, addOutput, outputNames[0], ctx.GetNodeName(opType), "");
                var t2 = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();

                castToUint32Node.AddAttribute("to", t2);

                opType = "Max";
                ctx.CreateNode(opType, inputName, outputNames[1], ctx.GetNodeName(opType), "");
            }
コード例 #2
0
            private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
            {
                const int minimumOpSetVersion = 9;

                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

                string       opType = "Tokenizer";
                DataViewType dataViewType;

                if (_isSourceVector[iinfo])
                {
                    dataViewType = new VectorDataViewType(TextDataViewType.Instance, _sourceVectorLength[iinfo]);
                }
                else
                {
                    dataViewType = TextDataViewType.Instance;
                }

                string tokenizerOutput = ctx.AddIntermediateVariable(dataViewType, "TokenizerOutput", true);
                var    node            = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft");

                node.AddAttribute("mark", _parent._useMarkerChars);
                node.AddAttribute("mincharnum", 1);
                node.AddAttribute("pad_value", "");
                node.AddAttribute("separators", new string[] { "" });

                opType = "Squeeze";
                var squeezeOutput = ctx.AddIntermediateVariable(dataViewType, "SqueezeOutput");

                node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), "");
                node.AddAttribute("axes", new long[] { 1 });

                opType = "LabelEncoder";
                var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput");

                node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType));

                IEnumerable <string> charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString());
                IEnumerable <long>   charValues  = Enumerable.Range(0, 65535).Select(x => Convert.ToInt64(x));

                node.AddAttribute("keys_strings", charStrings);
                node.AddAttribute("values_int64s", charValues);

                opType = "Cast";
                var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
                var t        = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt16).ToType();

                castNode.AddAttribute("to", t);
            }
コード例 #3
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            string opType = "OneHotEncoder";
            var    node   = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            node.AddAttribute("cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x));
            node.AddAttribute("zeros", true);
            return(true);
        }
コード例 #4
0
        private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
        {
            var opType    = "Log";
            var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);

            ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");

            opType = "Mul";
            ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), "");
        }
コード例 #5
0
        private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
        {
            var opType    = "Log";
            var logOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length), "LogOutput");

            ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");

            opType = "Mul";
            ctx.CreateNode(opType, new[] { logOutput, isFeaturePresent }, new[] { output }, ctx.GetNodeName(opType), "");
        }
コード例 #6
0
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            // Mapping score to prediction
            var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);

            base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
            var opType = "Exp";

            ctx.CreateNode(opType, new[] { fastTreeOutput }, outputNames, ctx.GetNodeName(opType), "");
            return(true);
        }
コード例 #7
0
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            // Mapping score to prediction
            var fastTreeOutput = ctx.AddIntermediateVariable(null, "FastTreeOutput", true);
            var numTrees       = ctx.AddInitializer((float)TrainedEnsemble.NumTrees, "NumTrees");

            base.SaveAsOnnx(ctx, new[] { fastTreeOutput }, featureColumn);
            var opType = "Div";

            ctx.CreateNode(opType, new[] { fastTreeOutput, numTrees }, outputNames, ctx.GetNodeName(opType), "");
            return(true);
        }
コード例 #8
0
            private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
            {
                // StringNormalizer only takes input of shapes [C] or [1,C],
                // so the input is squeezed to support inferred shapes ( e.g. [-1,C] ).
                var opType        = "Squeeze";
                var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true);
                var node          = ctx.CreateNode(opType, srcVariableName, squeezeOutput, ctx.GetNodeName(opType), "");

                node.AddAttribute("axes", new long[] { 1 });

                opType = "StringNormalizer";
                var normalizerOutput = ctx.AddIntermediateVariable(null, "NormalizerOutput", true);

                node = ctx.CreateNode(opType, squeezeOutput, normalizerOutput, ctx.GetNodeName(opType), "");
                var isCaseChange = (_parent._caseMode == TextNormalizingEstimator.CaseMode.Lower) ? "LOWER" :
                                   (_parent._caseMode == TextNormalizingEstimator.CaseMode.Upper) ? "UPPER" : "NONE";

                node.AddAttribute("case_change_action", isCaseChange);

                opType = "Unsqueeze";
                node   = ctx.CreateNode(opType, normalizerOutput, dstVariableName, ctx.GetNodeName(opType), "");
                node.AddAttribute("axes", new long[] { 1 });
            }
コード例 #9
0
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
        {
            Host.CheckValue(ctx, nameof(ctx));

            string opType = "LinearClassifier";
            var    node   = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));

            // Selection of logit or probit output transform. enum {'NONE', 'SOFTMAX', 'LOGISTIC', 'SOFTMAX_ZERO', 'PROBIT}
            node.AddAttribute("post_transform", "NONE");
            node.AddAttribute("multi_class", true);
            node.AddAttribute("coefficients", _weights.SelectMany(w => w.DenseValues()));
            node.AddAttribute("intercepts", _biases);
            node.AddAttribute("classlabels_ints", Enumerable.Range(0, _numClasses).Select(x => (long)x));
            return(true);
        }
コード例 #10
0
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Check(Utils.Size(outputs) == 1);

            string opType = "LinearRegressor";
            var    node   = ctx.CreateNode(opType, new[] { featureColumn }, outputs, ctx.GetNodeName(opType));

            // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
            node.AddAttribute("post_transform", "NONE");
            node.AddAttribute("targets", 1);
            node.AddAttribute("coefficients", Weight.DenseValues());
            node.AddAttribute("intercepts", new float[] { Bias });
            return(true);
        }
コード例 #11
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            DataKind rawKind;
            var      type = Infos[iinfo].TypeSrc;

            if (type.IsVector)
            {
                rawKind = type.AsVector.ItemType.RawKind;
            }
            else if (type.IsKey)
            {
                rawKind = type.AsKey.RawKind;
            }
            else
            {
                rawKind = type.RawKind;
            }

            if (rawKind != DataKind.R4)
            {
                return(false);
            }

            string opType = "Imputer";
            var    node   = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            node.AddAttribute("replaced_value_float", Single.NaN);

            if (!Infos[iinfo].TypeSrc.IsVector)
            {
                node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1));
            }
            else
            {
                if (_repIsDefault[iinfo] != null)
                {
                    node.AddAttribute("imputed_value_floats", (float[])_repValues[iinfo]);
                }
                else
                {
                    node.AddAttribute("imputed_value_floats", Enumerable.Repeat((float)_repValues[iinfo], 1));
                }
            }

            return(true);
        }
コード例 #12
0
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Check(Utils.Size(outputs) == 1);

            string opType = "LinearRegressor";
            var    node   = OnnxUtils.MakeNode(opType, new List <string> {
                featureColumn
            }, new List <string> (outputs), ctx.GetNodeName(opType));

            // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
            OnnxUtils.NodeAddAttributes(node, "post_transform", 0);
            OnnxUtils.NodeAddAttributes(node, "targets", 1);
            OnnxUtils.NodeAddAttributes(node, "coefficients", Weight.DenseValues());
            OnnxUtils.NodeAddAttributes(node, "intercepts", Bias);
            ctx.AddNode(node);
            return(true);
        }
コード例 #13
0
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
        {
            Host.CheckValue(ctx, nameof(ctx));

            string opType = "LinearClassifier";
            var    node   = OnnxUtils.MakeNode(opType, new List <string> {
                featureColumn
            }, new List <string>(outputs), ctx.GetNodeName(opType));

            // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
            OnnxUtils.NodeAddAttributes(node, "post_transform", 0);
            OnnxUtils.NodeAddAttributes(node, "multi_class", true);
            OnnxUtils.NodeAddAttributes(node, "coefficients", _weights.SelectMany(w => w.DenseValues()));
            OnnxUtils.NodeAddAttributes(node, "intercepts", _biases);
            OnnxUtils.NodeAddAttributes(node, "classlabels_strings", _labelNames);
            ctx.AddNode(node);
            return(true);
        }
コード例 #14
0
            public void SaveAsOnnx(OnnxContext ctx)
            {
                var droppedCols = new HashSet <int>(Enumerable.Range(0, InputSchema.Count));

                var outputToInputMap = _mapper.OutputToInputMap;

                for (int i = 0; i < outputToInputMap.Length; i++)
                {
                    var    srcCol      = InputSchema[outputToInputMap[i]];
                    var    dstCol      = OutputSchema[i];
                    var    srcVariable = ctx.GetVariableName(srcCol.Name);
                    var    dstVariable = ctx.AddIntermediateVariable(dstCol.Type, dstCol.Name, true);
                    string opType      = "Identity";
                    ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), "");

                    droppedCols.Remove(srcCol.Index);
                }
            }
コード例 #15
0
            public void SaveAsOnnx(OnnxContext ctx)
            {
                var outputToInputMap = _mapper.OutputToInputMap;

                for (int i = 0; i < outputToInputMap.Length; i++)
                {
                    var srcCol = InputSchema[outputToInputMap[i]];
                    var dstCol = OutputSchema[i];
                    if (!ctx.ContainsColumn(srcCol.Name) || dstCol.IsHidden)
                    {
                        continue;
                    }

                    var    srcVariable = ctx.GetVariableName(srcCol.Name);
                    var    dstVariable = ctx.AddIntermediateVariable(dstCol.Type, dstCol.Name);
                    string opType      = "Identity";
                    ctx.CreateNode(opType, srcVariable, dstVariable, ctx.GetNodeName(opType), "");
                }
            }
コード例 #16
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            if (!info.TypeSrc.ItemType.IsText)
            {
                return(false);
            }

            var terms            = default(VBuffer <DvText>);
            TermMap <DvText> map = (TermMap <DvText>)_termMap[iinfo].Map;

            map.GetTerms(ref terms);
            string opType = "LabelEncoder";
            var    node   = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            node.AddAttribute("classes_strings", terms.DenseValues());
            node.AddAttribute("default_int64", -1);
            node.AddAttribute("default_string", DvText.Empty);
            return(true);
        }
コード例 #17
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            var opType = "CSharp";

            for (int i = 0; i < _exes.Length; i++)
            {
                var ex   = _exes[i];
                var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
                node.AddAttribute("type", LoaderSignature);
                node.AddAttribute("to", (byte)ex.Kind);
                if (ex.HasKeyRange)
                {
                    var key = ex.TypeDst.ItemType.AsKey;
                    node.AddAttribute("min", key.Min);
                    node.AddAttribute("max", key.Count);
                    node.AddAttribute("contiguous", key.Contiguous);
                }
            }

            return(true);
        }
コード例 #18
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            if (!info.TypeSrc.ItemType.IsText)
            {
                return(false);
            }

            var terms            = default(VBuffer <DvText>);
            TermMap <DvText> map = (TermMap <DvText>)_termMap[iinfo].Map;

            map.GetTerms(ref terms);
            string opType = "LabelEncoder";
            var    node   = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            node.AddAttribute("classes_strings", terms.DenseValues());
            node.AddAttribute("default_int64", -1);
            //default_string needs to be an empty string but there is a BUG in Lotus that
            //throws a validation error when default_string is empty. As a work around, set
            //default_string to a space.
            node.AddAttribute("default_string", " ");
            return(true);
        }
コード例 #19
0
        public override void SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnx(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            //Check if "Probability" column was generated by the base class, only then
            //label can be predicted.
            if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2]))
            {
                string opType = "Binarizer";
                var    node   = OnnxUtils.MakeNode(opType, new List <string> {
                    ctx.GetVariableName(outColumnNames[2])
                },
                                                   new List <string> {
                    ctx.GetVariableName(outColumnNames[0])
                }, ctx.GetNodeName(opType));

                OnnxUtils.NodeAddAttributes(node, "threshold", 0.5);
                ctx.AddNode(node);
            }
        }
コード例 #20
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            Contracts.AssertValue(ctx);
            Contracts.Assert(0 <= iinfo && iinfo < Infos.Length);
            Contracts.Assert(Infos[iinfo] == info);
            Contracts.Assert(CanSaveOnnx);

            if (info.TypeSrc.ValueCount == 0)
            {
                return(false);
            }

            string opType = "Scaler";
            var    node   = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            if (_functions[iinfo].OnnxInfo(ctx, new OnnxUtils.NodeProtoWrapper(node), info.TypeSrc.ValueCount))
            {
                ctx.AddNode(node);
                return(true);
            }

            return(false);
        }
コード例 #21
0
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            // Computation graph of distances to all centriods for a batch of examples. Note that a centriod is just
            // the center of a cluster. We use [] to denote the dimension of a variable; for example, X [3, 2] means
            // that X is a 3-by-2 tensor. In addition, for a matrix X, X^T denotes its transpose.
            //
            // Symbols:
            // l: # of examples.
            // n: # of features per input example.
            // X: input examples, l-by-n tensor.
            // C: centriods, k-by-n tensor.
            // C^2: 2-norm of all centriod vectors, its shape is [k].
            // Y: 2-norm of difference between examples and centriods, l-by-k tensor. The value at i-th row and k-th
            // column row, Y[i,k], is the distance from example i to centrioid k.
            // L: the id of the nearest centriod for each input example, its shape is [l].
            //
            // .------------------------------------------------------.
            // |                                                      |
            // |                                                      v
            // X [l, n] --> ReduceSumSquare --> X^2 [l]             Gemm (alpha=-2, transB=1) <-- C [k, n]
            //                                   |                    |
            //                                   |                    v
            //                                   `------> Add <---- -2XC^T [l, k]
            //                                             |
            //                                             v
            //                                             Z [l, k] ----------> Add <------------C^2 [k]
            //                                                                   |
            //                                                                   v
            //                                           L [l] <--- ArgMin <---  Y [l, k]

            // Allocate C, which is a constant tensor in prediction phase
            var shapeC  = new long[] { _centroids.Length, _centroids[0].Length };
            var tensorC = new List <float>();

            foreach (var centriod in _centroids)
            {
                tensorC.AddRange(centriod.DenseValues());
            }
            var nameC = ctx.AddInitializer(tensorC, shapeC, "C");

            // Save C^2 as an initializer because it's a constant.
            var shapeC2 = new long[] { _centroidL2s.Length };
            var nameC2  = ctx.AddInitializer(_centroidL2s, shapeC2, "C2");

            // Retrieve the name of X
            var nameX = featureColumn;

            // Compute X^2 from X
            var nameX2       = ctx.AddIntermediateVariable(null, "X2", true);
            var reduceNodeX2 = ctx.CreateNode("ReduceSumSquare", nameX, nameX2, ctx.GetNodeName("ReduceSumSquare"), "");

            // Compute -2XC^T. Note that Gemm always takes three inputs. Since we only have two here,
            // a dummy one, named zero, is created.
            var zeroName    = ctx.AddInitializer(new Float[] { 0f }, null, "zero");
            var nameXC2     = ctx.AddIntermediateVariable(null, "XC2", true);
            var gemmNodeXC2 = ctx.CreateNode("Gemm", new[] { nameX, nameC, zeroName }, new[] { nameXC2 }, ctx.GetNodeName("Gemm"), "");

            gemmNodeXC2.AddAttribute("alpha", -2f);
            gemmNodeXC2.AddAttribute("transB", 1);

            // Compute Z = X^2 - 2XC^T
            var nameZ    = ctx.AddIntermediateVariable(null, "Z", true);
            var addNodeZ = ctx.CreateNode("Add", new[] { nameX2, nameXC2 }, new[] { nameZ }, ctx.GetNodeName("Add"), "");

            // Compute Y = Z + C^2
            var nameY    = outputNames[1];
            var addNodeY = ctx.CreateNode("Add", new[] { nameZ, nameC2 }, new[] { nameY }, ctx.GetNodeName("Add"), "");

            // Compute the most-matched cluster index, L
            var nameL        = outputNames[0];
            var predictNodeL = ctx.CreateNode("ArgMin", nameY, nameL, ctx.GetNodeName("ArgMin"), "");

            predictNodeL.AddAttribute("axis", 1);
            predictNodeL.AddAttribute("keepdims", 1);

            return(true);
        }
コード例 #22
0
            private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
            {
                var model     = _parent._models[iinfo];
                int dimension = _srcTypes[iinfo].GetValueCount();

                Host.Assert(model.Length == dimension * dimension);

                var parameters = _parent._columns[iinfo];

                Host.Assert(parameters.Kind == WhiteningKind.PrincipalComponentAnalysis || parameters.Kind == WhiteningKind.ZeroPhaseComponentAnalysis);

                int rank = (parameters.Kind == WhiteningKind.PrincipalComponentAnalysis && parameters.Rank > 0) ? parameters.Rank : dimension;

                Host.CheckParam(rank <= dimension, nameof(rank), "Rank must be at most the dimension of untransformed data.");

                long[] modelDimension = { rank, dimension };

                var opType        = "Gemm";
                var modelName     = ctx.AddInitializer(model.Take(rank * dimension), modelDimension, "model");
                var zeroValueName = ctx.AddInitializer((float)0);

                var gemmOutput = ctx.AddIntermediateVariable(null, "GemmOutput", true);
                var node       = ctx.CreateNode(opType, new[] { modelName, srcVariableName, zeroValueName }, new[] { gemmOutput }, ctx.GetNodeName(opType), "");

                node.AddAttribute("transB", 1);

                opType = "Transpose";
                ctx.CreateNode(opType, new[] { gemmOutput }, new[] { dstVariableName }, ctx.GetNodeName(opType), "");
            }
コード例 #23
0
        private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            Host.Assert(Bindings.InfoCount >= 2);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            string   opType = "Binarizer";
            OnnxNode node;
            var      binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);

            string scoreColumn;

            if (Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name == "Score")
            {
                scoreColumn = outColumnNames[1];
            }
            else
            {
                Host.Assert(Bindings.InfoCount >= 3);
                scoreColumn = outColumnNames[2];
            }
            node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType));
            node.AddAttribute("threshold", _threshold);

            opType = "Cast";
            node   = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
            var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]);

            Host.Assert(predictedLabelCol.HasValue);
            node.AddAttribute("to", predictedLabelCol.Value.Type.RawType);
        }
コード例 #24
0
        private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            Host.Assert(Bindings.InfoCount >= 2);

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            string scoreColumn = Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name;

            OnnxNode node;
            string   opType          = "Binarizer";
            var      binarizerOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "BinarizerOutput", false);

            node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType));
            node.AddAttribute("threshold", _threshold);

            string comparisonOutput = binarizerOutput;

            if (Bindings.PredColType is KeyDataViewType)
            {
                var one       = ctx.AddInitializer(1.0f, "one");
                var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "Add", false);
                opType = "Add";
                ctx.CreateNode(opType, new[] { binarizerOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");
                comparisonOutput = addOutput;
            }

            opType = "Cast";
            node   = ctx.CreateNode(opType, comparisonOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
            var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]);

            Host.Assert(predictedLabelCol.HasValue);
            node.AddAttribute("to", predictedLabelCol.Value.Type.RawType);
        }
コード例 #25
0
            private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
            {
                // Converts 1 column that is taken as input to the transform into one column of output
                //
                // Missing words are mapped to k for finding average, k + 1 for finding min, and k + 2 for finding max
                // Those spots in the dictionary contain a vector of 0s, max floats, and min floats, respectively
                //
                // Symbols:
                // j: length of latent vector of every word in the pretrained model
                // n: length of input tensor (number of words)
                // X: word input, a tensor with n elements.
                // k: # of words in pretrained model (known when transform is created)
                // S: word labels, k tensor (known when transform is created)
                // D: word embeddings, (k + 3)-by-j tensor(known when transform is created). The extra three embeddings
                //      at the end are used for out of vocab words.
                // F: location value representing missing words, equal to k
                // P: output, a j * 3 tensor
                //
                //                                                      X [n]
                //                                                       |
                //                                                     nameX
                //                                                       |
                //                           LabelEncoder (classes_strings = S [k], default_int64 = k)
                //                                                       |
                //                            /----------------------- nameY -----------------------\
                //                           /   |                       |                           \
                //     Initialize (F)-------/----|------ nameF ------> Equal                          \
                //                         /     |                       |                             \
                //                        /      |                     nameA                            \
                //                       /       |                     / |  \                            \
                //                      /        '-------------|      /  |   \                            \
                //                     /                 ------|-----/   |    \------------------          \---------
                //                    /                 /      |         |                       \                   \
                //                    |      Cast (to = int64) |  Cast (to = float)              Not                 |
                //                    |             |          |         |                        |                  |
                //                    |         nameVMin       |       nameB                    nameQ                |
                //                    |             |          |         |                        |                  |
                //                  Add ------------'          | Scale (scale = 2.0)         Cast (to = int32)       |
                //                    |                        |         |                        |                  |
                //                    |                        |     nameSMax                  nameZ                 |
                //                    |                        |         |                        |                  |
                //                    |                        | Cast (to = int64)       ReduceSum (axes = [0])      |
                //                namePMin                     |         |                        |                  |
                //                    |                        |      nameVMax                 nameR                 |
                //                    |                        |         |                        |                  |
                //                    |                        '-- Add --'                Cast (to = float)          |
                //                    |   Initialize (D [k + 3, j]   |                            |                  |
                //                    |             |                |                            |                  |
                //                    |           nameD           namePMax                     nameRF                |
                //                    |             |                |                            |                  |
                //                    |             |                |                     Clip (min = 1.0)          |
                //                    |             |                |                            |                  |
                //                    |             |                |                          nameT                |
                //                    |             |----------------|----------------------------|--------\         |
                //                    |             |                |                            |         \        |
                //                    |   /---------'-------------\  |                            |          '----\  |
                //                  Gather                        Gather                          |               Gather
                //                    |                              |                            |                  |
                //                 nameGMin                       nameGMax                        |                nameW
                //                    |                              |                            |                  |
                //            ReduceMin (axes = [0])      ReduceMax (axes = [0])                  |        ReduceSum (axes = [0])
                //                    |                              |                            |                  |
                //                    |                              |                            |                nameK
                //                    |                              |                            |                  |
                //                    |                              |                            '------- Div ------'
                //                  nameJ                          nameL                                    |
                //                    |                              |                                   nameE
                //                    |                              |                                      |
                //                    '------------------- Concat (axis = 1) -------------------------------'
                //                                                   |
                //                                                 nameP
                //                                                   |
                //                                               P [j * 3]

                long[] axes = new long[] { 0 };
                // Allocate D, a constant tensor representing word embedding weights.
                var shapeD      = new long[] { _parent._currentVocab.GetNumWords() + 3, _parent._currentVocab.Dimension };
                var wordVectors = _parent._currentVocab.WordVectors;
                var tensorD     = new List <float>();

                tensorD.AddRange(wordVectors);
                // Out-of-vocab embedding vector for combining embeddings by mean.
                tensorD.AddRange(Enumerable.Repeat(0.0f, _parent._currentVocab.Dimension));
                // Out-of-vocab embedding vector for combining embeddings by element-wise min.
                tensorD.AddRange(Enumerable.Repeat(float.MaxValue, _parent._currentVocab.Dimension));
                // Out-of-vocab embedding vector for combining embeddings by element-wise max.
                tensorD.AddRange(Enumerable.Repeat(float.MinValue, _parent._currentVocab.Dimension));
                var nameD = ctx.AddInitializer(tensorD, shapeD, "WordEmbeddingWeights");

                // Allocate F, a value representing an out-of-dictionary word.
                var tensorF = _parent._currentVocab.GetNumWords();
                var nameF   = ctx.AddInitializer(tensorF, "NotFoundValueComp");

                // Retrieve X, name of input.
                var nameX = srcVariableName;

                // Do label encoding. Out-of-vocab tokens will be mapped to the size of vocabulary. Because the index of vocabulary
                // is zero-based, the size of vocabulary is just greater then the max indexes computed from in-vocab tokens by one.
                var nameY = ctx.AddIntermediateVariable(null, "LabelEncodedInput", true);
                var nodeY = ctx.CreateNode("LabelEncoder", nameX, nameY, ctx.GetNodeName("LabelEncoder"));

                nodeY.AddAttribute("classes_strings", _parent._currentVocab.GetWordLabels());
                nodeY.AddAttribute("default_int64", _parent._currentVocab.GetNumWords());

                // Do steps necessary for min and max embedding vectors.

                // Map to boolean vector representing missing words. The following Equal produces 1 if a token is missing and 0 otherwise.
                var nameA = ctx.AddIntermediateVariable(null, "NotFoundValuesBool", true);
                var nodeA = ctx.CreateNode("Equal", new[] { nameY, nameF }, new[] { nameA }, ctx.GetNodeName("Equal"), "");

                // Cast the not found vector to a vector of floats.
                var nameB = ctx.AddIntermediateVariable(null, "NotFoundValuesFloat", true);
                var nodeB = ctx.CreateNode("Cast", nameA, nameB, ctx.GetNodeName("Cast"), "");

                nodeB.AddAttribute("to", 1);

                // Scale the not found vector to get the location bias for max weights.
                var nameSMax = ctx.AddIntermediateVariable(null, "ScaleMax", true);
                var nodeSMax = ctx.CreateNode("Scale", nameB, nameSMax, ctx.GetNodeName("Scale"), "");

                nodeSMax.AddAttribute("scale", 2.0);

                // Cast scaled word label locations to ints.
                var nameVMin = ctx.AddIntermediateVariable(null, "CastMin", true);
                var nodeVMin = ctx.CreateNode("Cast", nameA, nameVMin, ctx.GetNodeName("Cast"), "");

                nodeVMin.AddAttribute("to", 7);

                var nameVMax = ctx.AddIntermediateVariable(null, "CastMax", true);
                var nodeVMax = ctx.CreateNode("Cast", nameSMax, nameVMax, ctx.GetNodeName("Cast"), "");

                nodeVMax.AddAttribute("to", 7);

                // Add the scaled options back to originals. The outputs of the following Add operators are almost identical
                // the output of the previous LabelEncoder. The only difference is that out-of-vocab tokens are mapped to k+1
                // for applying ReduceMin and k+2 for applying ReduceMax so that out-of-vocab tokens do not affect embedding results at all.
                var namePMin = ctx.AddIntermediateVariable(null, "AddMin", true);
                var nodePMin = ctx.CreateNode("Add", new[] { nameY, nameVMin }, new[] { namePMin }, ctx.GetNodeName("Add"), "");

                var namePMax = ctx.AddIntermediateVariable(null, "AddMax", true);
                var nodePMax = ctx.CreateNode("Add", new[] { nameY, nameVMax }, new[] { namePMax }, ctx.GetNodeName("Add"), "");

                // Map encoded words to their embedding vectors, mapping missing ones to min/max.
                var nameGMin = ctx.AddIntermediateVariable(null, "GatheredMin", true);
                var nodeGMin = ctx.CreateNode("Gather", new[] { nameD, namePMin }, new[] { nameGMin }, ctx.GetNodeName("Gather"), "");

                var nameGMax = ctx.AddIntermediateVariable(null, "GatheredMax", true);
                var nodeGMax = ctx.CreateNode("Gather", new[] { nameD, namePMax }, new[] { nameGMax }, ctx.GetNodeName("Gather"), "");

                // Merge all embedding vectors using element-wise min/max per embedding coordinate.
                var nameJ = ctx.AddIntermediateVariable(null, "MinWeights", true);
                var nodeJ = ctx.CreateNode("ReduceMin", nameGMin, nameJ, ctx.GetNodeName("ReduceMin"), "");

                nodeJ.AddAttribute("axes", axes);

                var nameL = ctx.AddIntermediateVariable(null, "MaxWeights", true);
                var nodeL = ctx.CreateNode("ReduceMax", nameGMax, nameL, ctx.GetNodeName("ReduceMax"), "");

                nodeL.AddAttribute("axes", axes);

                // Do steps necessary for mean embedding vector.

                // Map encoded words to their embedding vectors using Gather.
                var nameW = ctx.AddIntermediateVariable(null, "GatheredMean", true);
                var nodeW = ctx.CreateNode("Gather", new[] { nameD, nameY }, new[] { nameW }, ctx.GetNodeName("Gather"), "");

                // Find the sum of the embedding vectors.
                var nameK = ctx.AddIntermediateVariable(null, "SumWeights", true);
                var nodeK = ctx.CreateNode("ReduceSum", nameW, nameK, ctx.GetNodeName("ReduceSum"), "");

                nodeK.AddAttribute("axes", axes);

                // Flip the boolean vector representing missing words to represent found words.
                var nameQ = ctx.AddIntermediateVariable(null, "FoundValuesBool", true);
                var nodeQ = ctx.CreateNode("Not", nameA, nameQ, ctx.GetNodeName("Not"), "");

                // Cast the found words vector to ints.
                var nameZ = ctx.AddIntermediateVariable(null, "FoundValuesInt", true);
                var nodeZ = ctx.CreateNode("Cast", nameQ, nameZ, ctx.GetNodeName("Cast"), "");

                nodeZ.AddAttribute("to", 6);

                // Sum the number of total found words.
                var nameR = ctx.AddIntermediateVariable(null, "NumWordsFoundInt", true);
                var nodeR = ctx.CreateNode("ReduceSum", nameZ, nameR, ctx.GetNodeName("ReduceSum"), "");

                nodeR.AddAttribute("axes", axes);

                // Cast the found words to float.
                var nameRF = ctx.AddIntermediateVariable(null, "NumWordsFoundFloat", true);
                var nodeRF = ctx.CreateNode("Cast", nameR, nameRF, ctx.GetNodeName("Cast"), "");

                nodeRF.AddAttribute("to", 1);

                // Clip the number of found words to prevent division by 0.
                var nameT = ctx.AddIntermediateVariable(null, "NumWordsClippedFloat", true);
                var nodeT = ctx.CreateNode("Clip", nameRF, nameT, ctx.GetNodeName("Clip"), "");

                nodeT.AddAttribute("min", 1.0f);

                // Divide total sum by number of words found to get the average embedding vector of the input string vector.
                var nameE = ctx.AddIntermediateVariable(null, "MeanWeights", true);
                var nodeE = ctx.CreateNode("Div", new[] { nameK, nameT }, new[] { nameE }, ctx.GetNodeName("Div"), "");

                // Concatenate the final embeddings produced by the three reduction strategies.
                var nameP = dstVariableName;
                var nodeP = ctx.CreateNode("Concat", new[] { nameJ, nameE, nameL }, new[] { nameP }, ctx.GetNodeName("Concat"), "");

                nodeP.AddAttribute("axis", 1);
            }
コード例 #26
0
            public string[] SaveAsOnnxPreProcess(OnnxContext ctx, string featureColumn, bool clipToZero)
            {
                string[] outputs = new string[Predictors.Length];

                string[] localOutputNames = { DefaultColumnNames.PredictedLabel, DefaultColumnNames.Score, DefaultColumnNames.Probability };

                for (int i = 0; i < Predictors.Length; i++)
                {
                    var predictorOutputNames = new string[localOutputNames.Length];

                    predictorOutputNames[0] = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, $"{DefaultColumnNames.PredictedLabel}_{i}", true);
                    predictorOutputNames[1] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Score}_{i}", true);
                    predictorOutputNames[2] = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"{DefaultColumnNames.Probability}_{i}", true);

                    string clipInput = predictorOutputNames[2];

                    var pred = Predictors[i] as ISingleCanSaveOnnx;
                    Contracts.AssertValue(pred);
                    pred.SaveAsOnnx(ctx, predictorOutputNames, featureColumn);

                    if (clipToZero)
                    {
                        var clipOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, $"ClipOutput_{i}", true);
                        outputs[i] = clipOutput;

                        string opType   = "Clip";
                        var    zeroVar  = ctx.AddInitializer(0.0f, "Zero");
                        var    clipNode = ctx.CreateNode(opType, new[] { clipInput, zeroVar }, new[] { outputs[i] }, ctx.GetNodeName(opType), "");
                    }
                    else
                    {
                        outputs[i] = predictorOutputNames[1];
                    }
                }
                return(outputs);
            }
コード例 #27
0
        private protected override void SaveAsOnnxCore(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);
            Host.Assert(Bindings.InfoCount >= 2);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnxCore(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            /* If the probability column was generated, then the classification threshold is set to 0.5. Otherwise,
             * the predicted label is based on the sign of the score.
             */
            string   opType = "Binarizer";
            OnnxNode node;
            var      binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true);

            if (Bindings.InfoCount >= 3)
            {
                Host.Assert(ctx.ContainsColumn(outColumnNames[2]));
                node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[2]), binarizerOutput, ctx.GetNodeName(opType));
                node.AddAttribute("threshold", 0.5);
            }
            else
            {
                node = ctx.CreateNode(opType, ctx.GetVariableName(outColumnNames[1]), binarizerOutput, ctx.GetNodeName(opType));
                node.AddAttribute("threshold", 0.0);
            }
            opType = "Cast";
            node   = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), "");
            var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Boolean).ToType();

            node.AddAttribute("to", t);
        }
コード例 #28
0
        /// <summary>
        /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/>
        /// </summary>
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            float[] featureHistogram       = new float[_featureHistogram[0].Length * _labelHistogram.Length];
            float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];

            for (int i = 0; i < _featureHistogram.Length; i++)
            {
                Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
            }
            for (int i = 0; i < _featureHistogram[0].Length; i++)
            {
                Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
            }

            var one            = ctx.AddInitializer(1.0f, "one");
            var oneInt         = ctx.AddInitializer(1, typeof(int), "oneInt");
            var zero           = ctx.AddInitializer(0.0f, "zero");
            var labelCount     = ctx.AddInitializer((float)_labelCount, "labelCount");
            var trainingCount  = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
            var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");

            var featureHistogramName        = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
            var labelHistogramName          = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
            var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");

            var typeOne        = new VectorDataViewType(NumberDataViewType.Single, 1);
            var typeFea        = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length);
            var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length);
            var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1);

            var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput");
            var opType        = "Greater";

            ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput");
            var node       = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), "");
            var t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();

            node.AddAttribute("to", t);

            opType = "ExpandDims";
            var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent");

            ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft");

            //initialize logProb
            opType = "Div";
            var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput");

            ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");

            opType = "Log";
            var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput");

            ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");

            //log1
            opType = "Sum";
            var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput");

            ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);

            //log2
            opType = "Transpose";
            var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose");

            ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts");

            ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);

            //result
            opType = "Sub";
            var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb");

            ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb");

            ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");

            opType = "ReduceSum";
            var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");

            node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
            long[] list = { 2 };
            node.AddAttribute("axes", list);

            opType = "ReduceSum";
            var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");

            node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axes", list);

            opType     = "Cast";
            castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput");
            node       = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
            t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            opType = "Sub";
            var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput");

            ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput");
            ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Squeeze";
            var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), "");

            squeezeNode.AddAttribute("axes", new long[] { 2 });

            opType = "ArgMax";
            var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex");

            node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axis", 1);
            node.AddAttribute("keepdims", 0);

            opType     = "Cast";
            castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput");
            node       = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
            t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput");
            ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            node   = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
            t      = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
            node.AddAttribute("to", t);

            return(true);
        }
コード例 #29
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            Contracts.AssertValue(ctx);
            Contracts.Assert(0 <= iinfo && iinfo < Infos.Length);
            Contracts.Assert(Infos[iinfo] == info);
            Contracts.Assert(CanSaveOnnx);

            if (info.TypeSrc.ValueCount == 0)
            {
                return(false);
            }

            if (_functions[iinfo].CanSaveOnnx)
            {
                string opType = "Scaler";
                var    node   = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));
                _functions[iinfo].OnnxInfo(ctx, node, info.TypeSrc.ValueCount);
                return(true);
            }

            return(false);
        }