Пример #1
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            // we need copies of network and optimiser as to not affect the current internal state
            INetwork network = (INetwork)resolver.ResolveGetSingle <INetwork>("network.self").DeepCopy();
            BaseGradientOptimiser optimiser    = (BaseGradientOptimiser)resolver.ResolveGetSingle <BaseGradientOptimiser>("optimiser.self").ShallowCopy();
            INDArray            desiredTargets = ParameterRegistry.Get <INDArray>("desired_targets");
            IComputationHandler handler        = new DebugHandler(Operator.Handler);

            long[] inputShape = network.YieldExternalInputsLayerBuffers().First().Parameters.Get <long[]>("shape");

            IDictionary <string, INDArray> block = DataUtils.MakeBlock("targets", desiredTargets);            // desired targets don't change during execution

            double   desiredCost = ParameterRegistry.Get <double>("desired_cost"), currentCost = Double.MaxValue;
            int      maxOptimisationAttempts = ParameterRegistry.Get <int>("max_optimisation_attempts");
            int      maxOptimisationSteps    = ParameterRegistry.Get <int>("max_optimisation_steps");
            int      optimisationSteps       = 0;
            INDArray maximisedInputs         = CreateRandomisedInput(handler, inputShape);

            for (int i = 0; i < maxOptimisationAttempts; i++)
            {
                optimisationSteps = 0;

                do
                {
                    // trace current inputs and run network as normal
                    uint traceTag = handler.BeginTrace();
                    block["inputs"] = handler.Trace(maximisedInputs.Reshape(ArrayUtils.Concatenate(new[] { 1L, 1L }, inputShape)), traceTag);

                    handler.BeginSession();

                    DataUtils.ProvideExternalInputData(network, block);
                    network.Run(handler, trainingPass: false);

                    // fetch current outputs and optimise against them (towards desired targets)
                    INDArray currentTargets = network.YieldExternalOutputsLayerBuffers().First(b => b.ExternalOutputs.Contains("external_default"))
                                              .Outputs["external_default"].Get <INDArray>("activations");
                    INumber squaredDifference = handler.Sum(handler.Pow(handler.Subtract(handler.FlattenTimeAndFeatures(currentTargets), desiredTargets), 2));

                    handler.ComputeDerivativesTo(squaredDifference);

                    handler.EndSession();

                    INDArray gradient = handler.GetDerivative(block["inputs"]);
                    maximisedInputs = handler.ClearTrace(optimiser.Optimise("inputs", block["inputs"], gradient, handler));

                    currentCost = squaredDifference.GetValueAs <double>();

                    if (currentCost <= desiredCost)
                    {
                        goto Validation;
                    }
                } while (++optimisationSteps < maxOptimisationSteps);

                maximisedInputs = CreateRandomisedInput(handler, inputShape);                 // reset input
            }

Validation:
            maximisedInputs.ReshapeSelf(inputShape);

            string sharedResultInput   = ParameterRegistry.Get <string>("shared_result_input_key");
            string sharedResultSuccess = ParameterRegistry.Get <string>("shared_result_success_key");

            if (optimisationSteps >= maxOptimisationSteps)
            {
                _logger.Debug($"Aborted target maximisation for {desiredTargets}, failed after {maxOptimisationSteps} optimisation steps in {maxOptimisationAttempts} attempts (exceeded limit, current cost {currentCost} but desired {desiredCost}).");

                resolver.ResolveSet(sharedResultSuccess, false, addIdentifierIfNotExists: true);
            }
            else
            {
                _logger.Debug($"Successfully finished target optimisation for {desiredTargets} after {optimiser} optimisation steps.");

                resolver.ResolveSet(sharedResultSuccess, true, addIdentifierIfNotExists: true);
                resolver.ResolveSet(sharedResultInput, maximisedInputs, addIdentifierIfNotExists: true);
            }
        }
Пример #2
0
 public override string ToString()
 {
     return($"threshold criteria for \"{ParameterRegistry.Get<string>("parameter_identifier")}\" when {ParameterRegistry.Get<ComparisonTarget>("target")} {ParameterRegistry.Get<double>("threshold_value")}");
 }
Пример #3
0
 public override string ToString()
 {
     return($"extrema criteria for \"{ParameterRegistry.Get<string>("parameter_identifier")}\" when at {ParameterRegistry.Get<ExtremaTarget>("target")}");
 }
Пример #4
0
 public override string ToString()
 {
     return($"repeat criteria [{ParameterRegistry.Get<HookInvokeCriteria>("base_criteria")}] {ParameterRegistry.Get<long>("target_repetitions")} times");
 }
Пример #5
0
        public override bool CheckCriteria(IRegistry registry, IRegistryResolver resolver)
        {
            string parameter        = ParameterRegistry.Get <string>("parameter_identifier");
            object rawValue         = SimpleDirectEntries[0] ? registry.Get(parameter) : resolver.ResolveGetSingle <object>(parameter);
            double value            = (double)Convert.ChangeType(rawValue, typeof(double));
            bool   thresholdReached = _InternalThresholdReached(value, ParameterRegistry.Get <double>("threshold_value"), ParameterRegistry.Get <ComparisonTarget>("target"));
            bool   fire             = thresholdReached && (!ParameterRegistry.Get <bool>("last_check_met") || ParameterRegistry.Get <bool>("fire_continously"));

            ParameterRegistry["last_check_met"] = thresholdReached;

            return(fire);
        }
Пример #6
0
        public override MultiAndCriteria And(HookInvokeCriteria criteria)
        {
            ParameterRegistry.Get <IList <HookInvokeCriteria> >("criterias").Add(criteria);

            return(this);
        }
Пример #7
0
 public override string ToString()
 {
     return($"multi and criteria [{string.Join(" and ", ParameterRegistry.Get<IList<HookInvokeCriteria>>("criterias"))}]");
 }
Пример #8
0
 /// <summary>Returns a string that represents the current object.</summary>
 /// <returns>A string that represents the current object.</returns>
 public override string ToString()
 {
     return("not criteria [" + ParameterRegistry.Get <HookInvokeCriteria>("criteria") + "]");
 }
Пример #9
0
 public override bool CheckCriteria(IRegistry registry, IRegistryResolver resolver)
 {
     return(!ParameterRegistry.Get <HookInvokeCriteria>("criteria").CheckCriteria(registry, resolver));
 }
Пример #10
0
 protected virtual void Report(int epoch, int iteration)
 {
     _logger.Info(string.Format(ParameterRegistry.Get <string>("format_string"), epoch, iteration));
 }
        /// <summary>
        /// End a validation scoring session.
        /// Write out results here.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        protected override void ScoreEnd(IRegistry registry, IRegistryResolver resolver)
        {
            string resultKey = ParameterRegistry.Get <string>("result_key");
            double accuracy  = (double)ParameterRegistry.Get <int>("correct_classifications") / ParameterRegistry.Get <int>("total_classifications");

            resolver.ResolveSet(resultKey, accuracy, true);
        }
Пример #12
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            Action <IRegistry, IRegistryResolver> action = ParameterRegistry.Get <Action <IRegistry, IRegistryResolver> >("invoke_action");

            action.Invoke(registry, resolver);
        }