コード例 #1
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            string registryEntry = ParameterRegistry.Get <string>("registry_entry");
            string resultEntry   = ParameterRegistry.Get <string>("shared_result_entry");

            double value = resolver.ResolveGetSingle <double>(registryEntry);
            double previousAccumulatedValue = ParameterRegistry.Get <double>("accumulated_value");

            int currentInterval = HookUtils.GetCurrentInterval(registry, TimeStep.TimeScale);
            int resetInterval   = ParameterRegistry.Get <int>("reset_interval");
            int resetEvery      = ParameterRegistry.Get <int>("reset_every");
            int countSinceReset = ParameterRegistry.Get <int>("count_since_reset");

            if (currentInterval == resetInterval || resetEvery > 0 && currentInterval % resetEvery == 0)
            {
                previousAccumulatedValue = 0.0;
                countSinceReset          = 0;
            }

            countSinceReset++;

            double result = value + previousAccumulatedValue;

            if (ParameterRegistry.Get <bool>("average_mode"))
            {
                result /= countSinceReset;
            }

            ParameterRegistry["count_since_reset"] = countSinceReset;
            ParameterRegistry["accumulated_value"] = value + previousAccumulatedValue;
            resolver.ResolveSet(resultEntry, result, addIdentifierIfNotExists: true);
        }
コード例 #2
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            IValueModifier modifier   = ParameterRegistry.Get <IValueModifier>("modifier");
            string         identifier = ParameterRegistry.Get <string>("parameter_identifier");
            object         parameter  = resolver.ResolveGetSingle <object>(identifier);

            INumber  asNumber = parameter as INumber;
            INDArray asArray  = parameter as INDArray;

            if (asNumber != null)
            {
                parameter = modifier.Modify(identifier, asNumber, asNumber.AssociatedHandler);
            }
            else if (asArray != null)
            {
                parameter = modifier.Modify(identifier, asArray, asArray.AssociatedHandler);
            }
            else
            {
                throw new InvalidOperationException($"Cannot apply modifier {modifier} to parameter \"{identifier}\" with value {parameter}, " +
                                                    $"parameter is neither {nameof(INumber)} nor {nameof(INDArray)}.");
            }

            resolver.ResolveSet(identifier, parameter);
        }
コード例 #3
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            INetwork network = resolver.ResolveGetSingle <INetwork>("network.self");
            ITrainer trainer = resolver.ResolveGetSingle <ITrainer>("trainer.self");
            string   validationIteratorName   = ParameterRegistry.Get <string>("validation_iterator_name");
            string   finalExternalOutputAlias = ParameterRegistry.Get <string>("final_external_output_alias");
            string   activationsAlias         = ParameterRegistry.Get <string>("output_activations_alias");
            string   targetsAlias             = ParameterRegistry.Get <string>("targets_alias");

            if (!trainer.AdditionalNameDataIterators.ContainsKey(validationIteratorName))
            {
                throw new InvalidOperationException($"Additional named data iterator for validation with name \"{validationIteratorName}\" does not exist in referenced trainer {trainer} but is required.");
            }

            IDataIterator validationIterator = trainer.AdditionalNameDataIterators[validationIteratorName];

            ScoreBegin(registry, resolver);

            foreach (var block in validationIterator.Yield(Operator.Handler, Operator.Sigma))
            {
                trainer.ProvideExternalInputData(network, block);
                network.Run(Operator.Handler, trainingPass: false);

                INDArray finalOutputPredictions = null;

                foreach (ILayerBuffer layerBuffer in network.YieldExternalOutputsLayerBuffers())
                {
                    foreach (string outputAlias in layerBuffer.ExternalOutputs)
                    {
                        if (outputAlias.Equals(finalExternalOutputAlias))
                        {
                            finalOutputPredictions = Operator.Handler.ClearTrace(layerBuffer.Outputs[outputAlias].Get <INDArray>(activationsAlias));

                            goto FoundOutput;
                        }
                    }
                    ;
                }

                throw new InvalidOperationException($"Cannot find final output with alias \"{finalExternalOutputAlias}\" in the current network (but is required to score validation).");

FoundOutput:
                ScoreIntermediate(finalOutputPredictions, block[targetsAlias], Operator.Handler);
            }

            ScoreEnd(registry, resolver);
        }
コード例 #4
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            int[] tops = ParameterRegistry.Get <int[]>("tops");
            IDictionary <int, double> topDictionary = new Dictionary <int, double>();

            foreach (int top in tops)
            {
                topDictionary[top] = resolver.ResolveGetSingle <double>("shared.classification_accuracy_top" + top);
            }

            Report(topDictionary);
        }
コード例 #5
0
ファイル: PassNetworkHook.cs プロジェクト: xiaoxiongnpu/Sigma
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            IDictionary <string, INDArray> block    = (IDictionary <string, INDArray>)ParameterRegistry[DataIdentifier];
            IPassNetworkReceiver           receiver = (IPassNetworkReceiver)ParameterRegistry[ReceiverIdentifier];

            INetwork network = resolver.ResolveGetSingle <INetwork>("network.self");

            IDataProvider provider = new DefaultDataProvider();

            provider.SetExternalOutputLink("external_default", (targetsRegistry, layer, targetBlock) => { receiver.ReceivePass((INDArray)targetsRegistry["activations"]); });

            DataUtils.ProvideExternalInputData(provider, network, block);
            network.Run(Operator.Handler, false);
            DataUtils.ProvideExternalOutputData(provider, network, block);
        }
コード例 #6
0
        /// <inheritdoc />
        public string GetName(IRegistry registry, IRegistryResolver resolver, object sender)
        {
            for (int i = 0; i < _parameterIdentifiers.Length; i++)
            {
                _bufferParameters[i] = resolver.ResolveGetSingle <object>(_parameterIdentifiers[i]);
            }

            string name = string.Format(_formatString, _bufferParameters);

            for (var i = 0; i < _bufferParameters.Length; i++)
            {
                _bufferParameters[i] = null;
            }

            return(name);
        }
コード例 #7
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            string[] accumulatedIdentifiers = ParameterRegistry.Get <string[]>("accumulated_identifiers");
            string[] valueIdentifiers       = ParameterRegistry.Get <string[]>("value_identifiers");

            IDictionary <string, object> valuesByIdentifier = ParameterRegistry.Get <IDictionary <string, object> >("value_buffer");

            for (int i = 0; i < valueIdentifiers.Length; i++)
            {
                // TODO let callee decide if it's a number (double) / something else
                object value = resolver.ResolveGetSingle <double>(accumulatedIdentifiers[i]);

                valuesByIdentifier[valueIdentifiers[i]] = value;
            }

            ReportValues(valuesByIdentifier, ParameterRegistry.Get <bool>("report_epoch_iteration"), registry.Get <int>("epoch"), registry.Get <int>("iteration"));
        }
コード例 #8
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            INDArray desiredTargets = ParameterRegistry.Get <INDArray>("desired_targets");
            int      uid            = ParameterRegistry.Get <int>("uid");

            bool success = resolver.ResolveGetSingleWithDefault($"shared.target_maximisation_result_{uid}_success", false);

            if (!success)
            {
                _logger.Warn($"Failed target maximisation for {desiredTargets}, nothing to print.");
            }
            else
            {
                IComputationHandler handler = Operator.Handler;
                INDArray            inputs  = resolver.ResolveGetSingle <INDArray>($"shared.target_maximisation_result_{uid}_input");

                OnTargetMaximisationSuccess(handler, inputs, desiredTargets);
            }
        }
コード例 #9
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            string      registryEntryToSave = ParameterRegistry.Get <string>("registry_entry_to_save");
            INamer      fileNamer           = ParameterRegistry.Get <INamer>("file_namer");
            object      toSerialise         = resolver.ResolveGetSingle <object>(registryEntryToSave);
            bool        verbose             = ParameterRegistry.Get <bool>("verbose");
            Func <T, T> selectFunction      = ParameterRegistry.Get <Func <T, T> >("select_function");

            toSerialise = selectFunction.Invoke((T)toSerialise);

            lock (fileNamer)
            {
                Serialisation.WriteBinaryFile(toSerialise, fileNamer.GetName(registry, resolver, this), verbose: false);
            }

            if (verbose)
            {
                _logger.Info($"Saved \"{registryEntryToSave}\" to \"{SigmaEnvironment.Globals.Get<string>("storage_path")}{fileNamer}\".");
            }
        }
コード例 #10
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            // we need copies of network and optimiser as to not affect the current internal state
            INetwork network = (INetwork)resolver.ResolveGetSingle <INetwork>("network.self").DeepCopy();
            BaseGradientOptimiser optimiser    = (BaseGradientOptimiser)resolver.ResolveGetSingle <BaseGradientOptimiser>("optimiser.self").ShallowCopy();
            INDArray            desiredTargets = ParameterRegistry.Get <INDArray>("desired_targets");
            IComputationHandler handler        = new DebugHandler(Operator.Handler);

            long[] inputShape = network.YieldExternalInputsLayerBuffers().First().Parameters.Get <long[]>("shape");

            IDictionary <string, INDArray> block = DataUtils.MakeBlock("targets", desiredTargets);            // desired targets don't change during execution

            double   desiredCost = ParameterRegistry.Get <double>("desired_cost"), currentCost = Double.MaxValue;
            int      maxOptimisationAttempts = ParameterRegistry.Get <int>("max_optimisation_attempts");
            int      maxOptimisationSteps    = ParameterRegistry.Get <int>("max_optimisation_steps");
            int      optimisationSteps       = 0;
            INDArray maximisedInputs         = CreateRandomisedInput(handler, inputShape);

            for (int i = 0; i < maxOptimisationAttempts; i++)
            {
                optimisationSteps = 0;

                do
                {
                    // trace current inputs and run network as normal
                    uint traceTag = handler.BeginTrace();
                    block["inputs"] = handler.Trace(maximisedInputs.Reshape(ArrayUtils.Concatenate(new[] { 1L, 1L }, inputShape)), traceTag);

                    handler.BeginSession();

                    DataUtils.ProvideExternalInputData(network, block);
                    network.Run(handler, trainingPass: false);

                    // fetch current outputs and optimise against them (towards desired targets)
                    INDArray currentTargets = network.YieldExternalOutputsLayerBuffers().First(b => b.ExternalOutputs.Contains("external_default"))
                                              .Outputs["external_default"].Get <INDArray>("activations");
                    INumber squaredDifference = handler.Sum(handler.Pow(handler.Subtract(handler.FlattenTimeAndFeatures(currentTargets), desiredTargets), 2));

                    handler.ComputeDerivativesTo(squaredDifference);

                    handler.EndSession();

                    INDArray gradient = handler.GetDerivative(block["inputs"]);
                    maximisedInputs = handler.ClearTrace(optimiser.Optimise("inputs", block["inputs"], gradient, handler));

                    currentCost = squaredDifference.GetValueAs <double>();

                    if (currentCost <= desiredCost)
                    {
                        goto Validation;
                    }
                } while (++optimisationSteps < maxOptimisationSteps);

                maximisedInputs = CreateRandomisedInput(handler, inputShape);                 // reset input
            }

Validation:
            maximisedInputs.ReshapeSelf(inputShape);

            string sharedResultInput   = ParameterRegistry.Get <string>("shared_result_input_key");
            string sharedResultSuccess = ParameterRegistry.Get <string>("shared_result_success_key");

            if (optimisationSteps >= maxOptimisationSteps)
            {
                _logger.Debug($"Aborted target maximisation for {desiredTargets}, failed after {maxOptimisationSteps} optimisation steps in {maxOptimisationAttempts} attempts (exceeded limit, current cost {currentCost} but desired {desiredCost}).");

                resolver.ResolveSet(sharedResultSuccess, false, addIdentifierIfNotExists: true);
            }
            else
            {
                _logger.Debug($"Successfully finished target optimisation for {desiredTargets} after {optimiser} optimisation steps.");

                resolver.ResolveSet(sharedResultSuccess, true, addIdentifierIfNotExists: true);
                resolver.ResolveSet(sharedResultInput, maximisedInputs, addIdentifierIfNotExists: true);
            }
        }
コード例 #11
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            INetwork network = resolver.ResolveGetSingle <INetwork>("network.self");

            //DataUtils.ProvideExternalInputData(network, DataUtils.MakeBlock());
        }
コード例 #12
0
        public override bool CheckCriteria(IRegistry registry, IRegistryResolver resolver)
        {
            ExtremaTarget target = ParameterRegistry.Get <ExtremaTarget>("target");
            string        parameter = ParameterRegistry.Get <string>("parameter_identifier");
            double        value = SimpleDirectEntries[0] ? registry.Get <double>(parameter) : resolver.ResolveGetSingle <double>(parameter);
            double        currentExtremum = ParameterRegistry.Get <double>("current_extremum");
            bool          reachedExtremum = target == ExtremaTarget.Min && value <currentExtremum || target == ExtremaTarget.Max && value> currentExtremum;

            if (double.IsNaN(currentExtremum) || reachedExtremum)
            {
                ParameterRegistry["current_extremum"] = value;

                return(true);
            }

            return(false);
        }
コード例 #13
0
        public override bool CheckCriteria(IRegistry registry, IRegistryResolver resolver)
        {
            string parameter        = ParameterRegistry.Get <string>("parameter_identifier");
            object rawValue         = SimpleDirectEntries[0] ? registry.Get(parameter) : resolver.ResolveGetSingle <object>(parameter);
            double value            = (double)Convert.ChangeType(rawValue, typeof(double));
            bool   thresholdReached = _InternalThresholdReached(value, ParameterRegistry.Get <double>("threshold_value"), ParameterRegistry.Get <ComparisonTarget>("target"));
            bool   fire             = thresholdReached && (!ParameterRegistry.Get <bool>("last_check_met") || ParameterRegistry.Get <bool>("fire_continously"));

            ParameterRegistry["last_check_met"] = thresholdReached;

            return(fire);
        }
コード例 #14
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry if optional conditional criteria are satisfied.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            double accuracy = resolver.ResolveGetSingle <double>("shared.classification_accuracy");

            Report(accuracy);
        }