Beispiel #1
0
        public void RecordOperation(string op_type,
                                    Tensor[] input_tensors,
                                    TapeTensor[] output_tensors,
                                    BackwardFunction backward_function)
        {
            if (!ShouldRecord(input_tensors))
            {
                return;
            }

            var op_id = next_op_id_++;

            foreach (var i in input_tensors)
            {
                tensor_usage_[i]++;
            }

            foreach (var o in output_tensors)
            {
                tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
                tensor_tape_[o.GetTensor()]  = op_id;
                tensor_usage_[o.GetTensor()] = 1;
            }

            op_tape_[op_id] = new OpTapeEntry
            {
                op_type            = op_type,
                output_tensor_info = output_tensors,
                input_tensor_id    = input_tensors,
                backward_function  = backward_function
            };
        }
        public void RecordOperation(string op_type,
                                    Tensor[] input_tensors,
                                    TapeTensor[] output_tensors,
                                    Func <BackwardFunction> backward_function_getter)
        {
            var input_ids    = input_tensors.Select(x => x.Id).ToArray();
            var input_dtypes = input_tensors.Select(x => x.dtype).ToArray();

            if (!ShouldRecord(input_ids, input_dtypes))
            {
                return;
            }

            long op_id = next_op_id_++;
            var  ids   = new List <long>(input_ids.Length);

            foreach (var i in input_ids)
            {
                tensor_usage_[i]++;
                ids.Add(i);
            }

            var tensors = new List <TapeTensor>(output_tensors.Length);

            foreach (var o in output_tensors)
            {
                tensor_tape_[o.GetID()] = op_id;
                tf.Logger.Debug($"RecordOperation: tensor_tape_[{o.GetID()}] = {op_id}");
                tensor_usage_[o.GetID()] = 1;
                tensors.Add(o);
            }

            op_tape_[op_id] = new OpTapeEntry <BackwardFunction, TapeTensor>
            {
                op_type            = op_type,
                output_tensor_info = tensors.ToArray(),
                input_tensor_id    = ids.ToArray(),
                backward_function  = backward_function_getter()
            };
        }
        public void RecordOperation(string op_type,
                                    Tensor[] input_tensors,
                                    TapeTensor[] output_tensors,
                                    long[] input_tensor_id,
                                    TF_DataType[] input_dtypes,
                                    Func <BackwardFunction> backward_function_getter)
        {
            if (!ShouldRecord(input_tensor_id, input_dtypes))
            {
                return;
            }

            long op_id = next_op_id_++;
            var  ids   = new List <long>(input_tensor_id.Length);

            foreach (var i in input_tensor_id)
            {
                tensor_usage_[i]++;
                ids.Add(i);
            }

            var tensors = new List <TapeTensor>(output_tensors.Length);

            foreach (var o in output_tensors)
            {
                tensor_tape_[o.GetID()]  = op_id;
                tensor_usage_[o.GetID()] = 1;
                tensors.Add(o);
            }

            op_tape_[op_id] = new OpTapeEntry <BackwardFunction, TapeTensor>
            {
                op_type            = op_type,
                output_tensor_info = tensors.ToArray(),
                input_tensor_id    = ids.ToArray(),
                backward_function  = backward_function_getter()
            };
        }