/// <summary> /// Performs reverse-mode (adjoint) automatic differentiation /// </summary> /// <param name="builder"></param> /// <param name="dependentVariable"></param> /// <param name="independentVariables"></param> public static void Differentiate(BlockExpressionBuilder builder, out IList <Expression> derivativeExpressions, VectorParameterExpression dependentVariable, params VectorParameterExpression[] independentVariables) { if (independentVariables.Length == 0) { derivativeExpressions = null; return; } var block = builder.ToBlock(); // change the numbering of variables; arguments first, then locals List <int>[] jLookup; Function[] f; // we want a list of functions which can be unary or binary (in principle high order if it makes sense) and a look-up // each function is associated with a parameter, which is either an argument or a local // for each function index we return a list of the indices of the functions that reference the parameter GetFunctions(block, out f, out jLookup); int N = dependentVariable.Index; bool[] derivativeRequired; int[] derivativeExpressionIndex; IdentifyNeccesaryDerivatives(N, jLookup, independentVariables, out derivativeRequired, out derivativeExpressionIndex); // the list of operations needed to calculate the derivatives (*all* derivatives) derivativeExpressions = new Expression[independentVariables.Length]; var dxNdx = new Expression[N + 1]; dxNdx[N] = new ConstantExpression <double>(1.0); for (int i = N - 1; i >= 0; i--) { if (!derivativeRequired[i]) { continue; } // dxN / dxi // need to find all operation indices j such that p(j) contains i // that is, all functions that have xi as an argument (these must therefore have a higher index than i) VectorParameterExpression total = new ConstantExpression <double>(0); var xi = f[i].Parameter; //need sum of dfj/dxi dXN/dxj foreach (var j in jLookup[i]) { var fj = f[j]; var dfjdxi = Differentiate(fj, xi, builder); // dfj/dxi var dXNdxj = dxNdx[j]; // dXN/dxj var product = builder.AddProductExpression(dfjdxi, dXNdxj); total = builder.AddAdditionExpression(total, product); } dxNdx[i] = total; int targetIndex = derivativeExpressionIndex[i]; if (targetIndex != -1) { derivativeExpressions[targetIndex] = total; } } }
private static void UpdateOperationIndices(List <int>[] operationIndices, VectorParameterExpression expression, int operationIndex) { if (expression.Index != -1) { operationIndices[expression.Index].Add(operationIndex); // -1 indices a constant } }
private static void GetFunctions(VectorBlockExpression block, out Function[] f, out List <int>[] jLookup) { int N = block.ArgumentParameters.Count + block.LocalParameters.Count; // this stores the indices of the operations which make use of expression i jLookup = new List <int> [N]; for (int i = 0; i < N; ++i) { jLookup[i] = new List <int>(); } f = new Function[N]; // The first set of functions are simply the arguments for (int i = 0; i < block.ArgumentParameters.Count; ++i) { f[i] = new Function() { Expression = block.ArgumentParameters[i], Parameter = block.ArgumentParameters[i] }; } // The next set are the operations for (int i = 0; i < block.Operations.Count; ++i) { var assignment = block.Operations[i]; if (!(assignment.Left is VectorParameterExpression)) { throw new NotImplementedException(); } VectorParameterExpression parameter = assignment.Left as VectorParameterExpression; int operationIndex = parameter.Index; //if (operationIndex != i + block.ArgumentParameters.Count) throw new Exception("index mismatch"); var operation = assignment.Right; f[operationIndex] = new Function() { Expression = assignment.Right, Parameter = parameter }; if (operation is BinaryExpression) { UpdateOperationIndices(jLookup, (operation as BinaryExpression).Left as VectorParameterExpression, operationIndex); UpdateOperationIndices(jLookup, (operation as BinaryExpression).Right as VectorParameterExpression, operationIndex); } else if (operation is UnaryMathsExpression) { UpdateOperationIndices(jLookup, (operation as UnaryMathsExpression).Operand, operationIndex); } } }