public void SaveAsOnnxPostProcess(OnnxContext ctx, string inputName, string[] outputNames)
            {
                Contracts.Assert(outputNames.Length >= 2);

                string opType;

                opType = "ArgMax";
                var argMaxOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "ArgMaxOutput");
                var argMaxNode   = ctx.CreateNode(opType, inputName, argMaxOutput, ctx.GetNodeName(opType), "");

                argMaxNode.AddAttribute("keepdims", 1);
                argMaxNode.AddAttribute("axis", 1);

                opType = "Add";
                var one       = ctx.AddInitializer(1);
                var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "AddOutput");
                var addNode   = ctx.CreateNode(opType, new[] { argMaxOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), "");

                opType = "Cast";
                var castToUint32Node = ctx.CreateNode(opType, addOutput, outputNames[0], ctx.GetNodeName(opType), "");
                var t2 = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();

                castToUint32Node.AddAttribute("to", t2);

                opType = "Max";
                ctx.CreateNode(opType, inputName, outputNames[1], ctx.GetNodeName(opType), "");
            }
Ejemplo n.º 2
0
            private void SaveAsOnnxCore(OnnxContext ctx, string srcVariableName, string dstVariableName)
            {
                string opType          = "Tokenizer";
                string tokenizerOutput = ctx.AddIntermediateVariable(null, "TokenizerOutput", true);
                var    node            = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft");

                node.AddAttribute("mark", _parent._useMarkerChars);
                node.AddAttribute("mincharnum", 1);
                node.AddAttribute("pad_value", "");
                node.AddAttribute("separators", new string[] { "" });

                opType = "Squeeze";
                var squeezeOutput = ctx.AddIntermediateVariable(null, "SqueezeOutput", true);

                node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), "");
                node.AddAttribute("axes", new long[] { 0 });

                opType = "LabelEncoder";
                var labelEncoderOutput = ctx.AddIntermediateVariable(null, "LabelEncoderOutput", true);

                node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType));

                IEnumerable <string> charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString());
                IEnumerable <long>   charValues  = Enumerable.Range(0, 65535).Select(x => Convert.ToInt64(x));;

                node.AddAttribute("keys_strings", charStrings);
                node.AddAttribute("values_int64s", charValues);

                opType = "Cast";
                var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
                var t        = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt16).ToType();

                castNode.AddAttribute("to", t);
            }
Ejemplo n.º 3
0
            private void SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariableName, string dstVariableName)
            {
                const int minimumOpSetVersion = 9;

                ctx.CheckOpSetVersion(minimumOpSetVersion, LoaderSignature);

                string       opType = "Tokenizer";
                DataViewType dataViewType;

                if (_isSourceVector[iinfo])
                {
                    dataViewType = new VectorDataViewType(TextDataViewType.Instance, _sourceVectorLength[iinfo]);
                }
                else
                {
                    dataViewType = TextDataViewType.Instance;
                }

                string tokenizerOutput = ctx.AddIntermediateVariable(dataViewType, "TokenizerOutput", true);
                var    node            = ctx.CreateNode(opType, srcVariableName, tokenizerOutput, ctx.GetNodeName(opType), "com.microsoft");

                node.AddAttribute("mark", _parent._useMarkerChars);
                node.AddAttribute("mincharnum", 1);
                node.AddAttribute("pad_value", "");
                node.AddAttribute("separators", new string[] { "" });

                opType = "Squeeze";
                var squeezeOutput = ctx.AddIntermediateVariable(dataViewType, "SqueezeOutput");

                node = ctx.CreateNode(opType, tokenizerOutput, squeezeOutput, ctx.GetNodeName(opType), "");
                node.AddAttribute("axes", new long[] { 1 });

                opType = "LabelEncoder";
                var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput");

                node = ctx.CreateNode(opType, squeezeOutput, labelEncoderOutput, ctx.GetNodeName(opType));

                IEnumerable <string> charStrings = Enumerable.Range(0, 65535).Select(x => ((char)x).ToString());
                IEnumerable <long>   charValues  = Enumerable.Range(0, 65535).Select(x => Convert.ToInt64(x));

                node.AddAttribute("keys_strings", charStrings);
                node.AddAttribute("values_int64s", charValues);

                opType = "Cast";
                var castNode = ctx.CreateNode(opType, labelEncoderOutput, dstVariableName, ctx.GetNodeName(opType), "");
                var t        = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt16).ToType();

                castNode.AddAttribute("to", t);
            }
Ejemplo n.º 4
0
        private static ExpressionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
        {
            Contracts.CheckValue(env, nameof(env));
            env.CheckValue(ctx, nameof(ctx));
            ctx.CheckAtModel(GetVersionInfo());

            // *** Binary format ***
            // int: number of output columns
            // for each output column:
            //   int: number of inputs
            //   foreach input
            //     int: Id of the input column name
            //   int: Id of the expression
            //   int: Id of the output column name
            //   int: the index of the vector input (or -1)
            //   int[]: The data kinds of the input columns

            var columnCount = ctx.Reader.ReadInt32();

            env.CheckDecode(columnCount > 0);

            var columns = new ColumnInfo[columnCount];

            for (int i = 0; i < columnCount; i++)
            {
                var inputSize = ctx.Reader.ReadInt32();
                env.CheckDecode(inputSize >= 0);
                var inputColumnNames = new string[inputSize];
                for (int j = 0; j < inputSize; j++)
                {
                    inputColumnNames[j] = ctx.LoadNonEmptyString();
                }
                var expression        = ctx.LoadNonEmptyString();
                var outputColumnName  = ctx.LoadNonEmptyString();
                var vectorInputColumn = ctx.Reader.ReadInt32();
                env.CheckDecode(vectorInputColumn >= -1);

                var inputTypes = new DataViewType[inputSize];
                for (int j = 0; j < inputSize; j++)
                {
                    var dataKindIndex = ctx.Reader.ReadInt32();
                    var kind          = InternalDataKindExtensions.FromIndex(dataKindIndex);
                    inputTypes[j] = ColumnTypeExtensions.PrimitiveTypeFromKind(kind);
                }
                var node = ExpressionEstimator.ParseAndBindLambda(env, expression, vectorInputColumn, inputTypes, out var perm);
                columns[i] = new ColumnInfo(env, inputColumnNames, inputTypes, expression, outputColumnName, vectorInputColumn, node, perm);
            }
            return(new ExpressionTransformer(env, columns));
        }
        public static DataKind RawKind(this DataViewType type)
        {
            InternalDataKind kind;

            if (IsVector(type))
            {
                if (InternalDataKindExtensions.TryGetDataKind(ItemType(type).RawType, out kind))
                {
                    return(Internal2DataKind(kind));
                }
            }
            if (InternalDataKindExtensions.TryGetDataKind(type.RawType, out kind))
            {
                return(Internal2DataKind(kind));
            }
            throw Contracts.ExceptNotSupp($"Unable to guess kind for type {type}.");
        }
Ejemplo n.º 6
0
        /// <summary>
        /// Creates an Onnx inferencing model by vectorizing and following the logic found in <see cref="Map"/>
        /// </summary>
        bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string featureColumn)
        {
            float[] featureHistogram       = new float[_featureHistogram[0].Length * _labelHistogram.Length];
            float[] labelHistogramExpanded = new float[_featureHistogram[0].Length * _labelHistogram.Length];

            for (int i = 0; i < _featureHistogram.Length; i++)
            {
                Array.Copy(_featureHistogram[i], 0, featureHistogram, i * _featureHistogram[i].Length, _featureHistogram[i].Length);
            }
            for (int i = 0; i < _featureHistogram[0].Length; i++)
            {
                Array.Copy(_labelHistogram, 0, labelHistogramExpanded, i * _featureHistogram.Length, _featureHistogram.Length);
            }

            var one            = ctx.AddInitializer(1.0f, "one");
            var oneInt         = ctx.AddInitializer(1, typeof(int), "oneInt");
            var zero           = ctx.AddInitializer(0.0f, "zero");
            var labelCount     = ctx.AddInitializer((float)_labelCount, "labelCount");
            var trainingCount  = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
            var labelHistogram = ctx.AddInitializer(labelHistogramExpanded.Take(_labelHistogram.Length), new long[] { _labelHistogram.Length, 1 }, "labelHistogram");

            var featureHistogramName        = ctx.AddInitializer(featureHistogram, new long[] { _featureHistogram.Length, _featureHistogram[0].Length }, "featureHistogram");
            var labelHistogramName          = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
            var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");

            var typeOne        = new VectorDataViewType(NumberDataViewType.Single, 1);
            var typeFea        = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length);
            var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length);
            var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1);

            var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput");
            var opType        = "Greater";

            ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput");
            var node       = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), "");
            var t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();

            node.AddAttribute("to", t);

            opType = "ExpandDims";
            var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent");

            ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft");

            //initialize logProb
            opType = "Div";
            var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput");

            ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");

            opType = "Log";
            var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput");

            ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");

            //log1
            opType = "Sum";
            var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput");

            ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);

            //log2
            opType = "Transpose";
            var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose");

            ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts");

            ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
            ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");

            LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);

            //result
            opType = "Sub";
            var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb");

            ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");

            opType = "Sub";
            var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb");

            ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");

            opType = "ReduceSum";
            var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");

            node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
            long[] list = { 2 };
            node.AddAttribute("axes", list);

            opType = "ReduceSum";
            var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");

            node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axes", list);

            opType     = "Cast";
            castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput");
            node       = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
            t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            opType = "Sub";
            var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput");

            ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");

            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput");
            ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Squeeze";
            var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), "");

            squeezeNode.AddAttribute("axes", new long[] { 2 });

            opType = "ArgMax";
            var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex");

            node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
            node.AddAttribute("axis", 1);
            node.AddAttribute("keepdims", 0);

            opType     = "Cast";
            castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput");
            node       = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
            t          = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
            node.AddAttribute("to", t);

            //log3
            opType    = "Sum";
            sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput");
            ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");

            opType = "Cast";
            node   = ctx.CreateNode(opType, sumOutput, outputNames[0], ctx.GetNodeName(opType), "");
            t      = InternalDataKindExtensions.ToInternalDataKind(DataKind.UInt32).ToType();
            node.AddAttribute("to", t);

            return(true);
        }