/// <summary> /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. /// </summary> /// <param name="registry">The registry containing the required values for this hook's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { // we need copies of network and optimiser as to not affect the current internal state INetwork network = (INetwork)resolver.ResolveGetSingle <INetwork>("network.self").DeepCopy(); BaseGradientOptimiser optimiser = (BaseGradientOptimiser)resolver.ResolveGetSingle <BaseGradientOptimiser>("optimiser.self").ShallowCopy(); INDArray desiredTargets = ParameterRegistry.Get <INDArray>("desired_targets"); IComputationHandler handler = new DebugHandler(Operator.Handler); long[] inputShape = network.YieldExternalInputsLayerBuffers().First().Parameters.Get <long[]>("shape"); IDictionary <string, INDArray> block = DataUtils.MakeBlock("targets", desiredTargets); // desired targets don't change during execution double desiredCost = ParameterRegistry.Get <double>("desired_cost"), currentCost = Double.MaxValue; int maxOptimisationAttempts = ParameterRegistry.Get <int>("max_optimisation_attempts"); int maxOptimisationSteps = ParameterRegistry.Get <int>("max_optimisation_steps"); int optimisationSteps = 0; INDArray maximisedInputs = CreateRandomisedInput(handler, inputShape); for (int i = 0; i < maxOptimisationAttempts; i++) { optimisationSteps = 0; do { // trace current inputs and run network as normal uint traceTag = handler.BeginTrace(); block["inputs"] = handler.Trace(maximisedInputs.Reshape(ArrayUtils.Concatenate(new[] { 1L, 1L }, inputShape)), traceTag); handler.BeginSession(); DataUtils.ProvideExternalInputData(network, block); network.Run(handler, trainingPass: false); // fetch current outputs and optimise against them (towards desired targets) INDArray currentTargets = network.YieldExternalOutputsLayerBuffers().First(b => b.ExternalOutputs.Contains("external_default")) .Outputs["external_default"].Get <INDArray>("activations"); INumber squaredDifference = handler.Sum(handler.Pow(handler.Subtract(handler.FlattenTimeAndFeatures(currentTargets), desiredTargets), 2)); handler.ComputeDerivativesTo(squaredDifference); handler.EndSession(); INDArray gradient = handler.GetDerivative(block["inputs"]); maximisedInputs = handler.ClearTrace(optimiser.Optimise("inputs", block["inputs"], gradient, handler)); currentCost = squaredDifference.GetValueAs <double>(); if (currentCost <= desiredCost) { goto Validation; } } while (++optimisationSteps < maxOptimisationSteps); maximisedInputs = CreateRandomisedInput(handler, inputShape); // reset input } Validation: maximisedInputs.ReshapeSelf(inputShape); string sharedResultInput = ParameterRegistry.Get <string>("shared_result_input_key"); string sharedResultSuccess = ParameterRegistry.Get <string>("shared_result_success_key"); if (optimisationSteps >= maxOptimisationSteps) { _logger.Debug($"Aborted target maximisation for {desiredTargets}, failed after {maxOptimisationSteps} optimisation steps in {maxOptimisationAttempts} attempts (exceeded limit, current cost {currentCost} but desired {desiredCost})."); resolver.ResolveSet(sharedResultSuccess, false, addIdentifierIfNotExists: true); } else { _logger.Debug($"Successfully finished target optimisation for {desiredTargets} after {optimiser} optimisation steps."); resolver.ResolveSet(sharedResultSuccess, true, addIdentifierIfNotExists: true); resolver.ResolveSet(sharedResultInput, maximisedInputs, addIdentifierIfNotExists: true); } }