public override void Visit(GradientsNode node) { Emit("{\ntensorBackwards(" + node.tensorID + ");\n"); for (int i = 0; i < node.GradVariables.Count; i++) { Emit("Tensor* " + node.GradVariables[i] + " = readGradients(" + node.GradTensors[i] + ");\n"); } Visit(node.Body); Emit("zeroGradients(" + node.tensorID + ");\n}\n"); }
public override Node VisitGradientsStmt(ML4DParser.GradientsStmtContext context) { GradientsNode gradientsNode = new GradientsNode(context.tensor.Text, (LinesNode)Visit(context.body)); for (int i = 0; i < context._gradvar.Count; i++) { gradientsNode.GradVariables.Add(context._gradvar[i].Text); gradientsNode.GradTensors.Add(context._gradtensor[i].Text); } return(gradientsNode); }
public override void Visit(GradientsNode node) { SymbolTable.OpenScope(); Symbol tensorDCL = SymbolTable.Retrieve(node.tensorID); if (SymbolTable.Retrieve(node.tensorID) is null) { throw new VariableNotDeclaredException( $"The variable \"{tensorDCL.Name}\" cannot be assigned, as it has not been declared."); } if (tensorDCL.Type != "tensor") { throw new Exception($"Variable \"{tensorDCL.Name}\" is not a tensor. Grads can only be derived from tensors."); } for (int i = 0; i < node.GradVariables.Count; i++) { string gradVar = node.GradVariables[i]; string gradTensor = node.GradTensors[i]; Symbol symbol = SymbolTable.Retrieve(gradTensor); if (symbol is TensorSymbol tensorSymbol) // Er null check nødvendigt? eller implicit i "is" operator? { if (SymbolTable.Retrieve(gradVar) is null) { SymbolTable.Insert(gradVar, "tensor", false, tensorSymbol.Rows, tensorSymbol.Columns); } else { throw new VariableAlreadyDeclaredException( $"The variable \"{gradVar}\" could not be declared, as it has already been declared in the current or parent scope."); } } else if (symbol is null) { throw new VariableNotDeclaredException( $"The variable \"{gradTensor}\" cannot be assigned, as it has not been declared."); } else { throw new Exception( $"Variable \"{symbol.Name}\" is not a tensor. Grads can only be derived from tensors."); } } base.Visit(node); SymbolTable.CloseScope(); }
public virtual void Visit(GradientsNode node) { VisitChildren(node); }