public void SaveAsOnnxPostProcess(OnnxContext ctx, string inputName, string[] outputNames) { Contracts.Assert(outputNames.Length >= 2); string opType; opType = "ArgMax"; var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "ArgMaxOutput"); var argMaxNode = ctx.CreateNode(opType, inputName, argMaxOutput, ctx.GetNodeName(opType), ""); argMaxNode.AddAttribute("keepdims", 1); argMaxNode.AddAttribute("axis", 1); opType = "Add"; var one = ctx.AddInitializer(1); var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput"); var addNode = ctx.CreateNode(opType, new[] { argMaxOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; var castToUint32Node = ctx.CreateNode(opType, addOutput, outputNames[0], ctx.GetNodeName(opType), ""); var t2 = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); castToUint32Node.AddAttribute("to", t2); opType = "Max"; ctx.CreateNode(opType, inputName, outputNames[1], ctx.GetNodeName(opType), ""); }
private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName) { string opType = "Tokenizer"; string tokenizerOutput = ctx.AddIntermediateVariable(null, "TokenizerOutput", true); var node = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft"); node.AddAttribute("mark", _parent._useMarkerChars); node.AddAttribute("mincharnum", 1); node.AddAttribute("pad_value", ""); node.AddAttribute("separators", new string[] { "" }); opType = "Squeeze"; var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true); node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), ""); node.AddAttribute("axes", new long[] { 0 }); opType = "LabelEncoder"; var labelEncoderOutput = ctx.AddIntermediateVariable(null, "LabelEncoderOutput", true); node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType)); IEnumerable <string> charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString()); IEnumerable <long> charValues = Enumerable.Range(0, 65535).Select(x => Convert.ToInt64(x));; node.AddAttribute("keys_strings", charStrings); node.AddAttribute("values_int64s", charValues); opType = "Cast"; var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt16).ToType(); castNode.AddAttribute("to", t); }
private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName) { const int minimumOpSetVersion = 9; ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature); string opType = "Tokenizer"; DataViewType dataViewType; if (_isSourceVector[iinfo]) { dataViewType = new VectorDataViewType(TextDataViewType.Instance, _sourceVectorLength[iinfo]); } else { dataViewType = TextDataViewType.Instance; } string tokenizerOutput = ctx.AddIntermediateVariable(dataViewType, "TokenizerOutput", true); var node = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft"); node.AddAttribute("mark", _parent._useMarkerChars); node.AddAttribute("mincharnum", 1); node.AddAttribute("pad_value", ""); node.AddAttribute("separators", new string[] { "" }); opType = "Squeeze"; var squeezeOutput = ctx.AddIntermediateVariable(dataViewType, "SqueezeOutput"); node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), ""); node.AddAttribute("axes", new long[] { 1 }); opType = "LabelEncoder"; var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput"); node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType)); IEnumerable <string> charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString()); IEnumerable <long> charValues = Enumerable.Range(0, 65535).Select(x => Convert.ToInt64(x)); node.AddAttribute("keys_strings", charStrings); node.AddAttribute("values_int64s", charValues); opType = "Cast"; var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt16).ToType(); castNode.AddAttribute("to", t); }
private static ExpressionTransformer Create(IHostEnvironment env, ModelLoadContext ctx) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); // *** Binary format *** // int: number of output columns // for each output column: // int: number of inputs // foreach input // int: Id of the input column name // int: Id of the expression // int: Id of the output column name // int: the index of the vector input (or -1) // int[]: The data kinds of the input columns var columnCount = ctx.Reader.ReadInt32(); env.CheckDecode(columnCount > 0); var columns = new ColumnInfo[columnCount]; for (int i = 0; i < columnCount; i++) { var inputSize = ctx.Reader.ReadInt32(); env.CheckDecode(inputSize >= 0); var inputColumnNames = new string[inputSize]; for (int j = 0; j < inputSize; j++) { inputColumnNames[j] = ctx.LoadNonEmptyString(); } var expression = ctx.LoadNonEmptyString(); var outputColumnName = ctx.LoadNonEmptyString(); var vectorInputColumn = ctx.Reader.ReadInt32(); env.CheckDecode(vectorInputColumn >= -1); var inputTypes = new DataViewType[inputSize]; for (int j = 0; j < inputSize; j++) { var dataKindIndex = ctx.Reader.ReadInt32(); var kind = InternalDataKindExtensions.FromIndex(dataKindIndex); inputTypes[j] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind); } var node = ExpressionEstimator.ParseAndBindLambda(env, expression, vectorInputColumn, inputTypes, out var perm); columns[i] = new ColumnInfo(env, inputColumnNames, inputTypes, expression, outputColumnName, vectorInputColumn, node, perm); } return(new ExpressionTransformer(env, columns)); }
public static DataKind RawKind(this DataViewType type) { InternalDataKind kind; if (IsVector(type)) { if (InternalDataKindExtensions.TryGetDataKind(ItemType(type).RawType, out kind)) { return(Internal2DataKind(kind)); } } if (InternalDataKindExtensions.TryGetDataKind(type.RawType, out kind)) { return(Internal2DataKind(kind)); } throw Contracts.ExceptNotSupp($"Unable to guess kind for type {type}."); }
/// <summary> /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/> /// </summary> bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn) { float[] featureHistogram = new float[_featureHistogram[0].Length * _labelHistogram.Length]; float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length]; for (int i = 0; i < _featureHistogram.Length; i++) { Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length); } for (int i = 0; i < _featureHistogram[0].Length; i++) { Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length); } var one = ctx.AddInitializer(1.0f, "one"); var oneInt = ctx.AddInitializer(1, typeof(int), "oneInt"); var zero = ctx.AddInitializer(0.0f, "zero"); var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount"); var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount"); var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram"); var featureHistogramName = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram"); var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded"); var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb"); var typeOne = new VectorDataViewType(NumberDataViewType.Single, 1); var typeFea = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length); var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length); var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1); var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput"); var opType = "Greater"; ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput"); var node = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), ""); var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); opType = "ExpandDims"; var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent"); ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft"); //initialize logProb opType = "Div"; var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput"); ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), ""); opType = "Log"; var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput"); ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), ""); //log1 opType = "Sum"; var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput"); ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput"); LogMul(ctx, sumOutput, isFeaturePresent, logOutput1); //log2 opType = "Transpose"; var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose"); ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), ""); opType = "Sub"; var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts"); ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), ""); opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput"); ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput"); LogMul(ctx, sumOutput, isFeaturePresent, logOutput2); //log3 opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput"); ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput"); LogMul(ctx, sumOutput, isFeaturePresent, logOutput3); //result opType = "Sub"; var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb"); ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), ""); opType = "Sub"; var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb"); ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), ""); opType = "ReduceSum"; var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum"); node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), ""); long[] list = { 2 }; node.AddAttribute("axes", list); opType = "ReduceSum"; var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum"); node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), ""); node.AddAttribute("axes", list); opType = "Cast"; castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput"); node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); opType = "Sub"; var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput"); ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), ""); opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput"); ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); opType = "Squeeze"; var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), ""); squeezeNode.AddAttribute("axes", new long[] { 2 }); opType = "ArgMax"; var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex"); node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), ""); node.AddAttribute("axis", 1); node.AddAttribute("keepdims", 0); opType = "Cast"; castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput"); node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); node.AddAttribute("to", t); //log3 opType = "Sum"; sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput"); ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), ""); opType = "Cast"; node = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), ""); t = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType(); node.AddAttribute("to", t); return(true); }