コード例 #1
0
 public override void Visit(GradientsNode node)
 {
     Emit("{\ntensorBackwards(" + node.tensorID + ");\n");
     for (int i = 0; i < node.GradVariables.Count; i++)
     {
         Emit("Tensor* " + node.GradVariables[i] + " = readGradients(" + node.GradTensors[i] + ");\n");
     }
     Visit(node.Body);
     Emit("zeroGradients(" + node.tensorID + ");\n}\n");
 }
コード例 #2
0
ファイル: ASTBuilder.cs プロジェクト: Deaxz/ML4D
        public override Node VisitGradientsStmt(ML4DParser.GradientsStmtContext context)
        {
            GradientsNode gradientsNode = new GradientsNode(context.tensor.Text, (LinesNode)Visit(context.body));

            for (int i = 0; i < context._gradvar.Count; i++)
            {
                gradientsNode.GradVariables.Add(context._gradvar[i].Text);
                gradientsNode.GradTensors.Add(context._gradtensor[i].Text);
            }
            return(gradientsNode);
        }
コード例 #3
0
        public override void Visit(GradientsNode node)
        {
            SymbolTable.OpenScope();

            Symbol tensorDCL = SymbolTable.Retrieve(node.tensorID);

            if (SymbolTable.Retrieve(node.tensorID) is null)
            {
                throw new VariableNotDeclaredException(
                          $"The variable \"{tensorDCL.Name}\" cannot be assigned, as it has not been declared.");
            }
            if (tensorDCL.Type != "tensor")
            {
                throw new Exception($"Variable \"{tensorDCL.Name}\" is not a tensor. Grads can only be derived from tensors.");
            }

            for (int i = 0; i < node.GradVariables.Count; i++)
            {
                string gradVar    = node.GradVariables[i];
                string gradTensor = node.GradTensors[i];

                Symbol symbol = SymbolTable.Retrieve(gradTensor);

                if (symbol is TensorSymbol tensorSymbol) // Er null check nødvendigt? eller implicit i "is" operator?
                {
                    if (SymbolTable.Retrieve(gradVar) is null)
                    {
                        SymbolTable.Insert(gradVar, "tensor", false, tensorSymbol.Rows, tensorSymbol.Columns);
                    }
                    else
                    {
                        throw new VariableAlreadyDeclaredException(
                                  $"The variable \"{gradVar}\" could not be declared, as it has already been declared in the current or parent scope.");
                    }
                }
                else if (symbol is null)
                {
                    throw new VariableNotDeclaredException(
                              $"The variable \"{gradTensor}\" cannot be assigned, as it has not been declared.");
                }
                else
                {
                    throw new Exception(
                              $"Variable \"{symbol.Name}\" is not a tensor. Grads can only be derived from tensors.");
                }
            }

            base.Visit(node);

            SymbolTable.CloseScope();
        }
コード例 #4
0
 public virtual void Visit(GradientsNode node)
 {
     VisitChildren(node);
 }