Exemple #1
0
        UnorderedMapEnumerable <long, List <Tensor> > InitialGradients(long[] target_tensor_ids,
                                                                       UnorderedMap <long, TapeTensor> sources_that_are_targets,
                                                                       Tensor[] output_gradients,
                                                                       TensorTape tensor_tape,
                                                                       OpTape <BackwardFunction, TapeTensor> op_tape)
        {
            var result = new UnorderedMapEnumerable <long, List <Tensor> >();

            for (int i = 0; i < target_tensor_ids.Length; ++i)
            {
                var id = target_tensor_ids[i];
                if (output_gradients.Length == 0 || output_gradients[i] == null)
                {
                    if (tensor_tape.find(id, out var tensor_id) && tensor_id != -1)
                    {
                        if (!op_tape.find(tensor_tape[id], out var op_it))
                        {
                            throw new RuntimeError("Internal state of the gradient tape is invalid: " +
                                                   "failed to find operation producing a tensor");
                        }
                        bool found = false;
                        for (int j = 0; j < op_it.output_tensor_info.Length; ++j)
                        {
                            if (op_it.output_tensor_info[j].GetID() == id)
                            {
                                found = true;
                                var ones = op_it.output_tensor_info[j].OnesLike();
                                result[id].Add(ones);
                                break;
                            }
                        }

                        if (!found)
                        {
                            throw new ValueError("Internal state of the gradient tape is invalid: " +
                                                 "none of operations outputs match expected tensor");
                        }
                    }
                    else
                    {
                        if (sources_that_are_targets.find(id, out var source_tensor))
                        {
                            result[id].Add(source_tensor.OnesLike());
                        }
                    }
                }
                else
                {
                    result[id].Add(output_gradients[i]);
                }
            }

            return(result);
        }
Exemple #2
0
        public Tape(bool persistent, bool watch_accessed_variables)
        {
            this.persistent_ = persistent;
            this.watch_accessed_variables = watch_accessed_variables;

            tensor_tape_  = new TensorTape();
            op_tape_      = new OpTape <BackwardFunction, TapeTensor>();
            tensor_usage_ = new UnorderedMap <long, long>();

            nesting_id = ++tape_nesting_id_counter;
            tf.GetTapeSet().Add(this);
        }
Exemple #3
0
        /// <summary>
        /// A deque-backed stack, whose element references are not invalidated by
        /// pushes and pops at the back.
        /// </summary>
        // Stack<AccumulatorCallState> call_state_;

        public Tape(bool persistent, bool watch_accessed_variables)
        {
            _persistent      = persistent;
            _created_eagerly = tf.Context.executing_eagerly();
            tensor_tape_     = new TensorTape();
            op_tape_         = new OpTape();
            tensor_usage_    = new UnorderedMap <Tensor, long>();
            if (_created_eagerly)
            {
                tf.Context.start_step();
            }
            // nesting_id = ++tape_nesting_id_counter;
        }
        public BackpropInitialState PrepareBackprop(Tensor[] target,
                                                    TensorTape tensor_tape,
                                                    OpTape op_tape,
                                                    UnorderedSet <Tensor> sources_set,
                                                    bool persistent_tape)
        {
            BackpropInitialState result = new BackpropInitialState();
            var tensor_stack            = new Queue <Tensor>(target);

            while (tensor_stack.Count > 0)
            {
                var tensor_id = tensor_stack.Dequeue();

                if (!tensor_tape.find(tensor_id, out var op_id))
                {
                    continue;
                }

                if (op_id == -1 ||
                    !op_tape.find(op_id, out var op_it) ||
                    result.op_tape.find(op_id, out var result_op_it))
                {
                    continue;
                }

                result.op_tape.emplace(op_id, op_it);

                foreach (var it in op_it.input_tensor_id)
                {
                    if (result.tensor_usage_counts.find(it))
                    {
                        result.tensor_usage_counts[it]++;
                    }
                    else
                    {
                        result.tensor_usage_counts[it] = 1;
                        if (tensor_tape.find(it))
                        {
                            tensor_stack.Enqueue(it);
                        }
                    }
                }

                if (!persistent_tape)
                {
                    op_tape.Remove(op_id);
                }
            }

            foreach (var pair in result.tensor_usage_counts)
            {
                if (tensor_tape.find(pair.Key, out var it) && it != -1)
                {
                    result.op_missing_tensor[it] += 1;
                }
            }

            if (!persistent_tape)
            {
                // Call destructors for all unneeded gradient functions and
                // clear the op_tape. We can clear the tape because ownership of
                // backward functions that will be used for gradient computation
                // has been transferred to `result`.

                /*for (const auto&op_pair : *op_tape) {
                 *  op_pair.second.backward_function_deleter(
                 *      op_pair.second.backward_function);
                 * }*/
                op_tape.Clear();
            }

            return(result);
        }