public override List <Tensor> Call(List <Array> inputs)
        {
            var feed_dict = new Dictionary <Tensor, Array>();

            foreach (var(tensor, value) in Enumerable.Zip(this.inputs, inputs, (a, b) => (a, b)))
            {
                // if (is_sparse(tensor))
                // {
                //     sparse_coo = value.tocoo()
                //     indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                //                               np.expand_dims(sparse_coo.col, 1)), 1)
                //     value = (indices, sparse_coo.data, sparse_coo.shape)
                // }
                feed_dict[tensor] = value;
            }

            var session = K._SESSION;
            var outputs = new List <TFOutput>();

            foreach (var o in this.outputs)
            {
                outputs.Add(K.In(o));
            }
            foreach (TFOutput o in this.updates_op)
            {
                outputs.Add(o);
            }

            var _inputs = new List <TFOutput>();
            var _values = new List <TFTensor>();

            foreach (KeyValuePair <Tensor, Array> pair in feed_dict)
            {
                _inputs.Add(K.In(pair.Key));
                _values.Add(pair.Value);
            }

            var updated = session.Run(
                inputs: _inputs.ToArray(),
                inputValues: _values.ToArray(),
                outputs: outputs.ToArray());

            return(updated.Get(0, this.outputs.Count).Select(t => K.Out(t)).ToList());
        }
Beispiel #2
0
        public override List <Tensor> Call(List <Array> inputs)
        {
            var feed_dict = new Dictionary <Tensor, Array>();

            foreach (var(tensor, value) in Enumerable.Zip(this.inputs, inputs, (a, b) => (a, b)))
            {
                // if (is_sparse(tensor))
                // {
                //     sparse_coo = value.tocoo()
                //     indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                //                               np.expand_dims(sparse_coo.col, 1)), 1)
                //     value = (indices, sparse_coo.data, sparse_coo.shape)
                // }
                feed_dict[tensor] = value;
            }

            var session = K._SESSION;

            var init = tf.GetGlobalVariablesInitializer();

            if (init.Length > 0)
            {
                Console.WriteLine("Initializing variables:");
                foreach (var op in init)
                {
                    Console.WriteLine(" - " + op.Name);
                    session.Run(new TFOutput[0], new TFTensor[0], new TFOutput[0], new[] { op });
                }

                Console.WriteLine("Operations:");
                foreach (var op in tf.GetEnumerator())
                {
                    Console.WriteLine(" - " + op.Name);
                }
                Console.WriteLine();
            }

            //Console.WriteLine("Before:");
            //PrintVariables(feed_dict, session);
            // Console.ReadKey();

            var runner = session.GetRunner();

            foreach (var o in this.outputs)
            {
                runner.Fetch(K.In(o).output);
            }

            foreach (var op in this.updates_op)
            {
                runner.AddTarget(op);
            }

            foreach (KeyValuePair <Tensor, Array> pair in feed_dict)
            {
                TensorFlowTensor t = K.In(pair.Key);
                runner.AddInput(t.output, pair.Value);
            }



            var updated = runner.Run();

            //Console.WriteLine();

            //foreach (var v in updated)
            //{
            //    object obj = v.GetValue();
            //    if (obj is float[,])
            //        Console.WriteLine((obj as float[,]).ToCSharp());
            //    else if (obj is float[])
            //        Console.WriteLine((obj as float[]).ToCSharp());
            //    else
            //        Console.WriteLine(obj);
            //}

            //Console.WriteLine();
            //Console.WriteLine();

            //Console.WriteLine("After:");
            //PrintVariables(feed_dict, session);

            return(updated.Get(0, this.outputs.Count).Select(t => K.Out(t)).ToList());

            // Console.ReadKey();
        }