コード例 #1
0
 public virtual bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;
コード例 #2
0
 bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
コード例 #3
0
        /// <summary>
        /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/>
        /// </summary>
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            const int minimumOpSetVersion = 9;

            ctx.CheckOpSetVersion(minimumOpSetVersion, "MulticlassNaiveBayes");

            float[] featureHistogram       = new float[_featureHistogram[0].Length * _labelHistogram.Length];
            float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];

            for (int i = 0; i < _featureHistogram.Length; i++)
            {
                Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
            }
            for (int i = 0; i < _featureHistogram[0].Length; i++)
            {
                Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
            }

            var one            = ctx.AddInitializer(1.0f, "one");
            var oneInt         = ctx.AddInitializer(1, typeof(int), "oneInt");
            var zero           = ctx.AddInitializer(0.0f, "zero");
            var labelCount     = ctx.AddInitializer((float)_labelCount, "labelCount");
            var trainingCount  = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
            var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");

            var featureHistogramName        = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
            var labelHistogramName          = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
            var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");

            var typeOne        = new VectorDataViewType(NumberDataViewType.Single, 1);
            var typeFea        = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length);
            var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length);
            var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1);

            var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput");
            var opType        = "Greater";

            ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput");
            var node       = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), "");
            var t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();

            node.AddAttribute("to", t);

            opType = "ExpandDims";
            var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent");

            ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft");

            //initialize logProb
            opType = "Div";
            var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput");

            ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");

            opType = "Log";
            var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput");

            ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");

            //log1
            opType = "Sum";
            var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput");

            ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);

            //log2
            opType = "Transpose";
            var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose");

            ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts");

            ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);

            //result
            opType = "Sub";
            var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb");

            ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb");

            ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");

            opType = "ReduceSum";
            var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");

            node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
            long[] list = { 2 };
            node.AddAttribute("axes", list);

            opType = "ReduceSum";
            var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");

            node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axes", list);

            opType     = "Cast";
            castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput");
            node       = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
            t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            opType = "Sub";
            var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput");

            ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput");
            ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Squeeze";
            var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), "");

            squeezeNode.AddAttribute("axes", new long[] { 2 });

            opType = "ArgMax";
            var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex");

            node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axis", 1);
            node.AddAttribute("keepdims", 0);

            opType     = "Cast";
            castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput");
            node       = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
            t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput");
            ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            node   = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
            t      = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
            node.AddAttribute("to", t);

            return(true);
        }
コード例 #4
0
 public bool CanSaveOnnx(OnnxContext ctx) => Predictors.All(pred => (pred as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true);
コード例 #5
0
 bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => CanSaveOnnxCore;
コード例 #6
0
 public bool CanSaveOnnx(OnnxContext ctx) => true;
コード例 #7
0
 bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _impl.CanSaveOnnx(ctx);
コード例 #8
0
        private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
        {
            var    columnType      = _bindings.ColumnTypes[iinfo];
            string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name;

            Type type = columnType.RawType;

            int size;

            if (columnType is VectorDataViewType && columnType.IsKnownSizeVector())
            {
                size = columnType.GetVectorSize();
            }
            else
            {
                size = 1;
            }

            if ((type == typeof(int)) ||
                (type == typeof(short)) || (type == typeof(ushort)) ||
                (type == typeof(sbyte)) || (type == typeof(byte)))
            {
                ctx.AddInitializer(new int[size], type, new long[] { 1, size }, inputColumnName, false);
            }
            else if (type == typeof(uint) || (type == typeof(ulong)))
            {
                ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, inputColumnName, false);
            }
            else if (type == typeof(bool))
            {
                ctx.AddInitializer(new bool[size], new long[] { 1, size }, inputColumnName, false);
            }
            else if (type == typeof(long))
            {
                ctx.AddInitializer(new long[size], new long[] { 1, size }, inputColumnName, false);
            }
            else if (type == typeof(float))
            {
                ctx.AddInitializer(new float[size], new long[] { 1, size }, inputColumnName, false);
            }
            else if (type == typeof(double))
            {
                ctx.AddInitializer(new double[size], new long[] { 1, size }, inputColumnName, false);
            }
            else if ((type == typeof(string)) || (columnType is TextDataViewType))
            {
                string[] values = new string[size];
                for (int i = 0; i < size; i++)
                {
                    values[i] = "";
                }

                ctx.AddInitializer(values, new long[] { 1, size }, inputColumnName, false);
            }
            else
            {
                return(false);
            }

            return(true);
        }
コード例 #9
0
 public override bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
コード例 #10
0
 public bool CanSaveOnnx(OnnxContext ctx)
 {
     return(true);
 }
コード例 #11
0
 public void SaveAsOnnx(OnnxContext ctx)
 {
     // Nothing to do.
 }
コード例 #12
0
 public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount)
 {
     return(false);
 }
コード例 #13
0
 public abstract bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount);
コード例 #14
0
        private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, DataViewType columnType)
        {
            const int minimumOpSetVersion = 9;

            ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

            Type type = columnType.RawType;

            int size;

            if (columnType is VectorDataViewType && columnType.IsKnownSizeVector())
            {
                size = columnType.GetVectorSize();
            }
            else
            {
                size = 1;
            }

            if ((type == typeof(int)) ||
                (type == typeof(short)) || (type == typeof(ushort)) ||
                (type == typeof(sbyte)) || (type == typeof(byte)))
            {
                ctx.AddInitializer(new int[size], type, new long[] { 1, size }, srcVariableName, false);
            }
            else if (type == typeof(uint) || (type == typeof(ulong)))
            {
                ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, srcVariableName, false);
            }
            else if (type == typeof(bool))
            {
                ctx.AddInitializer(new bool[size], new long[] { 1, size }, srcVariableName, false);
            }
            else if (type == typeof(long))
            {
                ctx.AddInitializer(new long[size], new long[] { 1, size }, srcVariableName, false);
            }
            else if (type == typeof(float))
            {
                ctx.AddInitializer(new float[size], new long[] { 1, size }, srcVariableName, false);
            }
            else if (type == typeof(double))
            {
                ctx.AddInitializer(new double[size], new long[] { 1, size }, srcVariableName, false);
            }
            else if ((type == typeof(string)) || (columnType is TextDataViewType))
            {
                string[] values = new string[size];
                for (int i = 0; i < size; i++)
                {
                    values[i] = "";
                }

                ctx.AddInitializer(values, new long[] { 1, size }, srcVariableName, false);
            }
            else
            {
                return(false);
            }

            return(true);
        }
コード例 #15
0
 private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;
コード例 #16
0
 public abstract bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName);
コード例 #17
0
 bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
コード例 #18
0
            private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
            {
                // Converts 1 column that is taken as input to the transform into one column of output
                //
                // Missing words are mapped to k for finding average, k + 1 for finding min, and k + 2 for finding max
                // Those spots in the dictionary contain a vector of 0s, max floats, and min floats, respectively
                //
                // Symbols:
                // j: length of latent vector of every word in the pretrained model
                // n: length of input tensor (number of words)
                // X: word input, a tensor with n elements.
                // k: # of words in pretrained model (known when transform is created)
                // S: word labels, k tensor (known when transform is created)
                // D: word embeddings, (k + 3)-by-j tensor(known when transform is created). The extra three embeddings
                //      at the end are used for out of vocab words.
                // F: location value representing missing words, equal to k
                // P: output, a j * 3 tensor
                //
                //                                                      X [n]
                //                                                       |
                //                                                     nameX
                //                                                       |
                //                           LabelEncoder (classes_strings = S [k], default_int64 = k)
                //                                                       |
                //                            /----------------------- nameY -----------------------\
                //                           /   |                       |                           \
                //     Initialize (F)-------/----|------ nameF ------> Equal                          \
                //                         /     |                       |                             \
                //                        /      |                     nameA                            \
                //                       /       |                     / |  \                            \
                //                      /        '-------------|      /  |   \                            \
                //                     /                 ------|-----/   |    \------------------          \---------
                //                    /                 /      |         |                       \                   \
                //                    |      Cast (to = int64) |  Cast (to = float)              Not                 |
                //                    |             |          |         |                        |                  |
                //                    |         nameVMin       |       nameB                    nameQ                |
                //                    |             |          |         |                        |                  |
                //                  Add ------------'          | Scale (scale = 2.0)         Cast (to = int32)       |
                //                    |                        |         |                        |                  |
                //                    |                        |     nameSMax                  nameZ                 |
                //                    |                        |         |                        |                  |
                //                    |                        | Cast (to = int64)       ReduceSum (axes = [0])      |
                //                namePMin                     |         |                        |                  |
                //                    |                        |      nameVMax                 nameR                 |
                //                    |                        |         |                        |                  |
                //                    |                        '-- Add --'                Cast (to = float)          |
                //                    |   Initialize (D [k + 3, j]   |                            |                  |
                //                    |             |                |                            |                  |
                //                    |           nameD           namePMax                     nameRF                |
                //                    |             |                |                            |                  |
                //                    |             |                |                     Clip (min = 1.0)          |
                //                    |             |                |                            |                  |
                //                    |             |                |                          nameT                |
                //                    |             |----------------|----------------------------|--------\         |
                //                    |             |                |                            |         \        |
                //                    |   /---------'-------------\  |                            |          '----\  |
                //                  Gather                        Gather                          |               Gather
                //                    |                              |                            |                  |
                //                 nameGMin                       nameGMax                        |                nameW
                //                    |                              |                            |                  |
                //            ReduceMin (axes = [0])      ReduceMax (axes = [0])                  |        ReduceSum (axes = [0])
                //                    |                              |                            |                  |
                //                    |                              |                            |                nameK
                //                    |                              |                            |                  |
                //                    |                              |                            '------- Div ------'
                //                  nameJ                          nameL                                    |
                //                    |                              |                                   nameE
                //                    |                              |                                      |
                //                    '------------------- Concat (axis = 1) -------------------------------'
                //                                                   |
                //                                                 nameP
                //                                                   |
                //                                               P [j * 3]

                long[] axes = new long[] { 0 };
                // Allocate D, a constant tensor representing word embedding weights.
                var shapeD      = new long[] { _parent._currentVocab.GetNumWords() + 3, _parent._currentVocab.Dimension };
                var wordVectors = _parent._currentVocab.WordVectors;
                var tensorD     = new List <float>();

                tensorD.AddRange(wordVectors);
                // Out-of-vocab embedding vector for combining embeddings by mean.
                tensorD.AddRange(Enumerable.Repeat(0.0f, _parent._currentVocab.Dimension));
                // Out-of-vocab embedding vector for combining embeddings by element-wise min.
                tensorD.AddRange(Enumerable.Repeat(float.MaxValue, _parent._currentVocab.Dimension));
                // Out-of-vocab embedding vector for combining embeddings by element-wise max.
                tensorD.AddRange(Enumerable.Repeat(float.MinValue, _parent._currentVocab.Dimension));
                var nameD = ctx.AddInitializer(tensorD, shapeD, "WordEmbeddingWeights");

                // Allocate F, a value representing an out-of-dictionary word.
                var tensorF = _parent._currentVocab.GetNumWords();
                var nameF   = ctx.AddInitializer(tensorF, "NotFoundValueComp");

                // Retrieve X, name of input.
                var nameX = srcVariableName;

                // Do label encoding. Out-of-vocab tokens will be mapped to the size of vocabulary. Because the index of vocabulary
                // is zero-based, the size of vocabulary is just greater then the max indexes computed from in-vocab tokens by one.
                var nameY = ctx.AddIntermediateVariable(null, "LabelEncodedInput", true);
                var nodeY = ctx.CreateNode("LabelEncoder", nameX, nameY, ctx.GetNodeName("LabelEncoder"));

                nodeY.AddAttribute("classes_strings", _parent._currentVocab.GetWordLabels());
                nodeY.AddAttribute("default_int64", _parent._currentVocab.GetNumWords());

                // Do steps necessary for min and max embedding vectors.

                // Map to boolean vector representing missing words. The following Equal produces 1 if a token is missing and 0 otherwise.
                var nameA = ctx.AddIntermediateVariable(null, "NotFoundValuesBool", true);
                var nodeA = ctx.CreateNode("Equal", new[] { nameY, nameF }, new[] { nameA }, ctx.GetNodeName("Equal"), "");

                // Cast the not found vector to a vector of floats.
                var nameB = ctx.AddIntermediateVariable(null, "NotFoundValuesFloat", true);
                var nodeB = ctx.CreateNode("Cast", nameA, nameB, ctx.GetNodeName("Cast"), "");

                nodeB.AddAttribute("to", 1);

                // Scale the not found vector to get the location bias for max weights.
                var nameSMax = ctx.AddIntermediateVariable(null, "ScaleMax", true);
                var nodeSMax = ctx.CreateNode("Scale", nameB, nameSMax, ctx.GetNodeName("Scale"), "");

                nodeSMax.AddAttribute("scale", 2.0);

                // Cast scaled word label locations to ints.
                var nameVMin = ctx.AddIntermediateVariable(null, "CastMin", true);
                var nodeVMin = ctx.CreateNode("Cast", nameA, nameVMin, ctx.GetNodeName("Cast"), "");

                nodeVMin.AddAttribute("to", 7);

                var nameVMax = ctx.AddIntermediateVariable(null, "CastMax", true);
                var nodeVMax = ctx.CreateNode("Cast", nameSMax, nameVMax, ctx.GetNodeName("Cast"), "");

                nodeVMax.AddAttribute("to", 7);

                // Add the scaled options back to originals. The outputs of the following Add operators are almost identical
                // the output of the previous LabelEncoder. The only difference is that out-of-vocab tokens are mapped to k+1
                // for applying ReduceMin and k+2 for applying ReduceMax so that out-of-vocab tokens do not affect embedding results at all.
                var namePMin = ctx.AddIntermediateVariable(null, "AddMin", true);
                var nodePMin = ctx.CreateNode("Add", new[] { nameY, nameVMin }, new[] { namePMin }, ctx.GetNodeName("Add"), "");

                var namePMax = ctx.AddIntermediateVariable(null, "AddMax", true);
                var nodePMax = ctx.CreateNode("Add", new[] { nameY, nameVMax }, new[] { namePMax }, ctx.GetNodeName("Add"), "");

                // Map encoded words to their embedding vectors, mapping missing ones to min/max.
                var nameGMin = ctx.AddIntermediateVariable(null, "GatheredMin", true);
                var nodeGMin = ctx.CreateNode("Gather", new[] { nameD, namePMin }, new[] { nameGMin }, ctx.GetNodeName("Gather"), "");

                var nameGMax = ctx.AddIntermediateVariable(null, "GatheredMax", true);
                var nodeGMax = ctx.CreateNode("Gather", new[] { nameD, namePMax }, new[] { nameGMax }, ctx.GetNodeName("Gather"), "");

                // Merge all embedding vectors using element-wise min/max per embedding coordinate.
                var nameJ = ctx.AddIntermediateVariable(null, "MinWeights", true);
                var nodeJ = ctx.CreateNode("ReduceMin", nameGMin, nameJ, ctx.GetNodeName("ReduceMin"), "");

                nodeJ.AddAttribute("axes", axes);

                var nameL = ctx.AddIntermediateVariable(null, "MaxWeights", true);
                var nodeL = ctx.CreateNode("ReduceMax", nameGMax, nameL, ctx.GetNodeName("ReduceMax"), "");

                nodeL.AddAttribute("axes", axes);

                // Do steps necessary for mean embedding vector.

                // Map encoded words to their embedding vectors using Gather.
                var nameW = ctx.AddIntermediateVariable(null, "GatheredMean", true);
                var nodeW = ctx.CreateNode("Gather", new[] { nameD, nameY }, new[] { nameW }, ctx.GetNodeName("Gather"), "");

                // Find the sum of the embedding vectors.
                var nameK = ctx.AddIntermediateVariable(null, "SumWeights", true);
                var nodeK = ctx.CreateNode("ReduceSum", nameW, nameK, ctx.GetNodeName("ReduceSum"), "");

                nodeK.AddAttribute("axes", axes);

                // Flip the boolean vector representing missing words to represent found words.
                var nameQ = ctx.AddIntermediateVariable(null, "FoundValuesBool", true);
                var nodeQ = ctx.CreateNode("Not", nameA, nameQ, ctx.GetNodeName("Not"), "");

                // Cast the found words vector to ints.
                var nameZ = ctx.AddIntermediateVariable(null, "FoundValuesInt", true);
                var nodeZ = ctx.CreateNode("Cast", nameQ, nameZ, ctx.GetNodeName("Cast"), "");

                nodeZ.AddAttribute("to", 6);

                // Sum the number of total found words.
                var nameR = ctx.AddIntermediateVariable(null, "NumWordsFoundInt", true);
                var nodeR = ctx.CreateNode("ReduceSum", nameZ, nameR, ctx.GetNodeName("ReduceSum"), "");

                nodeR.AddAttribute("axes", axes);

                // Cast the found words to float.
                var nameRF = ctx.AddIntermediateVariable(null, "NumWordsFoundFloat", true);
                var nodeRF = ctx.CreateNode("Cast", nameR, nameRF, ctx.GetNodeName("Cast"), "");

                nodeRF.AddAttribute("to", 1);

                // Clip the number of found words to prevent division by 0.
                var nameT = ctx.AddIntermediateVariable(null, "NumWordsClippedFloat", true);
                var nodeT = ctx.CreateNode("Clip", nameRF, nameT, ctx.GetNodeName("Clip"), "");

                nodeT.AddAttribute("min", 1.0f);

                // Divide total sum by number of words found to get the average embedding vector of the input string vector.
                var nameE = ctx.AddIntermediateVariable(null, "MeanWeights", true);
                var nodeE = ctx.CreateNode("Div", new[] { nameK, nameT }, new[] { nameE }, ctx.GetNodeName("Div"), "");

                // Concatenate the final embeddings produced by the three reduction strategies.
                var nameP = dstVariableName;
                var nodeP = ctx.CreateNode("Concat", new[] { nameJ, nameE, nameL }, new[] { nameP }, ctx.GetNodeName("Concat"), "");

                nodeP.AddAttribute("axis", 1);
            }
コード例 #19
0
 bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
コード例 #20
0
 public bool CanSaveOnnx(OnnxContext ctx) => (_parent._keepDiacritics && _parent._keepNumbers && _parent._keepPunctuations);
コード例 #21
0
 bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) => _impl.SaveAsOnnx(ctx, outputNames, featureColumn);
コード例 #22
0
 public virtual bool CanSaveOnnx(OnnxContext ctx) => false;
コード例 #23
0
 public abstract bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn);
コード例 #24
0
        /// <summary>
        /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/>
        /// </summary>
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            float[] featureHistogram       = new float[_featureHistogram[0].Length * _labelHistogram.Length];
            float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];

            for (int i = 0; i < _featureHistogram.Length; i++)
            {
                Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
            }
            for (int i = 0; i < _featureHistogram[0].Length; i++)
            {
                Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
            }

            var one            = ctx.AddInitializer(1.0f, "one");
            var zero           = ctx.AddInitializer(0.0f, "zero");
            var labelCount     = ctx.AddInitializer((float)_labelCount, "labelCount");
            var trainingCount  = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
            var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");

            var featureHistogramName        = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
            var labelHistogramName          = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
            var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");

            var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true);
            var opType        = "Greater";

            ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true);
            var node             = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), "");
            var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();

            node.AddAttribute("to", t);

            //initialize logProb
            opType = "Div";
            var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true);

            ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");

            opType = "Log";
            var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);

            ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");

            //log1
            opType = "Sum";
            var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);

            ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true);

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);

            //log2
            opType = "Transpose";
            var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true);

            ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true);

            ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
            ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true);

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
            ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true);

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);

            //result
            opType = "Sub";
            var logProb = ctx.AddIntermediateVariable(null, "LogProb", true);

            ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true);

            ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");

            opType = "ReduceSum";
            var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);

            node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
            long[] list = { 1 };
            node.AddAttribute("axes", list);

            opType = "ReduceSum";
            var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);

            node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axes", list);

            opType = "Cast";
            var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true);

            node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
            t    = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            opType = "Sub";
            var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true);

            ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
            ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Transpose";
            var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true);

            ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");

            opType = "ArgMax";
            var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true);

            ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");

            opType     = "Cast";
            castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true);
            node       = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
            t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
            ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            node   = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
            t      = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
            node.AddAttribute("to", t);

            return(true);
        }
コード例 #25
0
 private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName,
                                               string dstVariableName) => false;
コード例 #26
0
 void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) => SaveAsOnnxCore(ctx);
コード例 #27
0
 bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper?onnxMapper.CanSaveOnnx(ctx) : false;
コード例 #28
0
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            // Computation graph of distances to all centriods for a batch of examples. Note that a centriod is just
            // the center of a cluster. We use [] to denote the dimension of a variable; for example, X [3, 2] means
            // that X is a 3-by-2 tensor. In addition, for a matrix X, X^T denotes its transpose.
            //
            // Symbols:
            // l: # of examples.
            // n: # of features per input example.
            // X: input examples, l-by-n tensor.
            // C: centriods, k-by-n tensor.
            // C^2: 2-norm of all centriod vectors, its shape is [k].
            // Y: 2-norm of difference between examples and centriods, l-by-k tensor. The value at i-th row and k-th
            // column row, Y[i,k], is the distance from example i to centrioid k.
            // L: the id of the nearest centriod for each input example, its shape is [l].
            //
            // .------------------------------------------------------.
            // |                                                      |
            // |                                                      v
            // X [l, n] --> ReduceSumSquare --> X^2 [l]             Gemm (alpha=-2, transB=1) <-- C [k, n]
            //                                   |                    |
            //                                   |                    v
            //                                   `------> Add <---- -2XC^T [l, k]
            //                                             |
            //                                             v
            //                                             Z [l, k] ----------> Add <------------C^2 [k]
            //                                                                   |
            //                                                                   v
            //                                           L [l] <--- ArgMin <---  Y [l, k]

            // Allocate C, which is a constant tensor in prediction phase
            var shapeC  = new long[] { _centroids.Length, _centroids[0].Length };
            var tensorC = new List <float>();

            foreach (var centriod in _centroids)
            {
                tensorC.AddRange(centriod.DenseValues());
            }
            var nameC = ctx.AddInitializer(tensorC, shapeC, "C");

            // Save C^2 as an initializer because it's a constant.
            var shapeC2 = new long[] { _centroidL2s.Length };
            var nameC2  = ctx.AddInitializer(_centroidL2s, shapeC2, "C2");

            // Retrieve the name of X
            var nameX = featureColumn;

            // Compute X^2 from X
            var nameX2       = ctx.AddIntermediateVariable(null, "X2", true);
            var reduceNodeX2 = ctx.CreateNode("ReduceSumSquare", nameX, nameX2, ctx.GetNodeName("ReduceSumSquare"), "");

            // Compute -2XC^T. Note that Gemm always takes three inputs. Since we only have two here,
            // a dummpy one is created.
            var zeroName    = ctx.AddInitializer(new Float[] { 0f }, null, "zero");
            var nameXC2     = ctx.AddIntermediateVariable(null, "XC2", true);
            var gemmNodeXC2 = ctx.CreateNode("Gemm", new[] { nameX, nameC, zeroName }, new[] { nameXC2 }, ctx.GetNodeName("Gemm"), "");

            gemmNodeXC2.AddAttribute("alpha", -2f);
            gemmNodeXC2.AddAttribute("transB", 1);

            // Compute Z = X^2 - 2XC^T
            var nameZ    = "Z"; // ctx.AddIntermediateVariable(null, "Z", true);
            var addNodeZ = ctx.CreateNode("Add", new[] { nameX2, nameXC2 }, new[] { nameZ }, ctx.GetNodeName("Add"), "");

            // Compute Y = Z + C^2
            var nameY    = outputNames[1];
            var addNodeY = ctx.CreateNode("Add", new[] { nameZ, nameC2 }, new[] { nameY }, ctx.GetNodeName("Add"), "");

            // Compute the most-matched cluster index, L
            var nameL        = outputNames[0];
            var predictNodeL = ctx.CreateNode("ArgMin", nameY, nameL, ctx.GetNodeName("ArgMin"), "");

            predictNodeL.AddAttribute("axis", 1);
            predictNodeL.AddAttribute("keepdims", 1);

            return(true);
        }