public void TestRegistryResolverModifyHierarchy() { Registry rootRegistry = new Registry(tags: "root"); RegistryResolver resolver = new RegistryResolver(rootRegistry); Registry trainer1 = new Registry(rootRegistry, tags: "trainer"); Registry trainer2 = new Registry(rootRegistry, new[] { "trainer" }); Registry trainer1Architecture = new Registry(trainer1, new[] { "architecture" }); Registry trainer2Architecture = new Registry(trainer2, new[] { "architecture" }); rootRegistry["trainer1"] = trainer1; rootRegistry["trainer2"] = trainer2; trainer1["architecture"] = trainer1Architecture; trainer2["architecture"] = trainer2Architecture; trainer1Architecture["complexity"] = 2; trainer2Architecture["complexity"] = 3; Assert.AreEqual(new[] { 2, 3 }, resolver.ResolveGet <int>("*.architecture.complexity")); Registry differentTrainer1Architecture = new Registry(trainer1, "architecture"); trainer1["architecture"] = differentTrainer1Architecture; differentTrainer1Architecture["complexity"] = 5; resolver.ResolveSet("*.architecture.complexity", 11, false, typeof(int)); Assert.AreEqual(new[] { 11, 11 }, resolver.ResolveGet <int>("*.architecture.complexity")); }
public void TestWeightedNetworkMergerMerge() { INetworkMerger merger = new WeightedNetworkMerger(3, 8); INetwork netA = NetworkMergerTestUtils.GenerateNetwork(1); INetwork netB = NetworkMergerTestUtils.GenerateNetwork(9); merger.AddMergeEntry("layers.*.weights"); merger.Merge(netA, netB); IRegistryResolver resolverA = new RegistryResolver(netA.Registry); IRegistryResolver resolverB = new RegistryResolver(netB.Registry); INDArray weightsA = resolverA.ResolveGet <INDArray>("layers.*.weights")[0]; INDArray weightsB = resolverB.ResolveGet <INDArray>("layers.*.weights")[0]; float firstValueA = weightsA.GetValue <float>(0, 0); float firstValueB = weightsB.GetValue <float>(0, 0); // the first value will change Assert.AreEqual(6.82, System.Math.Round(firstValueA * 100) / 100); // the second net may not be changed Assert.AreEqual(9, firstValueB); merger.RemoveMergeEntry("layers.*.weights"); merger.Merge(netA, netB); weightsA = resolverA.ResolveGet <INDArray>("layers.*.weights")[0]; firstValueA = weightsA.GetValue <float>(0, 0); // may not change Assert.AreEqual(6.82, System.Math.Round(firstValueA * 100) / 100); }
public void TestAverageNetworkMergerMerge() { INetworkMerger merger = new AverageNetworkMerger(); INetwork netA = NetworkMergerTestUtils.GenerateNetwork(1); INetwork netB = NetworkMergerTestUtils.GenerateNetwork(5); merger.AddMergeEntry("layers.*.weights"); merger.Merge(netA, netB); IRegistryResolver resolverA = new RegistryResolver(netA.Registry); IRegistryResolver resolverB = new RegistryResolver(netB.Registry); INDArray weightsA = resolverA.ResolveGet <INDArray>("layers.*.weights")[0]; INDArray weightsB = resolverB.ResolveGet <INDArray>("layers.*.weights")[0]; float firstValueA = weightsA.GetValue <float>(0, 0); float firstValueB = weightsB.GetValue <float>(0, 0); // the first value will change Assert.AreEqual(3, firstValueA); // the second net may not be changed Assert.AreEqual(5, firstValueB); merger.RemoveMergeEntry("layers.*.weights"); merger.Merge(netA, netB); weightsA = resolverA.ResolveGet <INDArray>("layers.*.weights")[0]; firstValueA = weightsA.GetValue <float>(0, 0); Assert.AreEqual(3, firstValueA); }
private void ApplyValueModifiers(IRegistry localRegistry, IComputationHandler handler) { if (_valueModifiers.Count == 0) { return; } RegistryResolver resolver = new RegistryResolver(localRegistry); foreach (string identifier in _valueModifiers.Keys) { string[] fullyResolvedIdentifiers; object[] values = resolver.ResolveGet <object>(identifier, out fullyResolvedIdentifiers); for (int i = 0; i < values.Length; i++) { object value = values[i]; INDArray asNDArray = value as INDArray; INumber asNumber = value as INumber; if (asNDArray != null) { foreach (IValueModifier modifier in _valueModifiers[identifier]) { asNDArray = modifier.Modify(fullyResolvedIdentifiers[i], asNDArray, handler); } values[i] = asNDArray; } else if (asNumber != null) { foreach (IValueModifier modifier in _valueModifiers[identifier]) { asNumber = modifier.Modify(fullyResolvedIdentifiers[i], asNumber, handler); } values[i] = asNumber; } else { double?asDouble = value as double?; if (asDouble != null) { foreach (IValueModifier modifier in _valueModifiers[identifier]) { asDouble = modifier.Modify(fullyResolvedIdentifiers[i], asDouble.Value, handler); } values[i] = asDouble.Value; } } resolver.ResolveSet(fullyResolvedIdentifiers[i], values[i]); } } }
public void TestRegistryResolverModifyDirect() { Registry rootRegistry = new Registry(tags: "root"); RegistryResolver resolver = new RegistryResolver(rootRegistry); Registry trainer1 = new Registry(rootRegistry, tags: "trainer"); Registry trainer2 = new Registry(rootRegistry, tags: "trainer"); rootRegistry["trainer1"] = trainer1; rootRegistry["trainer2"] = trainer2; //declare parameters in registry trainer1["accuracy"] = 0.0f; trainer2["accuracy"] = 0.0f; resolver.ResolveSet("trainer1.accuracy", 1.0f, false, typeof(float)); resolver.ResolveSet("trainer2.accuracy", 2.0f, false, typeof(float)); Assert.AreEqual(1.0f, resolver.ResolveGet <float>("trainer1.accuracy")[0]); Assert.AreEqual(2.0f, resolver.ResolveGet <float>("trainer2.accuracy")[0]); }
public void TestRegistryResolverModifyComplex() { Registry rootRegistry = new Registry(tags: "root"); RegistryResolver resolver = new RegistryResolver(rootRegistry); Registry trainer1 = new Registry(rootRegistry, tags: "trainer"); Registry trainer2 = new Registry(rootRegistry, new[] { "trainer" }); Registry weirdtrainer = new Registry(rootRegistry, new[] { "trainer" }); Registry childRegistryToIgnore = new Registry(rootRegistry); Registry trainer1Architecture = new Registry(trainer1, new[] { "architecture" }); Registry weirdarchitecture = new Registry(weirdtrainer, new[] { "architecture" }); rootRegistry["trainer1"] = trainer1; rootRegistry["trainer2"] = trainer2; rootRegistry["weirdtrainer"] = weirdtrainer; rootRegistry["childtoignore"] = childRegistryToIgnore; trainer1["architecture"] = trainer1Architecture; weirdtrainer["architecture"] = weirdarchitecture; trainer1Architecture["complexity"] = 2; weirdarchitecture["complexity"] = 3; trainer1["accuracy"] = 0.0f; trainer2["accuracy"] = 0.0f; resolver.ResolveSet("trainer*.accuracy", 1.0f, false, typeof(float)); resolver.ResolveSet("*<trainer>.*<architecture>.complexity", 9, false, typeof(int)); string[] resolved = null; Assert.AreEqual(new[] { 1.0f, 1.0f }, resolver.ResolveGet <float>("trainer*.accuracy", out resolved)); Assert.AreEqual(new[] { "trainer1.accuracy", "trainer2.accuracy" }, resolved); Assert.AreEqual(new[] { 9, 9 }, resolver.ResolveGet <int>("*<trainer>.architecture.complexity")); Assert.AreEqual(new IRegistry[] { trainer1, trainer2, weirdtrainer }, resolver.ResolveGet <IRegistry>("*<trainer>", out resolved)); Assert.AreEqual(new[] { "trainer1", "trainer2", "weirdtrainer" }, resolved); }
/// <summary> /// Specify how multiple networks are merged into a single one. <see cref="root" /> is <em>not</em> /// considered for the calculation. It is merely the storage container. (Although root can also be in /// <see cref="networks" />). /// </summary> /// <param name="root"> /// The root network that will be modified. Since the <see cref="INetworkMerger" /> does not know how /// to create a <see cref="INetwork" />, it will be passed not returned. /// </param> /// <param name="networks"> /// The networks that will be merged into the <see cref="root" />. Can contain <see cref="root" /> /// itself. /// </param> /// <param name="handler"> /// A handler can be specified optionally. If not passed (but required), /// <see cref="ITraceable.AssociatedHandler" /> will be used. /// </param> public void Merge(INetwork root, IEnumerable <INetwork> networks, IComputationHandler handler = null) { IRegistryResolver rootResolver = new RegistryResolver(root.Registry); string[] mergeKeys = CopyMatchIdentifiers(); if (mergeKeys.Length == 0) { Log.Warn($"Attempted merge network {root} with networks {networks} using handler {handler} but no merge keys were set so nothing will happen. This is probably not intended."); } // mapping of resolved mergeEnetry and all data IDictionary <string, IList <object> > resolvedDataArrays = new Dictionary <string, IList <object> >(mergeKeys.Length); int numNetworks = 0; // fill the mapping of all values foreach (INetwork network in networks) { IRegistryResolver resolver = new RegistryResolver(network.Registry); foreach (string mergeKey in mergeKeys) { string[] fullyResolvedIdentifiers; object[] values = resolver.ResolveGet <object>(mergeKey, out fullyResolvedIdentifiers); Debug.Assert(fullyResolvedIdentifiers.Length == values.Length); for (int i = 0; i < values.Length; i++) { IList <object> allValuesAtKey = resolvedDataArrays.TryGetValue(fullyResolvedIdentifiers[i], () => new List <object>()); allValuesAtKey.Add(values[i]); } } numNetworks++; } foreach (KeyValuePair <string, IList <object> > keyDataPair in resolvedDataArrays) { int numObjects = keyDataPair.Value.Count; if (numObjects != numNetworks) { _log.Warn($"Inconsistent network states for identifier \"{keyDataPair.Key}\", only {keyDataPair.Value.Count} have it but there are {numNetworks} networks."); } object merged = Merge(keyDataPair.Value.ToArray(), handler); rootResolver.ResolveSet(keyDataPair.Key, merged); } }
private static void SampleNetworkArchitecture() { SigmaEnvironment sigma = SigmaEnvironment.Create("test"); IComputationHandler handler = new CpuFloat32Handler(); ITrainer trainer = sigma.CreateTrainer("test_trainer"); trainer.Network = new Network(); trainer.Network.Architecture = InputLayer.Construct(2, 2) + ElementwiseLayer.Construct(2 * 2) + FullyConnectedLayer.Construct(2) + 2 * (FullyConnectedLayer.Construct(4) + FullyConnectedLayer.Construct(2)) + OutputLayer.Construct(2); trainer.Network = (INetwork)trainer.Network.DeepCopy(); trainer.Operator = new CpuMultithreadedOperator(10); trainer.AddInitialiser("*.weights", new GaussianInitialiser(standardDeviation: 0.1f)); trainer.AddInitialiser("*.bias*", new GaussianInitialiser(standardDeviation: 0.01f, mean: 0.03f)); trainer.Initialise(handler); trainer.Network = (INetwork)trainer.Network.DeepCopy(); Console.WriteLine(trainer.Network.Registry); IRegistryResolver resolver = new RegistryResolver(trainer.Network.Registry); Console.WriteLine("==============="); object[] weights = resolver.ResolveGet <object>("layers.*.weights"); Console.WriteLine(string.Join("\n", weights)); Console.WriteLine("==============="); //foreach (ILayerBuffer buffer in trainer.Network.YieldLayerBuffersOrdered()) //{ // Console.WriteLine(buffer.Layer.Name + ": "); // Console.WriteLine("inputs:"); // foreach (string input in buffer.Inputs.Keys) // { // Console.WriteLine($"\t{input}: {buffer.Inputs[input].GetHashCode()}"); // } // Console.WriteLine("outputs:"); // foreach (string output in buffer.Outputs.Keys) // { // Console.WriteLine($"\t{output}: {buffer.Outputs[output].GetHashCode()}"); // } //} }
private void InitialiseNetwork(IComputationHandler handler, out int initialisedNumberCount, out int initialisedNDArrayCount) { Network.Initialise(handler); initialisedNDArrayCount = 0; initialisedNumberCount = 0; RegistryResolver networkResolver = new RegistryResolver(Network.Registry.Get <IRegistry>("layers")); List <string> orderedInitialiserIdentifiers = _initialisers.Keys.ToList(); orderedInitialiserIdentifiers.Sort(RegistryUtils.CompareIdentifierSpecificityAscending); foreach (string identifier in orderedInitialiserIdentifiers) { object[] values = networkResolver.ResolveGet(identifier, new object[0]); IInitialiser initialiser = _initialisers[identifier]; foreach (object value in values) { INDArray array = value as INDArray; if (array != null) { initialiser.Initialise(array, handler, Sigma.Random); initialisedNDArrayCount++; } else { INumber number = value as INumber; if (number != null) { initialiser.Initialise(number, handler, Sigma.Random); initialisedNumberCount++; } } } } }
/// <summary> /// Resolve all matching identifiers in this registry. For the detailed supported syntax <see cref="IRegistryResolver"/>. /// </summary> /// <typeparam name="T">The most specific common type of the variables to retrieve.</typeparam> /// <param name="matchIdentifier">The full match identifier.</param> /// <param name="values">An array of values found at the matching identifiers, filled with the values found at all matching identifiers (for reuse and optimisation if request is issued repeatedly).</param> /// <returns>An array of values found at the matching identifiers. The parameter values is used if it is large enough and not null.</returns> public T[] ResolveGet <T>(string matchIdentifier, T[] values = null) { return(RegistryResolver.ResolveGet(matchIdentifier, values)); }
/// <summary> /// Resolve all matching identifiers in this registry. For the detailed supported syntax <see cref="IRegistryResolver"/>. /// </summary> /// <typeparam name="T">The most specific common type of the variables to retrieve.</typeparam> /// <param name="matchIdentifier">The full match identifier.</param> /// <param name="fullMatchedIdentifierArray">The fully matched identifiers corresponding to the given match identifier.</param> /// <param name="values">An array of values found at the matching identifiers, filled with the values found at all matching identifiers (for reuse and optimisation if request is issued repeatedly).</param> /// <returns>An array of values found at the matching identifiers. The parameter values is used if it is large enough and not null.</returns> public T[] ResolveGet <T>(string matchIdentifier, out string[] fullMatchedIdentifierArray, T[] values = null) { return(RegistryResolver.ResolveGet(matchIdentifier, out fullMatchedIdentifierArray, values)); }