Ejemplo n.º 1
0
        /// <summary>
        /// Invoke this hook with a certain parameter registry.
        /// </summary>
        /// <param name="registry">The registry containing the required values for this hook's execution.</param>
        /// <param name="resolver">A helper resolver for complex registry entries (automatically cached).</param>
        public override void SubInvoke(IRegistry registry, IRegistryResolver resolver)
        {
            INetwork network = resolver.ResolveGetSingle <INetwork>("network.self");
            ITrainer trainer = resolver.ResolveGetSingle <ITrainer>("trainer.self");
            string   validationIteratorName   = ParameterRegistry.Get <string>("validation_iterator_name");
            string   finalExternalOutputAlias = ParameterRegistry.Get <string>("final_external_output_alias");
            string   activationsAlias         = ParameterRegistry.Get <string>("output_activations_alias");
            string   targetsAlias             = ParameterRegistry.Get <string>("targets_alias");

            if (!trainer.AdditionalNameDataIterators.ContainsKey(validationIteratorName))
            {
                throw new InvalidOperationException($"Additional named data iterator for validation with name \"{validationIteratorName}\" does not exist in referenced trainer {trainer} but is required.");
            }

            IDataIterator validationIterator = trainer.AdditionalNameDataIterators[validationIteratorName];

            ScoreBegin(registry, resolver);

            foreach (var block in validationIterator.Yield(Operator.Handler, Operator.Sigma))
            {
                trainer.ProvideExternalInputData(network, block);
                network.Run(Operator.Handler, trainingPass: false);

                INDArray finalOutputPredictions = null;

                foreach (ILayerBuffer layerBuffer in network.YieldExternalOutputsLayerBuffers())
                {
                    foreach (string outputAlias in layerBuffer.ExternalOutputs)
                    {
                        if (outputAlias.Equals(finalExternalOutputAlias))
                        {
                            finalOutputPredictions = Operator.Handler.ClearTrace(layerBuffer.Outputs[outputAlias].Get <INDArray>(activationsAlias));

                            goto FoundOutput;
                        }
                    }
                    ;
                }

                throw new InvalidOperationException($"Cannot find final output with alias \"{finalExternalOutputAlias}\" in the current network (but is required to score validation).");

FoundOutput:
                ScoreIntermediate(finalOutputPredictions, block[targetsAlias], Operator.Handler);
            }

            ScoreEnd(registry, resolver);
        }