Example #1
0
        /// <summary>
        /// Calculates the gradient (first derivative) of a node with respect to a parameter node passed (pointer node).
        /// This method calls FNodeCompacter.CompactNode if the class level static variable 'Compact' is true (by default it is set to true).
        /// The gradient calculation leaves a lot of un-needed expressions that could be cancled out.
        /// </summary>
        /// <param name="Node">The node to calculate the gradient over</param>
        /// <param name="X">The parameter we are differentiating with respect to</param>
        /// <returns>A node representing a gradient</returns>
        internal static FNode Gradient(FNode Node, FNodePointer X)
        {

            // The node is a pointer node //
            if (Node.Affinity == FNodeAffinity.PointerNode)
            {
                if ((Node as FNodePointer).PointerName == X.PointerName)
                    return new FNodeValue(Node.ParentNode, Cell.OneValue(X.ReturnAffinity()));
                else
                    return new FNodeValue(Node.ParentNode, Cell.ZeroValue(X.ReturnAffinity()));
            }

            // The node is not a function node //
            if (Node.Affinity != FNodeAffinity.ResultNode)
                return new FNodeValue(Node.ParentNode, Cell.ZeroValue(Node.ReturnAffinity()));

            // Check if the node, which we now know is a function, has X as a decendant //
            if (!FNodeAnalysis.IsDecendent(X, Node))
                return new FNodeValue(Node.ParentNode, Cell.ZeroValue(X.ReturnAffinity()));

            // Otherwise we have to do work :( //

            // Get the name signiture //
            string name_sig = (Node as FNodeResult).InnerFunction.NameSig;

            // Go through each differentiable function //
            FNode t = null;
            switch (name_sig)
            {

                case FunctionNames.UNI_PLUS:
                    t = GradientOfUniPlus(Node, X);
                    break;
                case FunctionNames.UNI_MINUS:
                    t = GradientOfUniMinus(Node, X);
                    break;

                case FunctionNames.OP_ADD:
                    t = GradientOfAdd(Node, X);
                    break;
                case FunctionNames.OP_SUB:
                    t = GradientOfSubtract(Node, X);
                    break;
                case FunctionNames.OP_MUL:
                    t = GradientOfMultiply(Node, X);
                    break;
                case FunctionNames.OP_DIV:
                    t = GradientOfDivide(Node, X);
                    break;

                case FunctionNames.FUNC_LOG:
                    t = GradientOfLog(Node, X);
                    break;
                case FunctionNames.FUNC_EXP:
                    t = GradientOfExp(Node, X);
                    break;
                case FunctionNames.FUNC_POWER:
                    t = GradientOfPowerLower(Node, X);
                    break;

                case FunctionNames.FUNC_SIN:
                    t = GradientOfSin(Node, X);
                    break;
                case FunctionNames.FUNC_COS:
                    t = GradientOfCos(Node, X);
                    break;
                case FunctionNames.FUNC_TAN:
                    t = GradientOfTan(Node, X);
                    break;

                case FunctionNames.FUNC_SINH:
                    t = GradientOfSinh(Node, X);
                    break;
                case FunctionNames.FUNC_COSH:
                    t = GradientOfCosh(Node, X);
                    break;
                case FunctionNames.FUNC_TANH:
                    t = GradientOfTanh(Node, X);
                    break;

                case FunctionNames.FUNC_LOGIT:
                    t = GradientOfLogit(Node, X);
                    break;

                case FunctionNames.FUNC_NDIST:
                    t = GradientOfNDIST(Node, X);
                    break;
                
                default:
                    throw new Exception(string.Format("Function is not differentiable : {0}", name_sig));
            }

            if (Compact)
                t = FNodeCompacter.CompactNode(t);

            return t;

        }