Exemplo n.º 1
0
        private static void ParseComponents(IList <string> components, out Dictionary <string, string> elements, out Dictionary <string, ComponentGraph> modules)
        {
            Regex regex = new Regex(@"^(\w+)\s*=\s*(.+)$", RegexOptions.ECMAScript);

            elements = new Dictionary <string, string>();
            modules  = new Dictionary <string, ComponentGraph>();

            for (int i = 0; i < components.Count; i++)
            {
                string component = components[i];
                Match  match     = regex.Match(component);
                if (match != null && match.Success && match.Groups.Count == 3)
                {
                    string key   = match.Groups[1].Value.Trim();
                    string value = match.Groups[2].Value.Trim();

                    if (elements.ContainsKey(key))
                    {
                        throw new ArgumentException(
                                  string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture_DuplicateVertex, component));
                    }

                    if (value.First() == NetworkGraphBuilder.StartQualifier && value.Last() == NetworkGraphBuilder.EndQualifier)
                    {
                        modules.Add(key, NetworkGraphBuilder.ParseArchitecture(value.Substring(1, value.Length - 2), true));
                    }
                    else
                    {
                        elements.Add(key, value);
                    }

                    components.RemoveAt(i--);
                }
            }
        }
Exemplo n.º 2
0
            public static ComponentGraph FromComponents(IEnumerable <string> components, IDictionary <string, string> elements, IDictionary <string, ComponentGraph> modules)
            {
                ComponentGraph graph = new ComponentGraph();
                Dictionary <string, ComponentVertex> vertices = new Dictionary <string, ComponentVertex>();

                foreach (string component in components)
                {
                    IList <string> parts = ComponentGraph.SplitArchitecture(component, NetworkGraphBuilder.Splitter);
                    if (parts.Count >= 2 && parts.All(x => !string.IsNullOrEmpty(x)))
                    {
                        ComponentVertex sourceVertex = null;
                        ComponentGraph  sourceGraph  = null;

                        for (int i = 0, ii = parts.Count; i < ii; i++)
                        {
                            string         key         = parts[i];
                            ComponentGraph targetGraph = null;

                            if (key.First() == NetworkGraphBuilder.StartQualifier && key.Last() == NetworkGraphBuilder.EndQualifier)
                            {
                                targetGraph = NetworkGraphBuilder.ParseArchitecture(key.Substring(1, key.Length - 2), true);
                            }
                            else
                            {
                                ComponentGraph moduleGraph;
                                if (modules.TryGetValue(key, out moduleGraph))
                                {
                                    targetGraph = moduleGraph.Clone(true) as ComponentGraph;
                                }
                            }

                            if (targetGraph != null)
                            {
                                if (i > 0)
                                {
                                    bool result = sourceVertex != null?ComponentGraph.AddEdge(graph, sourceVertex, targetGraph) : ComponentGraph.AddEdge(graph, sourceGraph, targetGraph);

                                    if (!result)
                                    {
                                        throw new ArgumentException(
                                                  string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture_DuplicateEdge, parts[i - 1], parts[i]));
                                    }
                                }

                                sourceVertex = null;
                                sourceGraph  = targetGraph;
                            }
                            else
                            {
                                ComponentVertex targetVertex;

                                string arch;
                                if (elements.TryGetValue(key, out arch))
                                {
                                    if (!vertices.TryGetValue(key, out targetVertex))
                                    {
                                        vertices[key] = targetVertex = new ComponentVertex(key, arch);
                                    }
                                }
                                else
                                {
                                    targetVertex = new ComponentVertex(Guid.NewGuid().ToString(), key);
                                }

                                if (i > 0)
                                {
                                    bool result = sourceVertex != null?ComponentGraph.AddEdge(graph, sourceVertex, targetVertex) : ComponentGraph.AddEdge(graph, sourceGraph, targetVertex);

                                    if (!result)
                                    {
                                        throw new ArgumentException(
                                                  string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture_DuplicateEdge, parts[i - 1], parts[i]));
                                    }
                                }

                                sourceVertex = targetVertex;
                                sourceGraph  = null;
                            }
                        }
                    }
                    else if (parts.Count == 1 && !string.IsNullOrEmpty(parts[0]))
                    {
                        graph.AddVertex(new ComponentVertex(Guid.NewGuid().ToString(), parts[0]));
                    }
                    else
                    {
                        throw new ArgumentException(
                                  string.Format(CultureInfo.InvariantCulture, Properties.Resources.E_InvalidNetArchitecture, component));
                    }
                }

                // recreate vertices keys, for the embedded graphs to have unique layers
                foreach (ComponentVertex vertex in graph.Vertices)
                {
                    vertex.Key = Guid.NewGuid().ToString();
                }

                return(graph);
            }
Exemplo n.º 3
0
        public static NetworkGraph CreateNetworkGraph(string architecture, bool addActivationLayers, bool addLossLayer)
        {
            if (architecture == null)
            {
                throw new ArgumentNullException(nameof(architecture));
            }

            if (string.IsNullOrEmpty(architecture))
            {
                throw new ArgumentException(Properties.Resources.E_InvalidNetArchitecture_NoLayers, nameof(architecture));
            }

            // 1. parse architecture string and build preliminary graph
            ComponentGraph componentGraph = NetworkGraphBuilder.ParseArchitecture(architecture, false);

            // 2. create layers in the preliminary graph
            RandomNumberGenerator <float> random = null; //// new Random(0);

            foreach (ComponentVertex sink in componentGraph.Sinks)
            {
                NetworkGraphBuilder.CreateLayerInGraph(componentGraph, sink, random);
            }

            // 3. convert to network graph
            NetworkGraph graph = new NetworkGraph();

            foreach (Edge <ComponentVertex> edge in componentGraph.Edges)
            {
                /*NetworkGraph sourceGraph = (edge.Source.Layer as RNNLayer)?.Graph;
                 * NetworkGraph targetGraph = (edge.Target.Layer as RNNLayer)?.Graph;
                 *
                 * if (sourceGraph != null)
                 * {
                 *  graph.AddGraph(sourceGraph);
                 *  if (targetGraph != null)
                 *  {
                 *      graph.AddEdges(sourceGraph.Sinks, targetGraph.Sources);
                 *      graph.AddGraph(targetGraph);
                 *  }
                 *  else
                 *  {
                 *      graph.AddEdges(sourceGraph.Sinks, edge.Target.Layer);
                 *  }
                 * }
                 * else if (targetGraph != null)
                 * {
                 *  graph.AddEdges(edge.Source.Layer, targetGraph.Sources);
                 *  graph.AddGraph(targetGraph);
                 * }
                 * else*/
                {
                    graph.AddEdge(edge.Source.Layer, edge.Target.Layer);
                }
            }

            // 4. add missing loss layers
            if (addLossLayer)
            {
                NetworkGraphBuilder.AddLossLayers(graph);
            }

            // 5. add missing activation layers
            if (addActivationLayers)
            {
                NetworkGraphBuilder.AddActivationLayers(graph);
            }

            // 6. initialize stochastic biases with ReLU activations
            NetworkGraphBuilder.InitializeReLUs(graph);

            return(graph);
        }