예제 #1
0
        public static void CreateKerasHistoryHelper(Tensors tensors, List <Operation> processed_ops, List <Layer> created_layers)
        {
            foreach (var tensor in tensors)
            {
                if (tensor.KerasHistory != null)
                {
                    continue;
                }

                var op = tensor.op;
                if (!processed_ops.Contains(op))
                {
                    var layer_inputs = new List <Tensor>();
                    var constants    = new Dictionary <int, NDArray>();
                    foreach (var(i, op_input) in enumerate(op.inputs._inputs))
                    {
                        if (uses_keras_history(op_input))
                        {
                            layer_inputs.Add(op_input);
                        }
                        else
                        {
                            tf_with(ops.init_scope(), delegate
                            {
                                constants[i] = keras.backend.eval_in_eager_or_function(op_input);
                            });
                        }
                    }

                    // recursively
                    CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
                    var opLayerArgs = new TensorFlowOpLayerArgs
                    {
                        NodeDef   = op.node_def,
                        Constants = constants,
                        Name      = op.name
                    };
                    var op_layer = new TensorFlowOpLayer(opLayerArgs);
                    created_layers.Add(op_layer);
                    op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
                    processed_ops.Add(op);
                }
            }
        }
        public static void CreateKerasHistoryHelper(Tensors tensors, List <Operation> processed_ops, List <Layer> created_layers)
        {
            foreach (var tensor in tensors)
            {
                if (tensor.KerasHistory != null)
                {
                    continue;
                }

                var op = tensor.op;
                if (!processed_ops.Contains(op))
                {
                    var layer_inputs = new List <Tensor>();

                    foreach (var(i, op_input) in enumerate(op.inputs._inputs))
                    {
                        if (uses_keras_history(op_input))
                        {
                            layer_inputs.Add(op_input);
                        }
                        else
                        {
                        }

                        // recursively
                        CreateKerasHistoryHelper(layer_inputs, processed_ops, created_layers);
                        var op_layer = new TensorFlowOpLayer(new TensorFlowOpLayerArgs
                        {
                            NodeDef = op.node_def,
                            Name    = op.name
                        });
                        created_layers.Add(op_layer);
                        op_layer.SetConnectivityMetadata(layer_inputs, op.outputs);
                    }
                }
            }
        }