public virtual bool SaveAsOnnx(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => true;
/// <summary> /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/> /// </summary> bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { const int minimumOpSetVersion = 9; ctx.CheckOpSetVersion(minimumOpSetVersion, "MulticlassNaiveBayes"); float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length]; float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length]; for (int i = 0; i < _featureHistogram.Length; i++) { Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length); } for (int i = 0; i < _featureHistogram[0].Length; i++) { Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length); } var one = ctx.AddInitializer(1.0f, "one"); var oneInt = ctx.AddInitializer(1, typeof(int), "oneInt"); var zero = ctx.AddInitializer(0.0f, "zero"); var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount"); var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount"); var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram"); var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram"); var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded"); var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb"); var typeOne = new VectorDataViewType(NumberDataViewType.Single, 1); var typeFea = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length); var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length); var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1); var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput"); var opType = "Greater"; ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput"); var node = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); opType = "ExpandDims"; var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent"); ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft"); //initialize logProb opType = "Div"; var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput"); ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), ""); opType = "Log"; var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput"); ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), ""); //log1 opType = "Sum"; var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput"); ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput"); LogMul(ctx, sumOutput, isFeaturePresent, logOutput1); //log2 opType = "Transpose"; var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose"); ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), ""); opType = "Sub"; var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts"); ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), ""); opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput"); ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput"); LogMul(ctx, sumOutput, isFeaturePresent, logOutput2); //log3 opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput"); ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput"); LogMul(ctx, sumOutput, isFeaturePresent, logOutput3); //result opType = "Sub"; var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb"); ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), ""); opType = "Sub"; var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb"); ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), ""); opType = "ReduceSum"; var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum"); node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), ""); long[] list = { 2 }; node.AddAttribute("axes", list); opType = "ReduceSum"; var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum"); node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), ""); node.AddAttribute("axes", list); opType = "Cast"; castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput"); node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); opType = "Sub"; var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput"); ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), ""); opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput"); ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); opType = "Squeeze"; var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), ""); squeezeNode.AddAttribute("axes", new long[] { 2 }); opType = "ArgMax"; var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex"); node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), ""); node.AddAttribute("axis", 1); node.AddAttribute("keepdims", 0); opType = "Cast"; castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput"); node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); //log3 opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput"); ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); node.AddAttribute("to", t); return(true); }
public bool CanSaveOnnx(OnnxContext ctx) => Predictors.All(pred => (pred as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true);
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => CanSaveOnnxCore;
public bool CanSaveOnnx(OnnxContext ctx) => true;
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _impl.CanSaveOnnx(ctx);
private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { var columnType = _bindings.ColumnTypes[iinfo]; string inputColumnName = Source.Schema[_bindings.SrcCols[iinfo]].Name; Type type = columnType.RawType; int size; if (columnType is VectorDataViewType && columnType.IsKnownSizeVector()) { size = columnType.GetVectorSize(); } else { size = 1; } if ((type == typeof(int)) || (type == typeof(short)) || (type == typeof(ushort)) || (type == typeof(sbyte)) || (type == typeof(byte))) { ctx.AddInitializer(new int[size], type, new long[] { 1, size }, inputColumnName, false); } else if (type == typeof(uint) || (type == typeof(ulong))) { ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, inputColumnName, false); } else if (type == typeof(bool)) { ctx.AddInitializer(new bool[size], new long[] { 1, size }, inputColumnName, false); } else if (type == typeof(long)) { ctx.AddInitializer(new long[size], new long[] { 1, size }, inputColumnName, false); } else if (type == typeof(float)) { ctx.AddInitializer(new float[size], new long[] { 1, size }, inputColumnName, false); } else if (type == typeof(double)) { ctx.AddInitializer(new double[size], new long[] { 1, size }, inputColumnName, false); } else if ((type == typeof(string)) || (columnType is TextDataViewType)) { string[] values = new string[size]; for (int i = 0; i < size; i++) { values[i] = ""; } ctx.AddInitializer(values, new long[] { 1, size }, inputColumnName, false); } else { return(false); } return(true); }
public override bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental;
public bool CanSaveOnnx(OnnxContext ctx) { return(true); }
public void SaveAsOnnx(OnnxContext ctx) { // Nothing to do. }
public bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount) { return(false); }
public abstract bool OnnxInfo(OnnxContext ctx, OnnxUtils.NodeProtoWrapper nodeProtoWrapper, int featureCount);
private bool SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, DataViewType columnType) { const int minimumOpSetVersion = 9; ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); Type type = columnType.RawType; int size; if (columnType is VectorDataViewType && columnType.IsKnownSizeVector()) { size = columnType.GetVectorSize(); } else { size = 1; } if ((type == typeof(int)) || (type == typeof(short)) || (type == typeof(ushort)) || (type == typeof(sbyte)) || (type == typeof(byte))) { ctx.AddInitializer(new int[size], type, new long[] { 1, size }, srcVariableName, false); } else if (type == typeof(uint) || (type == typeof(ulong))) { ctx.AddInitializer(new ulong[size], type == typeof(ulong), new long[] { 1, size }, srcVariableName, false); } else if (type == typeof(bool)) { ctx.AddInitializer(new bool[size], new long[] { 1, size }, srcVariableName, false); } else if (type == typeof(long)) { ctx.AddInitializer(new long[size], new long[] { 1, size }, srcVariableName, false); } else if (type == typeof(float)) { ctx.AddInitializer(new float[size], new long[] { 1, size }, srcVariableName, false); } else if (type == typeof(double)) { ctx.AddInitializer(new double[size], new long[] { 1, size }, srcVariableName, false); } else if ((type == typeof(string)) || (columnType is TextDataViewType)) { string[] values = new string[size]; for (int i = 0; i < size; i++) { values[i] = ""; } ctx.AddInitializer(values, new long[] { 1, size }, srcVariableName, false); } else { return(false); } return(true); }
private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, RoleMappedSchema schema, string[] outputNames) => false;
public abstract bool SaveOnnx(OnnxContext ctx, string srcVariableName, string dstVariableName);
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (ValueMapper as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) { // Converts 1 column that is taken as input to the transform into one column of output // // Missing words are mapped to k for finding average, k + 1 for finding min, and k + 2 for finding max // Those spots in the dictionary contain a vector of 0s, max floats, and min floats, respectively // // Symbols: // j: length of latent vector of every word in the pretrained model // n: length of input tensor (number of words) // X: word input, a tensor with n elements. // k: # of words in pretrained model (known when transform is created) // S: word labels, k tensor (known when transform is created) // D: word embeddings, (k + 3)-by-j tensor(known when transform is created). The extra three embeddings // at the end are used for out of vocab words. // F: location value representing missing words, equal to k // P: output, a j * 3 tensor // // X [n] // | // nameX // | // LabelEncoder (classes_strings = S [k], default_int64 = k) // | // /----------------------- nameY -----------------------\ // / | | \ // Initialize (F)-------/----|------ nameF ------> Equal \ // / | | \ // / | nameA \ // / | / | \ \ // / '-------------| / | \ \ // / ------|-----/ | \------------------ \--------- // / / | | \ \ // | Cast (to = int64) | Cast (to = float) Not | // | | | | | | // | nameVMin | nameB nameQ | // | | | | | | // Add ------------' | Scale (scale = 2.0) Cast (to = int32) | // | | | | | // | | nameSMax nameZ | // | | | | | // | | Cast (to = int64) ReduceSum (axes = [0]) | // namePMin | | | | // | | nameVMax nameR | // | | | | | // | '-- Add --' Cast (to = float) | // | Initialize (D [k + 3, j] | | | // | | | | | // | nameD namePMax nameRF | // | | | | | // | | | Clip (min = 1.0) | // | | | | | // | | | nameT | // | |----------------|----------------------------|--------\ | // | | | | \ | // | /---------'-------------\ | | '----\ | // Gather Gather | Gather // | | | | // nameGMin nameGMax | nameW // | | | | // ReduceMin (axes = [0]) ReduceMax (axes = [0]) | ReduceSum (axes = [0]) // | | | | // | | | nameK // | | | | // | | '------- Div ------' // nameJ nameL | // | | nameE // | | | // '------------------- Concat (axis = 1) -------------------------------' // | // nameP // | // P [j * 3] long[] axes = new long[] { 0 }; // Allocate D, a constant tensor representing word embedding weights. var shapeD = new long[] { _parent._currentVocab.GetNumWords() + 3, _parent._currentVocab.Dimension }; var wordVectors = _parent._currentVocab.WordVectors; var tensorD = new List <float>(); tensorD.AddRange(wordVectors); // Out-of-vocab embedding vector for combining embeddings by mean. tensorD.AddRange(Enumerable.Repeat(0.0f, _parent._currentVocab.Dimension)); // Out-of-vocab embedding vector for combining embeddings by element-wise min. tensorD.AddRange(Enumerable.Repeat(float.MaxValue, _parent._currentVocab.Dimension)); // Out-of-vocab embedding vector for combining embeddings by element-wise max. tensorD.AddRange(Enumerable.Repeat(float.MinValue, _parent._currentVocab.Dimension)); var nameD = ctx.AddInitializer(tensorD, shapeD, "WordEmbeddingWeights"); // Allocate F, a value representing an out-of-dictionary word. var tensorF = _parent._currentVocab.GetNumWords(); var nameF = ctx.AddInitializer(tensorF, "NotFoundValueComp"); // Retrieve X, name of input. var nameX = srcVariableName; // Do label encoding. Out-of-vocab tokens will be mapped to the size of vocabulary. Because the index of vocabulary // is zero-based, the size of vocabulary is just greater then the max indexes computed from in-vocab tokens by one. var nameY = ctx.AddIntermediateVariable(null, "LabelEncodedInput", true); var nodeY = ctx.CreateNode("LabelEncoder", nameX, nameY, ctx.GetNodeName("LabelEncoder")); nodeY.AddAttribute("classes_strings", _parent._currentVocab.GetWordLabels()); nodeY.AddAttribute("default_int64", _parent._currentVocab.GetNumWords()); // Do steps necessary for min and max embedding vectors. // Map to boolean vector representing missing words. The following Equal produces 1 if a token is missing and 0 otherwise. var nameA = ctx.AddIntermediateVariable(null, "NotFoundValuesBool", true); var nodeA = ctx.CreateNode("Equal", new[] { nameY, nameF }, new[] { nameA }, ctx.GetNodeName("Equal"), ""); // Cast the not found vector to a vector of floats. var nameB = ctx.AddIntermediateVariable(null, "NotFoundValuesFloat", true); var nodeB = ctx.CreateNode("Cast", nameA, nameB, ctx.GetNodeName("Cast"), ""); nodeB.AddAttribute("to", 1); // Scale the not found vector to get the location bias for max weights. var nameSMax = ctx.AddIntermediateVariable(null, "ScaleMax", true); var nodeSMax = ctx.CreateNode("Scale", nameB, nameSMax, ctx.GetNodeName("Scale"), ""); nodeSMax.AddAttribute("scale", 2.0); // Cast scaled word label locations to ints. var nameVMin = ctx.AddIntermediateVariable(null, "CastMin", true); var nodeVMin = ctx.CreateNode("Cast", nameA, nameVMin, ctx.GetNodeName("Cast"), ""); nodeVMin.AddAttribute("to", 7); var nameVMax = ctx.AddIntermediateVariable(null, "CastMax", true); var nodeVMax = ctx.CreateNode("Cast", nameSMax, nameVMax, ctx.GetNodeName("Cast"), ""); nodeVMax.AddAttribute("to", 7); // Add the scaled options back to originals. The outputs of the following Add operators are almost identical // the output of the previous LabelEncoder. The only difference is that out-of-vocab tokens are mapped to k+1 // for applying ReduceMin and k+2 for applying ReduceMax so that out-of-vocab tokens do not affect embedding results at all. var namePMin = ctx.AddIntermediateVariable(null, "AddMin", true); var nodePMin = ctx.CreateNode("Add", new[] { nameY, nameVMin }, new[] { namePMin }, ctx.GetNodeName("Add"), ""); var namePMax = ctx.AddIntermediateVariable(null, "AddMax", true); var nodePMax = ctx.CreateNode("Add", new[] { nameY, nameVMax }, new[] { namePMax }, ctx.GetNodeName("Add"), ""); // Map encoded words to their embedding vectors, mapping missing ones to min/max. var nameGMin = ctx.AddIntermediateVariable(null, "GatheredMin", true); var nodeGMin = ctx.CreateNode("Gather", new[] { nameD, namePMin }, new[] { nameGMin }, ctx.GetNodeName("Gather"), ""); var nameGMax = ctx.AddIntermediateVariable(null, "GatheredMax", true); var nodeGMax = ctx.CreateNode("Gather", new[] { nameD, namePMax }, new[] { nameGMax }, ctx.GetNodeName("Gather"), ""); // Merge all embedding vectors using element-wise min/max per embedding coordinate. var nameJ = ctx.AddIntermediateVariable(null, "MinWeights", true); var nodeJ = ctx.CreateNode("ReduceMin", nameGMin, nameJ, ctx.GetNodeName("ReduceMin"), ""); nodeJ.AddAttribute("axes", axes); var nameL = ctx.AddIntermediateVariable(null, "MaxWeights", true); var nodeL = ctx.CreateNode("ReduceMax", nameGMax, nameL, ctx.GetNodeName("ReduceMax"), ""); nodeL.AddAttribute("axes", axes); // Do steps necessary for mean embedding vector. // Map encoded words to their embedding vectors using Gather. var nameW = ctx.AddIntermediateVariable(null, "GatheredMean", true); var nodeW = ctx.CreateNode("Gather", new[] { nameD, nameY }, new[] { nameW }, ctx.GetNodeName("Gather"), ""); // Find the sum of the embedding vectors. var nameK = ctx.AddIntermediateVariable(null, "SumWeights", true); var nodeK = ctx.CreateNode("ReduceSum", nameW, nameK, ctx.GetNodeName("ReduceSum"), ""); nodeK.AddAttribute("axes", axes); // Flip the boolean vector representing missing words to represent found words. var nameQ = ctx.AddIntermediateVariable(null, "FoundValuesBool", true); var nodeQ = ctx.CreateNode("Not", nameA, nameQ, ctx.GetNodeName("Not"), ""); // Cast the found words vector to ints. var nameZ = ctx.AddIntermediateVariable(null, "FoundValuesInt", true); var nodeZ = ctx.CreateNode("Cast", nameQ, nameZ, ctx.GetNodeName("Cast"), ""); nodeZ.AddAttribute("to", 6); // Sum the number of total found words. var nameR = ctx.AddIntermediateVariable(null, "NumWordsFoundInt", true); var nodeR = ctx.CreateNode("ReduceSum", nameZ, nameR, ctx.GetNodeName("ReduceSum"), ""); nodeR.AddAttribute("axes", axes); // Cast the found words to float. var nameRF = ctx.AddIntermediateVariable(null, "NumWordsFoundFloat", true); var nodeRF = ctx.CreateNode("Cast", nameR, nameRF, ctx.GetNodeName("Cast"), ""); nodeRF.AddAttribute("to", 1); // Clip the number of found words to prevent division by 0. var nameT = ctx.AddIntermediateVariable(null, "NumWordsClippedFloat", true); var nodeT = ctx.CreateNode("Clip", nameRF, nameT, ctx.GetNodeName("Clip"), ""); nodeT.AddAttribute("min", 1.0f); // Divide total sum by number of words found to get the average embedding vector of the input string vector. var nameE = ctx.AddIntermediateVariable(null, "MeanWeights", true); var nodeE = ctx.CreateNode("Div", new[] { nameK, nameT }, new[] { nameE }, ctx.GetNodeName("Div"), ""); // Concatenate the final embeddings produced by the three reduction strategies. var nameP = dstVariableName; var nodeP = ctx.CreateNode("Concat", new[] { nameJ, nameE, nameL }, new[] { nameP }, ctx.GetNodeName("Concat"), ""); nodeP.AddAttribute("axis", 1); }
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => (Bindable as ICanSaveOnnx)?.CanSaveOnnx(ctx) == true;
public bool CanSaveOnnx(OnnxContext ctx) => (_parent._keepDiacritics && _parent._keepNumbers && _parent._keepPunctuations);
bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) => _impl.SaveAsOnnx(ctx, outputNames, featureColumn);
public virtual bool CanSaveOnnx(OnnxContext ctx) => false;
public abstract bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn);
/// <summary> /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/> /// </summary> bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length]; float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length]; for (int i = 0; i < _featureHistogram.Length; i++) { Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length); } for (int i = 0; i < _featureHistogram[0].Length; i++) { Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length); } var one = ctx.AddInitializer(1.0f, "one"); var zero = ctx.AddInitializer(0.0f, "zero"); var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount"); var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount"); var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram"); var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram"); var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded"); var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb"); var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true); var opType = "Greater"; ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true); var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); //initialize logProb opType = "Div"; var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true); ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), ""); opType = "Log"; var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true); ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), ""); //log1 opType = "Sum"; var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true); LogMul(ctx, sumOutput, isFeaturePresent, logOutput1); //log2 opType = "Transpose"; var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true); ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), ""); opType = "Sub"; var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true); ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), ""); opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true); LogMul(ctx, sumOutput, isFeaturePresent, logOutput2); //log3 opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true); LogMul(ctx, sumOutput, isFeaturePresent, logOutput3); //result opType = "Sub"; var logProb = ctx.AddIntermediateVariable(null, "LogProb", true); ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), ""); opType = "Sub"; var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true); ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), ""); opType = "ReduceSum"; var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true); node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), ""); long[] list = { 1 }; node.AddAttribute("axes", list); opType = "ReduceSum"; var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true); node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), ""); node.AddAttribute("axes", list); opType = "Cast"; var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true); node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); opType = "Sub"; var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true); ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), ""); opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); opType = "Transpose"; var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true); ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), ""); opType = "ArgMax"; var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true); ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), ""); opType = "Cast"; castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true); node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); //log3 opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true); ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); node.AddAttribute("to", t); return(true); }
private protected virtual bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName) => false;
void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) => SaveAsOnnxCore(ctx);
bool ICanSaveOnnx.CanSaveOnnx(OnnxContext ctx) => _mapper is ICanSaveOnnx onnxMapper?onnxMapper.CanSaveOnnx(ctx) : false;
public bool SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { // Computation graph of distances to all centriods for a batch of examples. Note that a centriod is just // the center of a cluster. We use [] to denote the dimension of a variable; for example, X [3, 2] means // that X is a 3-by-2 tensor. In addition, for a matrix X, X^T denotes its transpose. // // Symbols: // l: # of examples. // n: # of features per input example. // X: input examples, l-by-n tensor. // C: centriods, k-by-n tensor. // C^2: 2-norm of all centriod vectors, its shape is [k]. // Y: 2-norm of difference between examples and centriods, l-by-k tensor. The value at i-th row and k-th // column row, Y[i,k], is the distance from example i to centrioid k. // L: the id of the nearest centriod for each input example, its shape is [l]. // // .------------------------------------------------------. // | | // | v // X [l, n] --> ReduceSumSquare --> X^2 [l] Gemm (alpha=-2, transB=1) <-- C [k, n] // | | // | v // `------> Add <---- -2XC^T [l, k] // | // v // Z [l, k] ----------> Add <------------C^2 [k] // | // v // L [l] <--- ArgMin <--- Y [l, k] // Allocate C, which is a constant tensor in prediction phase var shapeC = new long[] { _centroids.Length, _centroids[0].Length }; var tensorC = new List <float>(); foreach (var centriod in _centroids) { tensorC.AddRange(centriod.DenseValues()); } var nameC = ctx.AddInitializer(tensorC, shapeC, "C"); // Save C^2 as an initializer because it's a constant. var shapeC2 = new long[] { _centroidL2s.Length }; var nameC2 = ctx.AddInitializer(_centroidL2s, shapeC2, "C2"); // Retrieve the name of X var nameX = featureColumn; // Compute X^2 from X var nameX2 = ctx.AddIntermediateVariable(null, "X2", true); var reduceNodeX2 = ctx.CreateNode("ReduceSumSquare", nameX, nameX2, ctx.GetNodeName("ReduceSumSquare"), ""); // Compute -2XC^T. Note that Gemm always takes three inputs. Since we only have two here, // a dummpy one is created. var zeroName = ctx.AddInitializer(new Float[] { 0f }, null, "zero"); var nameXC2 = ctx.AddIntermediateVariable(null, "XC2", true); var gemmNodeXC2 = ctx.CreateNode("Gemm", new[] { nameX, nameC, zeroName }, new[] { nameXC2 }, ctx.GetNodeName("Gemm"), ""); gemmNodeXC2.AddAttribute("alpha", -2f); gemmNodeXC2.AddAttribute("transB", 1); // Compute Z = X^2 - 2XC^T var nameZ = "Z"; // ctx.AddIntermediateVariable(null, "Z", true); var addNodeZ = ctx.CreateNode("Add", new[] { nameX2, nameXC2 }, new[] { nameZ }, ctx.GetNodeName("Add"), ""); // Compute Y = Z + C^2 var nameY = outputNames[1]; var addNodeY = ctx.CreateNode("Add", new[] { nameZ, nameC2 }, new[] { nameY }, ctx.GetNodeName("Add"), ""); // Compute the most-matched cluster index, L var nameL = outputNames[0]; var predictNodeL = ctx.CreateNode("ArgMin", nameY, nameL, ctx.GetNodeName("ArgMin"), ""); predictNodeL.AddAttribute("axis", 1); predictNodeL.AddAttribute("keepdims", 1); return(true); }