private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, bool updateVariableWeights, out AutoDiff.Term term) { if (node.Symbol is Constant) { var var = new AutoDiff.Variable(); variables.Add(var); term = var; return true; } if (node.Symbol is Variable) { var varNode = node as VariableTreeNode; var par = new AutoDiff.Variable(); parameters.Add(par); variableNames.Add(varNode.VariableName); if (updateVariableWeights) { var w = new AutoDiff.Variable(); variables.Add(w); term = AutoDiff.TermBuilder.Product(w, par); } else { term = par; } return true; } if (node.Symbol is Addition) { List<AutoDiff.Term> terms = new List<Term>(); foreach (var subTree in node.Subtrees) { AutoDiff.Term t; if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } terms.Add(t); } term = AutoDiff.TermBuilder.Sum(terms); return true; } if (node.Symbol is Subtraction) { List<AutoDiff.Term> terms = new List<Term>(); for (int i = 0; i < node.SubtreeCount; i++) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } if (i > 0) t = -t; terms.Add(t); } if (terms.Count == 1) term = -terms[0]; else term = AutoDiff.TermBuilder.Sum(terms); return true; } if (node.Symbol is Multiplication) { List<AutoDiff.Term> terms = new List<Term>(); foreach (var subTree in node.Subtrees) { AutoDiff.Term t; if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } terms.Add(t); } if (terms.Count == 1) term = terms[0]; else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b)); return true; } if (node.Symbol is Division) { List<AutoDiff.Term> terms = new List<Term>(); foreach (var subTree in node.Subtrees) { AutoDiff.Term t; if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } terms.Add(t); } if (terms.Count == 1) term = 1.0 / terms[0]; else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b)); return true; } if (node.Symbol is Logarithm) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Log(t); return true; } } if (node.Symbol is Exponential) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Exp(t); return true; } } if (node.Symbol is Square) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Power(t, 2.0); return true; } } if (node.Symbol is SquareRoot) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Power(t, 0.5); return true; } } if (node.Symbol is Sine) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = sin(t); return true; } } if (node.Symbol is Cosine) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = cos(t); return true; } } if (node.Symbol is Tangent) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = tan(t); return true; } } if (node.Symbol is Erf) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = erf(t); return true; } } if (node.Symbol is Norm) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) { term = null; return false; } else { term = norm(t); return true; } } if (node.Symbol is StartSymbol) { var alpha = new AutoDiff.Variable(); var beta = new AutoDiff.Variable(); variables.Add(beta); variables.Add(alpha); AutoDiff.Term branchTerm; if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out branchTerm)) { term = branchTerm * alpha + beta; return true; } else { term = null; return false; } } term = null; return false; }
private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) { return (double[] c, double[] x, ref double func, object o) => { func = compiledFunc.Evaluate(c, x); }; }
private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) { return (double[] c, double[] x, ref double func, double[] grad, object o) => { var tupel = compiledFunc.Differentiate(c, x); func = tupel.Item2; Array.Copy(tupel.Item1, grad, grad.Length); }; }
private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) { if (node.Symbol is Constant) { var var = new AutoDiff.Variable(); variables.Add(var); term = var; return true; } if (node.Symbol is Variable) { var varNode = node as VariableTreeNode; var par = new AutoDiff.Variable(); parameters.Add(par); variableNames.Add(varNode.VariableName); var w = new AutoDiff.Variable(); variables.Add(w); term = AutoDiff.TermBuilder.Product(w, par); return true; } if (node.Symbol is Addition) { List<AutoDiff.Term> terms = new List<Term>(); foreach (var subTree in node.Subtrees) { AutoDiff.Term t; if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) { term = null; return false; } terms.Add(t); } term = AutoDiff.TermBuilder.Sum(terms); return true; } if (node.Symbol is Subtraction) { List<AutoDiff.Term> terms = new List<Term>(); for (int i = 0; i < node.SubtreeCount; i++) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) { term = null; return false; } if (i > 0) t = -t; terms.Add(t); } term = AutoDiff.TermBuilder.Sum(terms); return true; } if (node.Symbol is Multiplication) { AutoDiff.Term a, b; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) || !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) { term = null; return false; } else { List<AutoDiff.Term> factors = new List<Term>(); foreach (var subTree in node.Subtrees.Skip(2)) { AutoDiff.Term f; if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) { term = null; return false; } factors.Add(f); } term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray()); return true; } } if (node.Symbol is Division) { // only works for at least two subtrees AutoDiff.Term a, b; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) || !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) { term = null; return false; } else { List<AutoDiff.Term> factors = new List<Term>(); foreach (var subTree in node.Subtrees.Skip(2)) { AutoDiff.Term f; if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) { term = null; return false; } factors.Add(1.0 / f); } term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray()); return true; } } if (node.Symbol is Logarithm) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Log(t); return true; } } if (node.Symbol is Exponential) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Exp(t); return true; } } if (node.Symbol is Square) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Power(t, 2.0); return true; } } if (node.Symbol is SquareRoot) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = AutoDiff.TermBuilder.Power(t, 0.5); return true; } } if (node.Symbol is Sine) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = sin(t); return true; } } if (node.Symbol is Cosine) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = cos(t); return true; } } if (node.Symbol is Tangent) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = tan(t); return true; } } if (node.Symbol is Erf) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = erf(t); return true; } } if (node.Symbol is Norm) { AutoDiff.Term t; if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) { term = null; return false; } else { term = norm(t); return true; } } if (node.Symbol is StartSymbol) { var alpha = new AutoDiff.Variable(); var beta = new AutoDiff.Variable(); variables.Add(beta); variables.Add(alpha); AutoDiff.Term branchTerm; if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) { term = branchTerm * alpha + beta; return true; } else { term = null; return false; } } term = null; return false; }