/// <summary> /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. /// </summary> /// <param name="registry">The registry containing the required values for this hook's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { if (ParameterRegistry.ContainsKey("last_time")) { bool removeExtremas = ParameterRegistry.Get <bool>("remove_extremas"); long lastTime = ParameterRegistry.Get <long>("last_time"); long currentTime = Operator.RunningTimeMilliseconds; long elapsedTime = currentTime - lastTime; LinkedList <long> lastRunningTimes = ParameterRegistry.Get <LinkedList <long> >("last_running_times"); string sharedResultBaseKey = ParameterRegistry.Get <string>("shared_result_base_key"); int averageSpan = ParameterRegistry.Get <int>("average_span"); lastRunningTimes.AddLast(elapsedTime); int numberRunningTimes = lastRunningTimes.Count; if (numberRunningTimes > averageSpan) { lastRunningTimes.RemoveFirst(); numberRunningTimes--; } long averageTime = lastRunningTimes.Sum(); if (removeExtremas) { LinkedList <long> runningTimesCopy = new LinkedList <long>(lastRunningTimes); int timesToRemove = (int)Math.Sqrt(lastRunningTimes.Count / 2.0f); // TODO magic number while (timesToRemove-- > 0) { long removedTime = timesToRemove % 2 == 0 ? runningTimesCopy.Max() : runningTimesCopy.Min(); runningTimesCopy.Remove(removedTime); } averageTime = runningTimesCopy.Sum(); numberRunningTimes = runningTimesCopy.Count; } averageTime /= numberRunningTimes; resolver.ResolveSet(sharedResultBaseKey + "_last", elapsedTime, addIdentifierIfNotExists: true); resolver.ResolveSet(sharedResultBaseKey + "_average", averageTime, addIdentifierIfNotExists: true); resolver.ResolveSet(sharedResultBaseKey + "_min", lastRunningTimes.Min(), addIdentifierIfNotExists: true); resolver.ResolveSet(sharedResultBaseKey + "_max", lastRunningTimes.Max(), addIdentifierIfNotExists: true); } ParameterRegistry["last_time"] = Operator.RunningTimeMilliseconds; }
/// <inheritdoc /> public virtual void SynchroniseSet <T>(IRegistry registry, string key, T val, Action <T> onSuccess = null, Action <Exception> onError = null) { // check if the registry is from an operator foreach (IOperator op in Sigma.RunningOperatorsByTrainer.Values) { if (ReferenceEquals(op.Registry, registry)) { //TODO: test if callback is called //TODO: on error check sources for other to set the value op.InvokeCommand(new SetValueCommand <T>(key, val, () => onSuccess?.Invoke(val))); return; } } IRegistryResolver resolver = RegistryResolvers.TryGetValue(registry, () => new RegistryResolver(registry)); // check if at least one value has been set if (resolver.ResolveSet(key, val, true, typeof(T)).Length > 0) { onSuccess?.Invoke(val); } else { onError?.Invoke(new KeyNotFoundException($"{key} was not found in {registry} and could not be created.")); } }
/// <summary> /// Invoke this hook with a certain parameter registry. /// </summary> /// <param name="registry">The registry containing the required values for this hook's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { string registryEntry = ParameterRegistry.Get <string>("registry_entry"); string resultEntry = ParameterRegistry.Get <string>("shared_result_entry"); double value = resolver.ResolveGetSingle <double>(registryEntry); double previousAccumulatedValue = ParameterRegistry.Get <double>("accumulated_value"); int currentInterval = HookUtils.GetCurrentInterval(registry, TimeStep.TimeScale); int resetInterval = ParameterRegistry.Get <int>("reset_interval"); int resetEvery = ParameterRegistry.Get <int>("reset_every"); int countSinceReset = ParameterRegistry.Get <int>("count_since_reset"); if (currentInterval == resetInterval || resetEvery > 0 && currentInterval % resetEvery == 0) { previousAccumulatedValue = 0.0; countSinceReset = 0; } countSinceReset++; double result = value + previousAccumulatedValue; if (ParameterRegistry.Get <bool>("average_mode")) { result /= countSinceReset; } ParameterRegistry["count_since_reset"] = countSinceReset; ParameterRegistry["accumulated_value"] = value + previousAccumulatedValue; resolver.ResolveSet(resultEntry, result, addIdentifierIfNotExists: true); }
/// <summary> /// End a validation scoring session. /// Write out results here. /// </summary> /// <param name="registry">The registry containing the required values for this hook's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> protected override void ScoreEnd(IRegistry registry, IRegistryResolver resolver) { string resultKey = ParameterRegistry.Get <string>("result_key"); double accuracy = (double)ParameterRegistry.Get <int>("correct_classifications") / ParameterRegistry.Get <int>("total_classifications"); resolver.ResolveSet(resultKey, accuracy, true); }
/// <summary> /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. /// </summary> /// <param name="registry">The registry containing the required values for this hook's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { IValueModifier modifier = ParameterRegistry.Get <IValueModifier>("modifier"); string identifier = ParameterRegistry.Get <string>("parameter_identifier"); object parameter = resolver.ResolveGetSingle <object>(identifier); INumber asNumber = parameter as INumber; INDArray asArray = parameter as INDArray; if (asNumber != null) { parameter = modifier.Modify(identifier, asNumber, asNumber.AssociatedHandler); } else if (asArray != null) { parameter = modifier.Modify(identifier, asArray, asArray.AssociatedHandler); } else { throw new InvalidOperationException($"Cannot apply modifier {modifier} to parameter \"{identifier}\" with value {parameter}, " + $"parameter is neither {nameof(INumber)} nor {nameof(INDArray)}."); } resolver.ResolveSet(identifier, parameter); }
/// <summary> /// Invoke this command and set all required values. /// </summary> /// <param name="registry">The registry containing the required values for this command's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { string[] keys = (string[])ParameterRegistry[KeyIdentifier]; T[] values = (T[])ParameterRegistry[ValueIdentifier]; for (int i = 0; i < keys.Length; i++) { //TODO: validate if successfully set and call error otherwise (for each key?) resolver.ResolveSet(keys[i], values[i], AddItentifierIfNotExists, typeof(T)); } }
/// <summary> /// End a validation scoring session. /// Write out results here. /// </summary> /// <param name="registry">The registry containing the required values for this hook's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> protected override void ScoreEnd(IRegistry registry, IRegistryResolver resolver) { int[] tops = ParameterRegistry.Get <int[]>("tops"); foreach (int top in tops) { string resultBaseKey = ParameterRegistry.Get <string>("result_base_key"); int totalClassifications = ParameterRegistry.Get <int>("total_classifications"); int correctClassifications = ParameterRegistry.Get <int>($"correct_classifications_top{top}"); double score = ((double)correctClassifications) / totalClassifications; resolver.ResolveSet(resultBaseKey + top, score, addIdentifierIfNotExists: true); } }
/// <inheritdoc /> public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { IComputationHandler handler = Operator.Handler; string registryEntryToProcess = ParameterRegistry.Get <string>("registry_entry_to_process"); Func <T, IComputationHandler, INumber> metricFunction = ParameterRegistry.Get <Func <T, IComputationHandler, INumber> >("metric_function"); string metricSharedResultIdentifier = ParameterRegistry.Get <string>("metric_shared_result_identifier"); object[] entries = resolver.ResolveGet <object>(registryEntryToProcess); double totalMetric = 0.0; int count = 0; foreach (object entry in entries) { T entryAsT = entry as T; IEnumerable <T> entryAsEnumerable = entry as IEnumerable <T>; IDictionary <string, T> entryAsDictionary = entry as IDictionary <string, T>; if (entryAsDictionary != null) { entryAsEnumerable = entryAsDictionary.Values; } if (entryAsT != null) { totalMetric += metricFunction.Invoke(entryAsT, handler).GetValueAs <double>(); count++; } else if (entryAsEnumerable != null) { foreach (T value in entryAsEnumerable) { totalMetric += metricFunction.Invoke(value, handler).GetValueAs <double>(); count++; } } else { throw new InvalidOperationException($"Cannot process metric for entry of type {entry.GetType()} with identifier \"{registryEntryToProcess}\", must be {typeof(T)} or enumerable thereof."); } } double resultMetric = totalMetric / count; resolver.ResolveSet(metricSharedResultIdentifier, resultMetric, addIdentifierIfNotExists: true); }
/// <summary> /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied. /// </summary> /// <param name="registry">The registry containing the required values for this hook's execution.</param> /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param> public override void SubInvoke(IRegistry registry, IRegistryResolver resolver) { // we need copies of network and optimiser as to not affect the current internal state INetwork network = (INetwork)resolver.ResolveGetSingle <INetwork>("network.self").DeepCopy(); BaseGradientOptimiser optimiser = (BaseGradientOptimiser)resolver.ResolveGetSingle <BaseGradientOptimiser>("optimiser.self").ShallowCopy(); INDArray desiredTargets = ParameterRegistry.Get <INDArray>("desired_targets"); IComputationHandler handler = new DebugHandler(Operator.Handler); long[] inputShape = network.YieldExternalInputsLayerBuffers().First().Parameters.Get <long[]>("shape"); IDictionary <string, INDArray> block = DataUtils.MakeBlock("targets", desiredTargets); // desired targets don't change during execution double desiredCost = ParameterRegistry.Get <double>("desired_cost"), currentCost = Double.MaxValue; int maxOptimisationAttempts = ParameterRegistry.Get <int>("max_optimisation_attempts"); int maxOptimisationSteps = ParameterRegistry.Get <int>("max_optimisation_steps"); int optimisationSteps = 0; INDArray maximisedInputs = CreateRandomisedInput(handler, inputShape); for (int i = 0; i < maxOptimisationAttempts; i++) { optimisationSteps = 0; do { // trace current inputs and run network as normal uint traceTag = handler.BeginTrace(); block["inputs"] = handler.Trace(maximisedInputs.Reshape(ArrayUtils.Concatenate(new[] { 1L, 1L }, inputShape)), traceTag); handler.BeginSession(); DataUtils.ProvideExternalInputData(network, block); network.Run(handler, trainingPass: false); // fetch current outputs and optimise against them (towards desired targets) INDArray currentTargets = network.YieldExternalOutputsLayerBuffers().First(b => b.ExternalOutputs.Contains("external_default")) .Outputs["external_default"].Get <INDArray>("activations"); INumber squaredDifference = handler.Sum(handler.Pow(handler.Subtract(handler.FlattenTimeAndFeatures(currentTargets), desiredTargets), 2)); handler.ComputeDerivativesTo(squaredDifference); handler.EndSession(); INDArray gradient = handler.GetDerivative(block["inputs"]); maximisedInputs = handler.ClearTrace(optimiser.Optimise("inputs", block["inputs"], gradient, handler)); currentCost = squaredDifference.GetValueAs <double>(); if (currentCost <= desiredCost) { goto Validation; } } while (++optimisationSteps < maxOptimisationSteps); maximisedInputs = CreateRandomisedInput(handler, inputShape); // reset input } Validation: maximisedInputs.ReshapeSelf(inputShape); string sharedResultInput = ParameterRegistry.Get <string>("shared_result_input_key"); string sharedResultSuccess = ParameterRegistry.Get <string>("shared_result_success_key"); if (optimisationSteps >= maxOptimisationSteps) { _logger.Debug($"Aborted target maximisation for {desiredTargets}, failed after {maxOptimisationSteps} optimisation steps in {maxOptimisationAttempts} attempts (exceeded limit, current cost {currentCost} but desired {desiredCost})."); resolver.ResolveSet(sharedResultSuccess, false, addIdentifierIfNotExists: true); } else { _logger.Debug($"Successfully finished target optimisation for {desiredTargets} after {optimiser} optimisation steps."); resolver.ResolveSet(sharedResultSuccess, true, addIdentifierIfNotExists: true); resolver.ResolveSet(sharedResultInput, maximisedInputs, addIdentifierIfNotExists: true); } }