コード例 #1
0
 internal UnaryExpression(ExpressionType nodeType, Expression expression, Type type, MethodInfo method)
 {
     this._operand = expression;
     this._method = method;
     this._nodeType = nodeType;
     this._type = type;
 }
コード例 #2
0
        /// <summary>
        /// Performs reverse-mode (adjoint) automatic differentiation 
        /// </summary>
        /// <param name="builder"></param>
        /// <param name="dependentVariable"></param>
        /// <param name="independentVariables"></param>
        public static void Differentiate(BlockExpressionBuilder builder, 
            out IList<Expression> derivativeExpressions,
            VectorParameterExpression dependentVariable,
            params VectorParameterExpression[] independentVariables)
        {
            if (independentVariables.Length == 0)
            {
                derivativeExpressions = null;
                return;
            }

            var block = builder.ToBlock();
            // change the numbering of variables; arguments first, then locals
            List<int>[] jLookup;
            Function[] f;

            // we want a list of functions which can be unary or binary (in principle high order if it makes sense) and a look-up
            // each function is associated with a parameter, which is either an argument or a local
            // for each function index we return a list of the indices of the functions that reference the parameter
            GetFunctions(block, out f, out jLookup);

            int N = dependentVariable.Index;

            bool[] derivativeRequired; int[] derivativeExpressionIndex;
            IdentifyNeccesaryDerivatives(N, jLookup, independentVariables, out derivativeRequired, out derivativeExpressionIndex);

            // the list of operations needed to calculate the derivatives (*all* derivatives)
            derivativeExpressions = new Expression[independentVariables.Length];
            var dxNdx = new Expression[N + 1];
            dxNdx[N] = new ConstantExpression<double>(1.0);
            for (int i = N - 1; i >= 0; i--)
            {
                if (!derivativeRequired[i]) continue;
                // dxN / dxi
                // need to find all operation indices j such that p(j) contains i
                // that is, all functions that have xi as an argument (these must therefore have a higher index than i)
                VectorParameterExpression total = new ConstantExpression<double>(0);
                var xi = f[i].Parameter;
                //need sum of dfj/dxi dXN/dxj
                foreach (var j in jLookup[i])
                {
                    var fj = f[j];
                    var dfjdxi = Differentiate(fj, xi, builder); // dfj/dxi
                    var dXNdxj = dxNdx[j]; // dXN/dxj
                    var product = builder.AddProductExpression(dfjdxi, dXNdxj);
                    total = builder.AddAdditionExpression(total, product);
                }
                dxNdx[i] = total;
                int targetIndex = derivativeExpressionIndex[i];
                if (targetIndex != -1) derivativeExpressions[targetIndex] = total;
            }
        }
コード例 #3
0
 internal AssignBinaryExpression(Expression left, Expression right)
     : base(left, right)
 {
 }
コード例 #4
0
 private static void CheckTypes(Expression left, Expression right)
 {
     if (left.Type != right.Type) throw new ArgumentException("expressions Types mismatch");
 }
コード例 #5
0
 public static BinaryExpression Multiply(Expression left, Expression right)
 {
     CheckTypes(left, right);
     return new SimpleBinaryExpression(ExpressionType.Multiply, left, right, left.Type);
 }
コード例 #6
0
 public static BinaryExpression MakeBinary(ExpressionType binaryType, Expression left, Expression right)
 {
     CheckTypes(left, right);
     return new SimpleBinaryExpression(binaryType, left, right, left.Type);
 }
コード例 #7
0
 public static BinaryExpression Divide(Expression left, Expression right)
 {
     CheckTypes(left, right);
     return new SimpleBinaryExpression(ExpressionType.Divide, left, right, left.Type);
 }
コード例 #8
0
 public static BinaryExpression Assign(Expression left, Expression right)
 {
     CheckTypes(left, right);
     return new AssignBinaryExpression(left, right);
 }
コード例 #9
0
 internal BinaryExpression(Expression left, Expression right)
 {
     _left = left;
     _right = right;
 }
コード例 #10
0
        /// <summary>
        /// Differentiate the function with respect to the variable
        /// </summary>
        /// <param name="function"></param>
        /// <param name="variable"></param>
        /// <returns></returns>
        private static Expression Differentiate(Function function, Expression variable, BlockExpressionBuilder builder)
        {
            // note that one of the function parameters is the variable
            var expression = function.Expression;
            if (expression is BinaryExpression)
            {
                BinaryExpression binaryExpression = expression as BinaryExpression;
                if (expression.NodeType == ExpressionType.Multiply)
                {
                    if (binaryExpression.Left == variable && binaryExpression.Right == variable)
                    {
                        return Expression.Multiply(new ConstantExpression<double>(2), variable);
                    }
                    if (binaryExpression.Left == variable) return binaryExpression.Right;
                    else if (binaryExpression.Right == variable) return binaryExpression.Left;
                }
                else if (expression.NodeType == ExpressionType.Add)
                {
                    if (binaryExpression.Left == variable && binaryExpression.Right == variable)
                    {
                        return new ConstantExpression<double>(2);
                    }
                    else if (binaryExpression.Left == variable) return new ConstantExpression<double>(1);
                    else if (binaryExpression.Right == variable) return new ConstantExpression<double>(1);
                }
                else if (expression.NodeType == ExpressionType.Subtract)
                {
                    if (binaryExpression.Left == variable && binaryExpression.Right == variable)
                    {
                        return new ConstantExpression<double>(0);
                    }
                    else if (binaryExpression.Left == variable) return new ConstantExpression<double>(1);
                    else if (binaryExpression.Right == variable) return new ConstantExpression<double>(-1);
                }
                else if (expression.NodeType == ExpressionType.Divide)
                {
                    if (binaryExpression.Left == variable && binaryExpression.Right == variable) throw new NotImplementedException(); // should not happen
                    if (binaryExpression.Left == variable)
                    {
                        var right = binaryExpression.Right as ReferencingVectorParameterExpression<double>;
                        if (right.IsScalar) return builder.AddLocalAssignment<double>(1 / right.ScalarValue, new ScaleInverseExpression<double>(binaryExpression.Right as VectorParameterExpression, 1));
                        else return builder.AddLocalAssignment<double>(new ScaleInverseExpression<double>(binaryExpression.Right as VectorParameterExpression, 1));
                    }
                    else if (binaryExpression.Right == variable)
                    {
                        return builder.AddNegateDivideExpression(function.Parameter, binaryExpression.Right as VectorParameterExpression); // i.e. x/y => x/y * -1/y
                    }
                }
            }

            if (expression is UnaryMathsExpression)
            {
                var unaryExpression = expression as UnaryMathsExpression;
                if (unaryExpression.UnaryType == UnaryElementWiseOperation.ScaleOffset)
                {
                    return new ConstantExpression<double>((unaryExpression as ScaleOffsetExpression<double>).Scale);
                }
                else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Negate) return new ConstantExpression<double>(-1);
                else if (unaryExpression.UnaryType == UnaryElementWiseOperation.ScaleInverse)
                {
                    return builder.AddNegateDivideExpression(function.Parameter, unaryExpression.Operand); // i.e. x/y => x/y * -1/y
                }
                else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Exp) return function.Parameter;
                else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Log) return builder.AddInverseExpression(unaryExpression.Operand);
                else if (unaryExpression.UnaryType == UnaryElementWiseOperation.SquareRoot) return builder.AddHalfInverseSquareRootExpression(unaryExpression.Operand);
                else if (unaryExpression.UnaryType == UnaryElementWiseOperation.CumulativeNormal) return builder.AddGaussian(unaryExpression.Operand);
                else throw new NotImplementedException();
            }
            else throw new NotImplementedException();
        }