/// <summary> /// Performs reverse-mode (adjoint) automatic differentiation /// </summary> /// <param name="builder"></param> /// <param name="dependentVariable"></param> /// <param name="independentVariables"></param> public static void Differentiate(BlockExpressionBuilder builder, out IList <Expression> derivativeExpressions, VectorParameterExpression dependentVariable, params VectorParameterExpression[] independentVariables) { if (independentVariables.Length == 0) { derivativeExpressions = null; return; } var block = builder.ToBlock(); // change the numbering of variables; arguments first, then locals List <int>[] jLookup; Function[] f; // we want a list of functions which can be unary or binary (in principle high order if it makes sense) and a look-up // each function is associated with a parameter, which is either an argument or a local // for each function index we return a list of the indices of the functions that reference the parameter GetFunctions(block, out f, out jLookup); int N = dependentVariable.Index; bool[] derivativeRequired; int[] derivativeExpressionIndex; IdentifyNeccesaryDerivatives(N, jLookup, independentVariables, out derivativeRequired, out derivativeExpressionIndex); // the list of operations needed to calculate the derivatives (*all* derivatives) derivativeExpressions = new Expression[independentVariables.Length]; var dxNdx = new Expression[N + 1]; dxNdx[N] = new ConstantExpression <double>(1.0); for (int i = N - 1; i >= 0; i--) { if (!derivativeRequired[i]) { continue; } // dxN / dxi // need to find all operation indices j such that p(j) contains i // that is, all functions that have xi as an argument (these must therefore have a higher index than i) VectorParameterExpression total = new ConstantExpression <double>(0); var xi = f[i].Parameter; //need sum of dfj/dxi dXN/dxj foreach (var j in jLookup[i]) { var fj = f[j]; var dfjdxi = Differentiate(fj, xi, builder); // dfj/dxi var dXNdxj = dxNdx[j]; // dXN/dxj var product = builder.AddProductExpression(dfjdxi, dXNdxj); total = builder.AddAdditionExpression(total, product); } dxNdx[i] = total; int targetIndex = derivativeExpressionIndex[i]; if (targetIndex != -1) { derivativeExpressions[targetIndex] = total; } } }
/// <summary> /// Performs reverse-mode (adjoint) automatic differentiation /// </summary> /// <param name="builder"></param> /// <param name="dependentVariable"></param> /// <param name="independentVariables"></param> public static void Differentiate(BlockExpressionBuilder builder, out IList<Expression> derivativeExpressions, VectorParameterExpression dependentVariable, params VectorParameterExpression[] independentVariables) { if (independentVariables.Length == 0) { derivativeExpressions = null; return; } var block = builder.ToBlock(); // change the numbering of variables; arguments first, then locals List<int>[] jLookup; Function[] f; // we want a list of functions which can be unary or binary (in principle high order if it makes sense) and a look-up // each function is associated with a parameter, which is either an argument or a local // for each function index we return a list of the indices of the functions that reference the parameter GetFunctions(block, out f, out jLookup); int N = dependentVariable.Index; bool[] derivativeRequired; int[] derivativeExpressionIndex; IdentifyNeccesaryDerivatives(N, jLookup, independentVariables, out derivativeRequired, out derivativeExpressionIndex); // the list of operations needed to calculate the derivatives (*all* derivatives) derivativeExpressions = new Expression[independentVariables.Length]; var dxNdx = new Expression[N + 1]; dxNdx[N] = new ConstantExpression<double>(1.0); for (int i = N - 1; i >= 0; i--) { if (!derivativeRequired[i]) continue; // dxN / dxi // need to find all operation indices j such that p(j) contains i // that is, all functions that have xi as an argument (these must therefore have a higher index than i) VectorParameterExpression total = new ConstantExpression<double>(0); var xi = f[i].Parameter; //need sum of dfj/dxi dXN/dxj foreach (var j in jLookup[i]) { var fj = f[j]; var dfjdxi = Differentiate(fj, xi, builder); // dfj/dxi var dXNdxj = dxNdx[j]; // dXN/dxj var product = builder.AddProductExpression(dfjdxi, dXNdxj); total = builder.AddAdditionExpression(total, product); } dxNdx[i] = total; int targetIndex = derivativeExpressionIndex[i]; if (targetIndex != -1) derivativeExpressions[targetIndex] = total; } }
/// <summary> /// Differentiate the function with respect to the variable /// </summary> /// <param name="function"></param> /// <param name="variable"></param> /// <returns></returns> private static Expression Differentiate(Function function, Expression variable, BlockExpressionBuilder builder) { // note that one of the function parameters is the variable var expression = function.Expression; if (expression is BinaryExpression) { BinaryExpression binaryExpression = expression as BinaryExpression; if (expression.NodeType == ExpressionType.Multiply) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) { return(Expression.Multiply(new ConstantExpression <double>(2), variable)); } if (binaryExpression.Left == variable) { return(binaryExpression.Right); } else if (binaryExpression.Right == variable) { return(binaryExpression.Left); } } else if (expression.NodeType == ExpressionType.Add) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) { return(new ConstantExpression <double>(2)); } else if (binaryExpression.Left == variable) { return(new ConstantExpression <double>(1)); } else if (binaryExpression.Right == variable) { return(new ConstantExpression <double>(1)); } } else if (expression.NodeType == ExpressionType.Subtract) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) { return(new ConstantExpression <double>(0)); } else if (binaryExpression.Left == variable) { return(new ConstantExpression <double>(1)); } else if (binaryExpression.Right == variable) { return(new ConstantExpression <double>(-1)); } } else if (expression.NodeType == ExpressionType.Divide) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) { throw new NotImplementedException(); // should not happen } if (binaryExpression.Left == variable) { var right = binaryExpression.Right as ReferencingVectorParameterExpression <double>; if (right.IsScalar) { return(builder.AddLocalAssignment <double>(1 / right.ScalarValue, new ScaleInverseExpression <double>(binaryExpression.Right as VectorParameterExpression, 1))); } else { return(builder.AddLocalAssignment <double>(new ScaleInverseExpression <double>(binaryExpression.Right as VectorParameterExpression, 1))); } } else if (binaryExpression.Right == variable) { return(builder.AddNegateDivideExpression(function.Parameter, binaryExpression.Right as VectorParameterExpression)); // i.e. x/y => x/y * -1/y } } } if (expression is UnaryMathsExpression) { var unaryExpression = expression as UnaryMathsExpression; if (unaryExpression.UnaryType == UnaryElementWiseOperation.ScaleOffset) { return(new ConstantExpression <double>((unaryExpression as ScaleOffsetExpression <double>).Scale)); } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Negate) { return(new ConstantExpression <double>(-1)); } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.ScaleInverse) { return(builder.AddNegateDivideExpression(function.Parameter, unaryExpression.Operand)); // i.e. x/y => x/y * -1/y } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Exp) { return(function.Parameter); } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Log) { return(builder.AddInverseExpression(unaryExpression.Operand)); } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.SquareRoot) { return(builder.AddHalfInverseSquareRootExpression(unaryExpression.Operand)); } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.CumulativeNormal) { return(builder.AddGaussian(unaryExpression.Operand)); } else { throw new NotImplementedException(); } } else { throw new NotImplementedException(); } }
public DeferringExecutor(IList <NArray> independentVariables) : base() { _builder = new BlockExpressionBuilder(independentVariables); }
/// <summary> /// Differentiate the function with respect to the variable /// </summary> /// <param name="function"></param> /// <param name="variable"></param> /// <returns></returns> private static Expression Differentiate(Function function, Expression variable, BlockExpressionBuilder builder) { // note that one of the function parameters is the variable var expression = function.Expression; if (expression is BinaryExpression) { BinaryExpression binaryExpression = expression as BinaryExpression; if (expression.NodeType == ExpressionType.Multiply) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) { return Expression.Multiply(new ConstantExpression<double>(2), variable); } if (binaryExpression.Left == variable) return binaryExpression.Right; else if (binaryExpression.Right == variable) return binaryExpression.Left; } else if (expression.NodeType == ExpressionType.Add) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) { return new ConstantExpression<double>(2); } else if (binaryExpression.Left == variable) return new ConstantExpression<double>(1); else if (binaryExpression.Right == variable) return new ConstantExpression<double>(1); } else if (expression.NodeType == ExpressionType.Subtract) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) { return new ConstantExpression<double>(0); } else if (binaryExpression.Left == variable) return new ConstantExpression<double>(1); else if (binaryExpression.Right == variable) return new ConstantExpression<double>(-1); } else if (expression.NodeType == ExpressionType.Divide) { if (binaryExpression.Left == variable && binaryExpression.Right == variable) throw new NotImplementedException(); // should not happen if (binaryExpression.Left == variable) { var right = binaryExpression.Right as ReferencingVectorParameterExpression<double>; if (right.IsScalar) return builder.AddLocalAssignment<double>(1 / right.ScalarValue, new ScaleInverseExpression<double>(binaryExpression.Right as VectorParameterExpression, 1)); else return builder.AddLocalAssignment<double>(new ScaleInverseExpression<double>(binaryExpression.Right as VectorParameterExpression, 1)); } else if (binaryExpression.Right == variable) { return builder.AddNegateDivideExpression(function.Parameter, binaryExpression.Right as VectorParameterExpression); // i.e. x/y => x/y * -1/y } } } if (expression is UnaryMathsExpression) { var unaryExpression = expression as UnaryMathsExpression; if (unaryExpression.UnaryType == UnaryElementWiseOperation.ScaleOffset) { return new ConstantExpression<double>((unaryExpression as ScaleOffsetExpression<double>).Scale); } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Negate) return new ConstantExpression<double>(-1); else if (unaryExpression.UnaryType == UnaryElementWiseOperation.ScaleInverse) { return builder.AddNegateDivideExpression(function.Parameter, unaryExpression.Operand); // i.e. x/y => x/y * -1/y } else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Exp) return function.Parameter; else if (unaryExpression.UnaryType == UnaryElementWiseOperation.Log) return builder.AddInverseExpression(unaryExpression.Operand); else if (unaryExpression.UnaryType == UnaryElementWiseOperation.SquareRoot) return builder.AddHalfInverseSquareRootExpression(unaryExpression.Operand); else if (unaryExpression.UnaryType == UnaryElementWiseOperation.CumulativeNormal) return builder.AddGaussian(unaryExpression.Operand); else throw new NotImplementedException(); } else throw new NotImplementedException(); }