예제 #1
0
        public JToken SaveAsPfa(BoundPfaContext ctx, JToken input)
        {
            Host.CheckValue(ctx, nameof(ctx));
            Host.CheckValue(input, nameof(input));

            const string typeName = "MCLinearPredictor";
            JToken       typeDecl = typeName;

            if (ctx.Pfa.RegisterType(typeName))
            {
                JObject type = new JObject();
                type["type"] = "record";
                type["name"] = typeName;
                JArray  fields = new JArray();
                JObject jobj   = null;
                fields.Add(jobj.AddReturn("name", "coeff").AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Array(PfaUtils.Type.Double))));
                fields.Add(jobj.AddReturn("name", "const").AddReturn("type", PfaUtils.Type.Array(PfaUtils.Type.Double)));
                type["fields"] = fields;
                typeDecl       = type;
            }
            JObject predictor = new JObject();

            predictor["coeff"] = new JArray(_weights.Select(w => new JArray(w.DenseValues())));
            predictor["const"] = new JArray(_biases);
            var cell    = ctx.DeclareCell("MCLinearPredictor", typeDecl, predictor);
            var cellRef = PfaUtils.Cell(cell);

            return(PfaUtils.Call("m.link.softmax", PfaUtils.Call("model.reg.linear", input, cellRef)));
        }
예제 #2
0
        protected override JToken SaveAsPfaCore(BoundPfaContext ctx, int iinfo, ColInfo info, JToken srcToken)
        {
            Contracts.AssertValue(ctx);
            Contracts.Assert(0 <= iinfo && iinfo < Infos.Length);
            Contracts.Assert(Infos[iinfo] == info);
            Contracts.AssertValue(srcToken);
            Contracts.Assert(CanSavePfa);

            if (!info.TypeSrc.ItemType.IsText)
            {
                return(null);
            }
            var terms            = default(VBuffer <DvText>);
            TermMap <DvText> map = (TermMap <DvText>)_termMap[iinfo].Map;

            map.GetTerms(ref terms);
            var jsonMap = new JObject();

            foreach (var kv in terms.Items())
            {
                jsonMap[kv.Value.ToString()] = kv.Key;
            }
            string cellName = ctx.DeclareCell(
                "TermMap", PfaUtils.Type.Map(PfaUtils.Type.Int), jsonMap);
            JObject cellRef = PfaUtils.Cell(cellName);

            if (info.TypeSrc.IsVector)
            {
                var funcName = ctx.GetFreeFunctionName("mapTerm");
                ctx.Pfa.AddFunc(funcName, new JArray(PfaUtils.Param("term", PfaUtils.Type.String)),
                                PfaUtils.Type.Int, PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, "term"), PfaUtils.Index(cellRef, "term"), -1));
                var funcRef = PfaUtils.FuncRef("u." + funcName);
                return(PfaUtils.Call("a.map", srcToken, funcRef));
            }
            return(PfaUtils.If(PfaUtils.Call("map.containsKey", cellRef, srcToken), PfaUtils.Index(cellRef, srcToken), -1));
        }