Ejemplo n.º 1
0
        public TFFunction(TensorFlowBackend k, List <Tensor> inputs, List <Tensor> outputs, List <List <Tensor> > updates, string name)
        {
            this.K  = k;
            this.tf = k.tf;

            if (updates == null)
            {
                updates = new List <List <Tensor> >();
            }
            this.inputs  = inputs;
            this.outputs = outputs;
            {
                var updates_ops = new List <TFOutput>();
                foreach (List <Tensor> update in updates)
                {
                    if (update.Count == 2)
                    {
                        var p     = K.In(update[0]);
                        var new_p = K.In(update[1]);
                        updates_ops.Add(tf.Assign(p, new_p));
                    }
                    else
                    {
                        // assumed already an op
                        updates_ops.Add(K.In(update[0]));
                    }
                }

                //this.updates_op = tf.group(updates_ops);
                this.updates_op = updates_ops;
            }

            this.name = name;
            //this.session_kwargs = session_kwargs;
        }
Ejemplo n.º 2
0
        public override List <Tensor> Call(List <Array> inputs)
        {
            var feed_dict = new Dictionary <Tensor, Array>();

            foreach (var(tensor, value) in Enumerable.Zip(this.inputs, inputs, (a, b) => (a, b)))
            {
                // if (is_sparse(tensor))
                // {
                //     sparse_coo = value.tocoo()
                //     indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                //                               np.expand_dims(sparse_coo.col, 1)), 1)
                //     value = (indices, sparse_coo.data, sparse_coo.shape)
                // }
                feed_dict[tensor] = value;
            }

            var session = K._SESSION;
            var outputs = new List <TFOutput>();

            foreach (var o in this.outputs)
            {
                outputs.Add(K.In(o));
            }
            foreach (TFOutput o in this.updates_op)
            {
                outputs.Add(o);
            }

            var _inputs = new List <TFOutput>();
            var _values = new List <TFTensor>();

            foreach (KeyValuePair <Tensor, Array> pair in feed_dict)
            {
                _inputs.Add(K.In(pair.Key));
                _values.Add(pair.Value);
            }

            var updated = session.Run(
                inputs: _inputs.ToArray(),
                inputValues: _values.ToArray(),
                outputs: outputs.ToArray());

            return(updated.Get(0, this.outputs.Count).Select(t => K.Out(t)).ToList());
        }
Ejemplo n.º 3
0
        public override List <Tensor> Call(List <Array> inputs)
        {
            var feed_dict = new Dictionary <Tensor, Array>();

            foreach (var(tensor, value) in Enumerable.Zip(this.inputs, inputs, (a, b) => (a, b)))
            {
                // if (is_sparse(tensor))
                // {
                //     sparse_coo = value.tocoo()
                //     indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
                //                               np.expand_dims(sparse_coo.col, 1)), 1)
                //     value = (indices, sparse_coo.data, sparse_coo.shape)
                // }
                feed_dict[tensor] = value;
            }

            var session = K._SESSION;

            var init = tf.GetGlobalVariablesInitializer();

            if (init.Length > 0)
            {
                Console.WriteLine("Initializing variables:");
                foreach (var op in init)
                {
                    Console.WriteLine(" - " + op.Name);
                    session.Run(new TFOutput[0], new TFTensor[0], new TFOutput[0], new[] { op });
                }

                Console.WriteLine("Operations:");
                foreach (var op in tf.GetEnumerator())
                {
                    Console.WriteLine(" - " + op.Name);
                }
                Console.WriteLine();
            }

            //Console.WriteLine("Before:");
            //PrintVariables(feed_dict, session);
            // Console.ReadKey();

            var runner = session.GetRunner();

            foreach (var o in this.outputs)
            {
                runner.Fetch(K.In(o).output);
            }

            foreach (var op in this.updates_op)
            {
                runner.AddTarget(op);
            }

            foreach (KeyValuePair <Tensor, Array> pair in feed_dict)
            {
                TensorFlowTensor t = K.In(pair.Key);
                runner.AddInput(t.output, pair.Value);
            }



            var updated = runner.Run();

            //Console.WriteLine();

            //foreach (var v in updated)
            //{
            //    object obj = v.GetValue();
            //    if (obj is float[,])
            //        Console.WriteLine((obj as float[,]).ToCSharp());
            //    else if (obj is float[])
            //        Console.WriteLine((obj as float[]).ToCSharp());
            //    else
            //        Console.WriteLine(obj);
            //}

            //Console.WriteLine();
            //Console.WriteLine();

            //Console.WriteLine("After:");
            //PrintVariables(feed_dict, session);

            return(updated.Get(0, this.outputs.Count).Select(t => K.Out(t)).ToList());

            // Console.ReadKey();
        }