private static void CreateLayerInGraph(ComponentGraph graph, ComponentVertex vertex, RandomNumberGenerator <float> random) { if (graph.InDegree(vertex) == 0) { // use some arbitrary layout to start // source layer must be input layer that overrides it vertex.Layer = Layer.CreateFromArchitecture(new Shape(Shape.BWHC, -1, 100, 100, 100), vertex.Architecture, random); } else { IList <Shape> shapes = new List <Shape>(); foreach (Edge <ComponentVertex> edge in graph.InEdges(vertex)) { if (edge.Source.Layer == null) { NetworkGraphBuilder.CreateLayerInGraph(graph, edge.Source, random); } shapes.Add(edge.Source.Layer.OutputShape); } vertex.Layer = shapes.Count == 1 ? Layer.CreateFromArchitecture(shapes[0], vertex.Architecture, random) : Layer.CreateFromArchitecture(shapes, vertex.Architecture, random); } }
private static void ParseComponents(IList <string> components, out Dictionary <string, string> elements, out Dictionary <string, ComponentGraph> modules) { Regex regex = new Regex(@"^(\w+)\s*=\s*(.+)$", RegexOptions.ECMAScript); elements = new Dictionary <string, string>(); modules = new Dictionary <string, ComponentGraph>(); for (int i = 0; i < components.Count; i++) { string component = components[i]; Match match = regex.Match(component); if (match != null && match.Success && match.Groups.Count == 3) { string key = match.Groups[1].Value.Trim(); string value = match.Groups[2].Value.Trim(); if (elements.ContainsKey(key)) { throw new ArgumentException( string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture_DuplicateVertex, component)); } if (value.First() == NetworkGraphBuilder.StartQualifier && value.Last() == NetworkGraphBuilder.EndQualifier) { modules.Add(key, NetworkGraphBuilder.ParseArchitecture(value.Substring(1, value.Length - 2), true)); } else { elements.Add(key, value); } components.RemoveAt(i--); } } }
private static ComponentGraph ParseArchitecture(string architecture, bool isNested) { // split architecture into elements that include layers, modules, and edges IList <string> components = ComponentGraph.SplitArchitecture(architecture, NetworkGraphBuilder.Delimiter); // extract elements that defines edges and modules (in A=... format) Dictionary <string, string> elements; Dictionary <string, ComponentGraph> modules; NetworkGraphBuilder.ParseComponents(components, out elements, out modules); // components must now contain edges only (in A-B-... format) ComponentGraph graph = ComponentGraph.FromComponents(components, elements, modules); // process nested graphs if (isNested) { // nested graphs must start with a single split layer IList <ComponentVertex> sources = graph.Sources.ToList(); if (sources.Count > 1) { ComponentVertex split = new ComponentVertex( Guid.NewGuid().ToString(), string.Format(CultureInfo.InvariantCulture, "SP{0}", sources.Count)); graph.AddEdges(sources.Select(x => new Edge <ComponentVertex>(split, x))); } // nested graphs must end with a single concat layer IList <ComponentVertex> sinks = graph.Sinks.ToList(); if (sinks.Count > 1) { ComponentVertex concat = new ComponentVertex(Guid.NewGuid().ToString(), "CONCAT"); graph.AddEdges(sinks.Select(x => new Edge <ComponentVertex>(x, concat))); } } return(graph); }
public static ComponentGraph FromComponents(IEnumerable <string> components, IDictionary <string, string> elements, IDictionary <string, ComponentGraph> modules) { ComponentGraph graph = new ComponentGraph(); Dictionary <string, ComponentVertex> vertices = new Dictionary <string, ComponentVertex>(); foreach (string component in components) { IList <string> parts = ComponentGraph.SplitArchitecture(component, NetworkGraphBuilder.Splitter); if (parts.Count >= 2 && parts.All(x => !string.IsNullOrEmpty(x))) { ComponentVertex sourceVertex = null; ComponentGraph sourceGraph = null; for (int i = 0, ii = parts.Count; i < ii; i++) { string key = parts[i]; ComponentGraph targetGraph = null; if (key.First() == NetworkGraphBuilder.StartQualifier && key.Last() == NetworkGraphBuilder.EndQualifier) { targetGraph = NetworkGraphBuilder.ParseArchitecture(key.Substring(1, key.Length - 2), true); } else { ComponentGraph moduleGraph; if (modules.TryGetValue(key, out moduleGraph)) { targetGraph = moduleGraph.Clone(true) as ComponentGraph; } } if (targetGraph != null) { if (i > 0) { bool result = sourceVertex != null?ComponentGraph.AddEdge(graph, sourceVertex, targetGraph) : ComponentGraph.AddEdge(graph, sourceGraph, targetGraph); if (!result) { throw new ArgumentException( string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture_DuplicateEdge, parts[i - 1], parts[i])); } } sourceVertex = null; sourceGraph = targetGraph; } else { ComponentVertex targetVertex; string arch; if (elements.TryGetValue(key, out arch)) { if (!vertices.TryGetValue(key, out targetVertex)) { vertices[key] = targetVertex = new ComponentVertex(key, arch); } } else { targetVertex = new ComponentVertex(Guid.NewGuid().ToString(), key); } if (i > 0) { bool result = sourceVertex != null?ComponentGraph.AddEdge(graph, sourceVertex, targetVertex) : ComponentGraph.AddEdge(graph, sourceGraph, targetVertex); if (!result) { throw new ArgumentException( string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture_DuplicateEdge, parts[i - 1], parts[i])); } } sourceVertex = targetVertex; sourceGraph = null; } } } else if (parts.Count == 1 && !string.IsNullOrEmpty(parts[0])) { graph.AddVertex(new ComponentVertex(Guid.NewGuid().ToString(), parts[0])); } else { throw new ArgumentException( string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture, component)); } } // recreate vertices keys, for the embedded graphs to have unique layers foreach (ComponentVertex vertex in graph.Vertices) { vertex.Key = Guid.NewGuid().ToString(); } return(graph); }
public static NetworkGraph CreateNetworkGraph(string architecture, bool addActivationLayers, bool addLossLayer) { if (architecture == null) { throw new ArgumentNullException(nameof(architecture)); } if (string.IsNullOrEmpty(architecture)) { throw new ArgumentException(Properties.Resources.E_InvalidNetArchitecture_NoLayers, nameof(architecture)); } // 1. parse architecture string and build preliminary graph ComponentGraph componentGraph = NetworkGraphBuilder.ParseArchitecture(architecture, false); // 2. create layers in the preliminary graph RandomNumberGenerator <float> random = null; //// new Random(0); foreach (ComponentVertex sink in componentGraph.Sinks) { NetworkGraphBuilder.CreateLayerInGraph(componentGraph, sink, random); } // 3. convert to network graph NetworkGraph graph = new NetworkGraph(); foreach (Edge <ComponentVertex> edge in componentGraph.Edges) { /*NetworkGraph sourceGraph = (edge.Source.Layer as RNNLayer)?.Graph; * NetworkGraph targetGraph = (edge.Target.Layer as RNNLayer)?.Graph; * * if (sourceGraph != null) * { * graph.AddGraph(sourceGraph); * if (targetGraph != null) * { * graph.AddEdges(sourceGraph.Sinks, targetGraph.Sources); * graph.AddGraph(targetGraph); * } * else * { * graph.AddEdges(sourceGraph.Sinks, edge.Target.Layer); * } * } * else if (targetGraph != null) * { * graph.AddEdges(edge.Source.Layer, targetGraph.Sources); * graph.AddGraph(targetGraph); * } * else*/ { graph.AddEdge(edge.Source.Layer, edge.Target.Layer); } } // 4. add missing loss layers if (addLossLayer) { NetworkGraphBuilder.AddLossLayers(graph); } // 5. add missing activation layers if (addActivationLayers) { NetworkGraphBuilder.AddActivationLayers(graph); } // 6. initialize stochastic biases with ReLU activations NetworkGraphBuilder.InitializeReLUs(graph); return(graph); }
/// <summary> /// Creates a classification neural network from a string that contains network architecture. /// </summary> /// <param name="architecture">The network architecture.</param> /// <returns>The <see cref="Network"/> object this method creates.</returns> public static Network FromArchitecture(string architecture) { NetworkGraph graph = NetworkGraphBuilder.CreateNetworkGraph(architecture, true, false); return(new Network(graph)); }
/// <summary> /// Creates a classification neural network from a string that contains network architecture. /// </summary> /// <param name="architecture">The network architecture.</param> /// <param name="classes">The classes the network should able to classify into.</param> /// <param name="allowedClasses">The classes the network is allowed to classify.</param> /// <param name="blankClass">The blank class that represents none of the real classes.</param> /// <returns> /// The <see cref="ClassificationNetwork"/> object this method creates. /// </returns> public static ClassificationNetwork FromArchitecture(string architecture, IList <string> classes, IList <string> allowedClasses, string blankClass) { NetworkGraph graph = NetworkGraphBuilder.CreateNetworkGraph(architecture, true, true); return(new ClassificationNetwork(graph, classes, allowedClasses, blankClass)); }