public float TrainOne(Sent sent) { G.Need = true; var c = new ArcStandardConfig(sent, false, _rngChoice); GetTokenRepr(sent, true); var losses = 0f; while (!c.IsTerminal()) { var oracle = c.GetOracle(_conf.OraType, out float[] target, out string slabel); GetTrans(c, out Tensor op, out Tensor label); G.SoftmaxWithCrossEntropy(op, target, out float loss); losses += loss; if (slabel != null) { // in mapping 0 is unk, 1 is root var labelid = _train.DepLabel[slabel] - 1; G.SoftmaxWithCrossEntropy(label, labelid, out loss); losses += loss; } c.Apply(oracle, slabel); } G.Backward(); _opt.Update(FixedParams); _opt.Update(VariedParams); VariedParams.Clear(); G.Clear(); return(losses / sent.Count); }
public void PredictOne(Sent sent) { G.Need = false; var c = new ArcStandardConfig(sent, false, null); GetTokenRepr(sent, false); while (!c.IsTerminal()) { GetTrans(c, out Tensor op, out Tensor label); // in labelid 0 is root // in mapping 1 is root, 0 is always unk var labelid = label.W.MaxIndex(); var plabel = _train.DepLabel[labelid + 1]; var scores = op.W.Storage; var optScore = float.NegativeInfinity; var optTrans = ArcStandardConfig.Op.shift; for (var j = 0; j < 3; ++j) { if (scores[j] > optScore && c.CanApply((ArcStandardConfig.Op)j, plabel)) { optScore = scores[j]; optTrans = (ArcStandardConfig.Op)j; } } c.Apply(optTrans, plabel); } }
private void GetTrans(ArcStandardConfig c, out Tensor op, out Tensor label) { var left = c.Stack.Back(1); var right = c.Stack.Back(0); var next = c.Buffer.Front(0); var comb = G.Concat(left?.Repr ?? _non, right?.Repr ?? _non, next?.Repr ?? _non); _non.RefCount += left == null ? 1 : 0; _non.RefCount += right == null ? 1 : 0; _non.RefCount += next == null ? 1 : 0; op = _actOutput.Step(G, G.Relu(_act.Step(G, comb))); label = _labelOutput.Step(G, G.Relu(_label.Step(G, comb))); }
public void TestOne(Sent sent) { G.Need = false; var c = new ArcStandardConfig(sent, false, null); GetTokenRepr(sent, false); while (!c.IsTerminal()) { GetTrans(c, out Tensor op, out Tensor label); // in labelid 0 is root // in mapping 1 is root, 0 is always unk var labelid = label.W.MaxIndex(); var plabel = _train.DepLabel[labelid + 1]; var opid = op.W.MaxIndex(); var optTrans = (ArcStandardConfig.Op)opid; c.Apply(optTrans, plabel); } }