コード例 #1
0
        public static void PrintSummary(Model model, int?line_length = null, double[] positions = null, Action <string> print_fn = null)
        {
            int           trainable_count = 0;
            List <string> to_display      = new List <string>();
            bool          sequential_like = false;
            var           relevant_nodes  = new List <Node>();

            if (print_fn == null)
            {
                print_fn = Console.WriteLine;
            }

            if (model.GetType().Name == "Sequential")
            {
                sequential_like = true;
            }
            else if (!model._is_graph_network)
            {
                // We treat subclassed models as a simple sequence of layers,
                // for logging purposes.
                sequential_like = true;
            }
            else
            {
                sequential_like = true;
                var nodes_by_depth = model._nodes_by_depth.Values;
                var nodes          = new List <Node>();
                foreach (var v in nodes_by_depth)
                {
                    if (v.Count > 1 || v.Count == 1 && v[0].inbound_layers.Length > 1)
                    {
                        // if the model has multiple nodes
                        // or if the nodes have multiple inbound_layers
                        // the model is no longer sequential
                        sequential_like = false;
                        break;
                    }

                    nodes.AddRange(v);
                }
                if (sequential_like)
                {
                    // search for shared layers
                    foreach (var layer in model.Layers)
                    {
                        var flag = false;
                        foreach (var node in layer._inbound_nodes)
                        {
                            if (nodes.Contains(node))
                            {
                                if (flag)
                                {
                                    sequential_like = false;
                                    break;
                                }
                                else
                                {
                                    flag = true;
                                }
                            }
                        }
                        if (!sequential_like)
                        {
                            break;
                        }
                    }
                }
            }
            if (sequential_like)
            {
                line_length = line_length ?? 65;
                positions   = positions ?? new double[] {
                    0.45,
                    0.85,
                    1.0
                };

                if (positions.Last() <= 1)
                {
                    positions = (from p in positions
                                 select line_length.Value * p).ToArray();
                }

                // header names for the different log elements
                to_display = new List <string> {
                    "Layer (type)",
                    "Output Shape",
                    "Param #"
                };
            }
            else
            {
                line_length = line_length ?? 98;
                positions   = positions ?? new double[] {
                    0.33,
                    0.55,
                    0.67,
                    1.0
                };

                if (positions.Last() <= 1)
                {
                    positions = (from p in positions
                                 select line_length.Value * p).ToArray();
                }
                // header names for the different log elements
                to_display = new List <string> {
                    "Layer (type)",
                    "Output Shape",
                    "Param #",
                    "Connected to"
                };

                foreach (var v in model._nodes_by_depth.Values)
                {
                    relevant_nodes.AddRange(v);
                }
            }

            Action <string[], double[]> print_row = (fields, _positions) => {
                var line = "";
                foreach (var i in Enumerable.Range(0, fields.Length))
                {
                    if (i > 0)
                    {
                        line = line.Substring(0, line.Length - 1) + " ";
                    }

                    line += fields[i].ToString();
                    line  = line.Substring(0, Convert.ToInt32(_positions[i]));
                    var lineCharCount = Convert.ToInt32(_positions[i] - line.Length);
                    foreach (var item in Enumerable.Range(0, lineCharCount))
                    {
                        line += " ";
                    }
                }

                print_fn(line);
            };

            string msg = "";

            foreach (var item in Enumerable.Range(0, line_length.Value))
            {
                msg += "_";
            }

            print_fn(msg);
            print_row(to_display.ToArray(), positions);

            msg = "";
            foreach (var item in Enumerable.Range(0, line_length.Value))
            {
                msg += "=";
            }
            print_fn(msg);

            Action <Layer> print_layer_summary = layer => {
                string output_shape;
                try
                {
                    output_shape = layer.OutputShape.ToString();
                }
                catch (Exception)
                {
                    output_shape = "multiple";
                }

                var name     = layer.name;
                var cls_name = layer.GetType().Name;
                var fields   = new List <string> {
                    name + " (" + cls_name + ")",
                    output_shape.ToString(),
                    layer.CountParams().ToString()
                };

                print_row(fields.ToArray(), positions);
            };

            Action <Layer> print_layer_summary_with_connections = layer => {
                string first_connection;
                string output_shape;
                try
                {
                    output_shape = layer.OutputShape.ToString();
                }
                catch (Exception)
                {
                    output_shape = "multiple";
                }

                var connections = new List <string>();
                foreach (var node in layer._inbound_nodes)
                {
                    if (relevant_nodes.Count > 0 && !relevant_nodes.Contains(node))
                    {
                        // node is not part of the current network
                        continue;
                    }
                    foreach (var i in Enumerable.Range(0, node.inbound_layers.Length))
                    {
                        var inbound_layer        = node.inbound_layers[i].name;
                        var inbound_node_index   = node.node_indices[i];
                        var inbound_tensor_index = node.tensor_indices[i];
                        connections.Add(inbound_layer + "[" + inbound_node_index.ToString() + "][" + inbound_tensor_index.ToString() + "]");
                    }
                }

                var name     = layer.name;
                var cls_name = layer.GetType().Name;
                if (connections.Count == 0)
                {
                    first_connection = "";
                }
                else
                {
                    first_connection = connections[0];
                }

                var fields = new List <string> {
                    name + " (" + cls_name + ")",
                    output_shape,
                    layer.CountParams().ToString(),
                    first_connection
                };
                print_row(fields.ToArray(), positions);
                if (connections.Count > 1)
                {
                    foreach (var i in Enumerable.Range(1, connections.Count - 1))
                    {
                        fields = new List <string> {
                            "",
                            "",
                            "",
                            connections[i]
                        };
                        print_row(fields.ToArray(), positions);
                    }
                }
            };

            var layers = model.Layers;

            foreach (var i in Enumerable.Range(0, layers.Length))
            {
                if (sequential_like)
                {
                    print_layer_summary(layers[i]);
                }
                else
                {
                    print_layer_summary_with_connections(layers[i]);
                }
                if (i == layers.Length - 1)
                {
                    msg = "";
                    foreach (var item in Enumerable.Range(0, line_length.Value))
                    {
                        msg += "=";
                    }

                    print_fn(msg);
                }
                else
                {
                    msg = "";
                    foreach (var item in Enumerable.Range(0, line_length.Value))
                    {
                        msg += "=";
                    }

                    print_fn(msg);
                }
            }

            if (model._collected_trainable_weights != null)
            {
                trainable_count = model._collected_trainable_weights.Select(x => K.CountParams(x)).Sum();
            }
            else
            {
                trainable_count = model.TrainableWeights.Select(x => K.CountParams(x)).Sum();
            }

            var non_trainable_count = model.NonTrainableWeights.Select(x => K.CountParams(x)).Sum();

            print_fn($"Total params: {trainable_count + non_trainable_count}");
            print_fn($"Trainable params: {trainable_count}");
            print_fn($"Non-trainable params: {non_trainable_count}");
            msg = "";
            foreach (var item in Enumerable.Range(0, line_length.Value))
            {
                msg += "_";
            }

            print_fn(msg);
        }
コード例 #2
0
 public static int CountParams(KerasSymbol[] weights)
 {
     return(weights.Select(x => K.CountParams(x)).ToList().Sum());
 }