Exemplo n.º 1
0
        /// <summary>
        /// Returns a `TensorShape` combining the information in `self` and `other`.
        /// </summary>
        /// <param name="other"></param>
        /// <returns></returns>
        public TensorShape merge_with(TensorShape other)
        {
            if (dims == null)
            {
                return(other);
            }

            var new_dims = new List <int>();

            foreach (var i in range(ndim))
            {
                var dim    = new Dimension(dims[i]);
                var merged = dim.merge_with(new Dimension(other.dims[i]));
                new_dims.Add(merged.value);
            }

            return(new TensorShape(new_dims.ToArray()));
        }
Exemplo n.º 2
0
        /// <summary>
        /// map on the list of tensors unpacked from `elems` on dimension 0.
        /// </summary>
        /// <param name="fn"></param>
        /// <param name="elems"></param>
        /// <param name="dtype"></param>
        /// <param name="parallel_iterations"></param>
        /// <param name="back_prop"></param>
        /// <param name="swap_memory"></param>
        /// <param name="infer_shape"></param>
        /// <param name="name"></param>
        /// <returns>A tensor or (possibly nested) sequence of tensors.</returns>
        public static Tensor map_fn(Func <Tensor, Tensor> fn,
                                    Tensor elems,
                                    TF_DataType dtype       = TF_DataType.DtInvalid,
                                    int parallel_iterations = 10,
                                    bool back_prop          = true,
                                    bool swap_memory        = false,
                                    bool infer_shape        = true,
                                    string name             = null)
        {
            bool input_is_sequence = nest.is_sequence(elems);

            Tensor[] input_flatten(Tensor x) => input_is_sequence?nest.flatten(x).ToArray() : new [] { x };
            Tensor input_pack(Tensor[] x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0];

            bool output_is_sequence;
            Func <Tensor, Tensor[]> output_flatten;
            Func <Tensor[], Tensor> output_pack;

            if (dtype == TF_DataType.DtInvalid)
            {
                output_is_sequence = input_is_sequence;
                output_flatten     = input_flatten;
                output_pack        = input_pack;
            }
            else
            {
                output_is_sequence = nest.is_sequence(dtype);
                output_flatten     = (x) => output_is_sequence?nest.flatten(x).ToArray() : new [] { x };
                output_pack        = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(dtype, x) : x[0];
            }

            var elems_flat = input_flatten(elems);

            return(tf_with(ops.name_scope(name, "map", elems_flat), delegate
            {
                //if in_graph_mode:
                //# Any get_variable calls in fn will cache the first call locally
                //# and not issue repeated network I/O requests for each iteration.
                //varscope = vs.get_variable_scope()
                //varscope_caching_device_was_none = False
                //if varscope.caching_device is None:
                //  # TODO(ebrevdo): Change to using colocate_with here and in other
                //  # methods.
                //  varscope.set_caching_device(lambda op: op.device)
                //  varscope_caching_device_was_none = True

                elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem"))
                             .ToArray();

                dtype = elems_flat.Select(elem => elem.dtype).First();
                var dtype_flat = new[] { dtype };

                // Convert elems to tensor array. n may be known statically.
                var static_shape = elems_flat[0].shape;

                var n = static_shape[0];

                // TensorArrays are always flat
                var elems_ta = elems_flat.Select(elem => new TensorArray(dtype: elem.dtype,
                                                                         size: ops.convert_to_tensor(n),
                                                                         dynamic_size: false,
                                                                         infer_shape: true)).ToArray();

                // Unpack elements
                var elems_ta_1 = new List <TensorArray>();
                foreach (var(elem_ta, elem) in zip(elems_ta, elems_flat))
                {
                    elems_ta_1.Add(elem_ta.unstack(elem));
                }

                elems_ta = elems_ta_1.ToArray();

                var i = constant_op.constant(0);

                var accs_ta = dtype_flat.Select(dt => new TensorArray(dtype: dt,
                                                                      size: ops.convert_to_tensor(n),
                                                                      dynamic_size: false,
                                                                      infer_shape: infer_shape)).ToArray();


                BodyItem compute(BodyItem item)
                {
                    var packed_values = input_pack(elems_ta.Select(elem_ta => elem_ta.read(item.I)).ToArray());
                    var packed_fn_values = fn(packed_values);
                    //nest.assert_same_structure(dtype or elems, packed_fn_values)

                    var flat_fn_values = output_flatten(packed_fn_values);
                    for (int j = 0; j < item.Accs_ta.Length; j++)
                    {
                        item.Accs_ta[j].write(item.I, flat_fn_values[j]);
                    }

                    return new BodyItem(item.I + 1, item.Accs_ta);
                }

                var r_a = control_flow_ops.while_loop(
                    (x) => x.I < n,
                    compute,
                    new BodyItem(i, accs_ta),
                    parallel_iterations: parallel_iterations,
                    back_prop: back_prop,
                    swap_memory: swap_memory,
                    maximum_iterations: tf.constant(n));
                var results_flat = r_a.Accs_ta.Select(r => r.stack()).ToArray();

                var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0]));

                foreach (var elem in elems_flat.Skip(1))
                {
                    n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0])));
                }

                foreach (Tensor r in results_flat)
                {
                    r.set_shape(new TensorShape(n_static).concatenate(r.dims.Skip(1).ToArray()));
                }

                // todo get working when the above caching_device is fixed
                //if (in_graph_mode && varscope_caching_device_was_none) {
                //    varscope.set_caching_device(None);
                //}

                return output_pack(results_flat);
            }));
Exemplo n.º 3
0
        public static Tensor scan(
            Func <Tensor, Tensor, Tensor> fn,
            Tensor elems,
            Tensor initializer      = null,
            int parallel_iterations = 10,
            bool back_prop          = true,
            bool swap_memory        = false,
            bool infer_shape        = true,
            bool reverse            = false,
            string name             = null)
        {
            bool input_is_sequence = nest.is_sequence(elems);

            Tensor[] input_flatten(Tensor x) => input_is_sequence?nest.flatten(x).ToArray() : new[] { x };
            Tensor input_pack(Tensor[] x) => input_is_sequence ? (Tensor)nest.pack_sequence_as(elems, x) : x[0];

            bool output_is_sequence;
            Func <Tensor, Tensor[]> output_flatten;
            Func <Tensor[], Tensor> output_pack;

            if (initializer == null)
            {
                output_is_sequence = input_is_sequence;
                output_flatten     = input_flatten;
                output_pack        = input_pack;
            }
            else
            {
                output_is_sequence = nest.is_sequence(initializer);
                output_flatten     = (x) => output_is_sequence?nest.flatten(x).ToArray() : new[] { x };
                output_pack        = (x) => output_is_sequence ? (Tensor)nest.pack_sequence_as(initializer, x) : x[0];
            }

            var elems_flat = input_flatten(elems);

            bool in_graph_mode = tf.Context.executing_eagerly();

            return(tf_with(ops.name_scope(name, "scan", new { elems_flat }), scope =>
            {
                if (in_graph_mode)
                {
                    // todo tf.net doesn't expose .caching_device
                    //// Any get_variable calls in fn will cache the first call locally
                    //// and not issue repeated network I/O requests for each iteration.
                    //var varscope = variable_scope.get_variable_scope();
                    //bool varscope_caching_device_was_none = false;
                    //if (varscope.caching_device = null)
                    //{
                    //    //      varscope.set_caching_device(lambda op: op.device)
                    //    //      varscope_caching_device_was_none = True
                    //}
                }

                elems_flat = elems_flat.Select(elem => ops.convert_to_tensor(elem, name: "elem")).ToArray();

                var n = tensor_shape.dimension_value(elems_flat[0].shape[0]);

                // todo python had the below but dimension_value returns int which can't be null
                //if (n == null)
                //{
                //    n = array_ops.shape(elems_flat[0])[0];
                //}

                var elems_ta = elems_flat.Select(elem => new TensorArray(
                                                     elem.dtype,
                                                     size: tf.constant(n),
                                                     dynamic_size: false,
                                                     element_shape: elem.shape.Skip(1).ToArray(),
                                                     infer_shape: true)).ToList();

                for (int index = 0; index < elems_ta.Count; index++)
                {
                    elems_ta[index].unstack(elems_flat[index]);
                }

                Tensor[] a_flat;
                int i;
                if (initializer == null)
                {
                    a_flat = elems_ta.Select(elem => elem.read(tf.constant(reverse ? n - 1 : 0))).ToArray();
                    i = 1;
                }
                else
                {
                    Tensor[] initializer_flat = output_flatten(initializer);
                    a_flat = initializer_flat.Select(init => ops.convert_to_tensor(init)).ToArray();
                    i = 0;
                }

                var accs_ta = a_flat.Select(init => new TensorArray(
                                                dtype: init.dtype,
                                                size: tf.constant(n),
                                                element_shape: infer_shape ? init.shape : null,
                                                dynamic_size: false,
                                                infer_shape: infer_shape)).ToArray();

                if (initializer == null)
                {
                    for (int index = 0; index < accs_ta.Length; index++)
                    {
                        accs_ta[index].write(tf.constant(reverse ? n - 1 : 0), a_flat[index]);
                    }
                }

                BodyItem compute(BodyItem item)
                {
                    var packed_elems = input_pack(elems_ta.Select(elem_ta => elem_ta.read(item.I)).ToArray());
                    var packed_a = output_pack(item.A_Flat);
                    var a_out = fn(packed_a, packed_elems);

                    var flat_a_out = output_flatten(a_out);
                    for (int j = 0; j < item.Accs_ta.Length; j++)
                    {
                        item.Accs_ta[j].write(item.I, flat_a_out[j]);
                    }

                    var next_i = reverse ? item.I - 1 : item.I + 1;
                    return new BodyItem(next_i, flat_a_out, item.Accs_ta);
                }

                int initial_i;
                Func <BodyItem, Tensor> condition;
                if (reverse)
                {
                    initial_i = n - 1 - i;
                    condition = x => x.I >= 0;
                }
                else
                {
                    initial_i = i;
                    condition = x => x.I < n;
                }

                BodyItem bodyItem =
                    control_flow_ops.while_loop(
                        condition,
                        compute,
                        new BodyItem(tf.constant(initial_i), a_flat, accs_ta),
                        parallel_iterations: parallel_iterations,
                        back_prop: back_prop,
                        swap_memory: swap_memory,
                        maximum_iterations: tf.constant(n));

                var results_flat = bodyItem.Accs_ta.Select(r => r.stack()).ToArray();

                var n_static = new Dimension(tensor_shape.dimension_value(elems_flat[0].TensorShape.with_rank_at_least(1).dims[0]));

                foreach (var elem in elems_flat.Skip(1))
                {
                    n_static.merge_with(new Dimension(tensor_shape.dimension_value(elem.TensorShape.with_rank_at_least(1).dims[0])));
                }

                foreach (Tensor r in results_flat)
                {
                    r.set_shape(new TensorShape(n_static).concatenate(r.dims.Skip(1).ToArray()));
                }

                // todo get working when the above caching_device is fixed
                //if (in_graph_mode && varscope_caching_device_was_none) {
                //    varscope.set_caching_device(None);
                //}

                return output_pack(results_flat);
            }));
        }