protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            string opType = "OneHotEncoder";
            var    node   = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            OnnxUtils.NodeAddAttributes(node, "cats_int64s", Enumerable.Range(1, info.TypeSrc.ItemType.KeyCount).Select(x => (long)x));
            OnnxUtils.NodeAddAttributes(node, "zeros", true);
            ctx.AddNode(node);
            return(true);
        }
Exemplo n.º 2
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            DataKind rawKind;
            var      type = Infos[iinfo].TypeSrc;

            if (type.IsVector)
            {
                rawKind = type.AsVector.ItemType.RawKind;
            }
            else if (type.IsKey)
            {
                rawKind = type.AsKey.RawKind;
            }
            else
            {
                rawKind = type.RawKind;
            }

            if (rawKind != DataKind.R4)
            {
                return(false);
            }

            string opType = "Imputer";
            var    node   = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            OnnxUtils.NodeAddAttributes(node, "replaced_value_float", Single.NaN);

            if (!Infos[iinfo].TypeSrc.IsVector)
            {
                OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
            }
            else
            {
                if (_repIsDefault[iinfo] != null)
                {
                    OnnxUtils.NodeAddAttributes(node, "imputed_value_floats", (float[])_repValues[iinfo]);
                }
                else
                {
                    OnnxUtils.NodeAddAttributes(node, "imputed_value_float", Enumerable.Repeat((float)_repValues[iinfo], 1));
                }
            }

            ctx.AddNode(node);
            return(true);
        }
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
        {
            Host.CheckValue(ctx, nameof(ctx));

            string opType = "LinearClassifier";
            var    node   = OnnxUtils.MakeNode(opType, new List <string> {
                featureColumn
            }, new List <string>(outputs), ctx.GetNodeName(opType));

            // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
            OnnxUtils.NodeAddAttributes(node, "post_transform", 0);
            OnnxUtils.NodeAddAttributes(node, "multi_class", true);
            OnnxUtils.NodeAddAttributes(node, "coefficients", _weights.SelectMany(w => w.DenseValues()));
            OnnxUtils.NodeAddAttributes(node, "intercepts", _biases);
            OnnxUtils.NodeAddAttributes(node, "classlabels_strings", _labelNames);
            ctx.AddNode(node);
            return(true);
        }
Exemplo n.º 4
0
        public bool SaveAsOnnx(OnnxContext ctx, string[] outputs, string featureColumn)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Check(Utils.Size(outputs) == 1);

            string opType = "LinearRegressor";
            var    node   = OnnxUtils.MakeNode(opType, new List <string> {
                featureColumn
            }, new List <string> (outputs), ctx.GetNodeName(opType));

            // Selection of logit or probit output transform. enum {'NONE', 'LOGIT', 'PROBIT}
            OnnxUtils.NodeAddAttributes(node, "post_transform", 0);
            OnnxUtils.NodeAddAttributes(node, "targets", 1);
            OnnxUtils.NodeAddAttributes(node, "coefficients", Weight.DenseValues());
            OnnxUtils.NodeAddAttributes(node, "intercepts", Bias);
            ctx.AddNode(node);
            return(true);
        }
Exemplo n.º 5
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            if (!info.TypeSrc.ItemType.IsText)
            {
                return(false);
            }

            var terms            = default(VBuffer <DvText>);
            TermMap <DvText> map = (TermMap <DvText>)_termMap[iinfo].Map;

            map.GetTerms(ref terms);
            string opType = "LabelEncoder";
            var    node   = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            OnnxUtils.NodeAddAttributes(node, "classes_strings", terms.DenseValues());
            OnnxUtils.NodeAddAttributes(node, "default_int64", -1);
            OnnxUtils.NodeAddAttributes(node, "default_string", DvText.Empty);
            ctx.AddNode(node);
            return(true);
        }
Exemplo n.º 6
0
        protected override bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string srcVariableName, string dstVariableName)
        {
            Contracts.AssertValue(ctx);
            Contracts.Assert(0 <= iinfo && iinfo < Infos.Length);
            Contracts.Assert(Infos[iinfo] == info);
            Contracts.Assert(CanSaveOnnx);

            if (info.TypeSrc.ValueCount == 0)
            {
                return(false);
            }

            string opType = "Scaler";
            var    node   = OnnxUtils.MakeNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType));

            if (_functions[iinfo].OnnxInfo(ctx, new OnnxUtils.NodeProtoWrapper(node), info.TypeSrc.ValueCount))
            {
                ctx.AddNode(node);
                return(true);
            }

            return(false);
        }
Exemplo n.º 7
0
        public override void SaveAsOnnx(OnnxContext ctx)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.Assert(Bindable is IBindableCanSaveOnnx);

            if (!ctx.ContainsColumn(DefaultColumnNames.Features))
            {
                return;
            }

            base.SaveAsOnnx(ctx);
            int delta = Bindings.DerivedColumnCount;

            Host.Assert(delta == 1);

            string[] outColumnNames = new string[Bindings.InfoCount]; //PredictedLabel, Score, Probability.
            for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo)
            {
                outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo));
            }

            //Check if "Probability" column was generated by the base class, only then
            //label can be predicted.
            if (Bindings.InfoCount >= 3 && ctx.ContainsColumn(outColumnNames[2]))
            {
                string opType = "Binarizer";
                var    node   = OnnxUtils.MakeNode(opType, new List <string> {
                    ctx.GetVariableName(outColumnNames[2])
                },
                                                   new List <string> {
                    ctx.GetVariableName(outColumnNames[0])
                }, ctx.GetNodeName(opType));

                OnnxUtils.NodeAddAttributes(node, "threshold", 0.5);
                ctx.AddNode(node);
            }
        }