Exemplo n.º 1
0
        public float TrainOne(Sent sent)
        {
            G.Need = true;
            var c = new ArcStandardConfig(sent, false, _rngChoice);

            GetTokenRepr(sent, true);
            var losses = 0f;

            while (!c.IsTerminal())
            {
                var oracle = c.GetOracle(_conf.OraType, out float[] target, out string slabel);
                GetTrans(c, out Tensor op, out Tensor label);
                G.SoftmaxWithCrossEntropy(op, target, out float loss);
                losses += loss;
                if (slabel != null)
                {
                    // in mapping 0 is unk, 1 is root
                    var labelid = _train.DepLabel[slabel] - 1;
                    G.SoftmaxWithCrossEntropy(label, labelid, out loss);
                    losses += loss;
                }

                c.Apply(oracle, slabel);
            }
            G.Backward();
            _opt.Update(FixedParams);
            _opt.Update(VariedParams);
            VariedParams.Clear();
            G.Clear();
            return(losses / sent.Count);
        }
Exemplo n.º 2
0
        public void PredictOne(Sent sent)
        {
            G.Need = false;

            var c = new ArcStandardConfig(sent, false, null);

            GetTokenRepr(sent, false);
            while (!c.IsTerminal())
            {
                GetTrans(c, out Tensor op, out Tensor label);

                // in labelid 0 is root
                // in mapping 1 is root, 0 is always unk
                var labelid  = label.W.MaxIndex();
                var plabel   = _train.DepLabel[labelid + 1];
                var scores   = op.W.Storage;
                var optScore = float.NegativeInfinity;
                var optTrans = ArcStandardConfig.Op.shift;

                for (var j = 0; j < 3; ++j)
                {
                    if (scores[j] > optScore && c.CanApply((ArcStandardConfig.Op)j, plabel))
                    {
                        optScore = scores[j];
                        optTrans = (ArcStandardConfig.Op)j;
                    }
                }
                c.Apply(optTrans, plabel);
            }
        }
Exemplo n.º 3
0
        private void GetTrans(ArcStandardConfig c, out Tensor op, out Tensor label)
        {
            var left  = c.Stack.Back(1);
            var right = c.Stack.Back(0);
            var next  = c.Buffer.Front(0);
            var comb  = G.Concat(left?.Repr ?? _non, right?.Repr ?? _non, next?.Repr ?? _non);

            _non.RefCount += left == null ? 1 : 0;
            _non.RefCount += right == null ? 1 : 0;
            _non.RefCount += next == null ? 1 : 0;
            op             = _actOutput.Step(G, G.Relu(_act.Step(G, comb)));
            label          = _labelOutput.Step(G, G.Relu(_label.Step(G, comb)));
        }
Exemplo n.º 4
0
        public void TestOne(Sent sent)
        {
            G.Need = false;

            var c = new ArcStandardConfig(sent, false, null);

            GetTokenRepr(sent, false);
            while (!c.IsTerminal())
            {
                GetTrans(c, out Tensor op, out Tensor label);

                // in labelid 0 is root
                // in mapping 1 is root, 0 is always unk
                var labelid = label.W.MaxIndex();
                var plabel  = _train.DepLabel[labelid + 1];


                var opid     = op.W.MaxIndex();
                var optTrans = (ArcStandardConfig.Op)opid;
                c.Apply(optTrans, plabel);
            }
        }