Ejemplo n.º 1
0
        /// <summary>
        /// Helper function for reduction ops.
        /// </summary>
        /// <param name="input_shape">1-D Tensor, the shape of the Tensor being reduced.</param>
        /// <param name="axes">1-D Tensor, the reduction axes.</param>
        /// <returns>A 1-D Tensor, the output shape as if keepdims were set to True.</returns>
        public static Tensor reduced_shape(Tensor input_shape, Tensor axes)
        {
            if (tf.Context.executing_eagerly())
            {
                var input_shape_val = input_shape.numpy();
                foreach (var axes_val in axes.ToArray <int>())
                {
                    input_shape_val[axes_val] = 1;
                }
                return(tf.constant(input_shape_val));
            }

            input_shape = to_int32(input_shape);
            axes        = to_int32(axes);

            var input_rank = array_ops.size(input_shape);

            axes = (axes + input_rank) % input_rank;
            var axes_shape = array_ops.shape(axes);
            var rng        = math_ops.range(input_rank);
            var a1         = new Tensor[] { rng, axes };
            var fill       = gen_array_ops.fill(axes_shape, 1);
            var a2         = new Tensor[] { input_shape, fill };

            return(gen_data_flow_ops.dynamic_stitch(a1, a2));
        }
Ejemplo n.º 2
0
        public static Shape constant_value_as_shape(Tensor tensor)
        {
            bool hasattr(Graph property, string attr)
            {
                var t = property.GetType().GetProperties();

                foreach (System.Reflection.PropertyInfo pi in t)
                {
                    if (pi.Name == attr)
                    {
                        return(true);
                    }
                }
                return(false);
            }

            if (tensor is EagerTensor eagerTensor)
            {
                if (tensor.dtype == tf.int64)
                {
                    return(new Shape(tensor.ToArray <long>()));
                }
                else
                {
                    return(new Shape(tensor.ToArray <int>()));
                }
            }

            if (tensor.shape.ndim == 0)
            {
                var value_ = constant_value(tensor);
                if (value_ == null)
                {
                    throw new ValueError(
                              @"Received a scalar with unknown value as shape; require a statically
known scalar with value '-1' to describe an unknown shape.");
                }
                if ((int)value_ != -1)
                {
                    throw new ValueError(
                              String.Format(@"Received a scalar value {0} as shape; require a statically known
scalar with value '-1' to describe an unknown shape.", value_));
                }
                return(tensor.shape.unknown_shape(-1));
            }

            var shape = tensor.shape.with_rank(1);

            if (shape == new Shape(new int[] { 1 }))
            {
                return(new Shape(new int[] { }));
            }
            else if (tensor.op.type == "Cast")
            {
                var pre_cast = constant_value_as_shape(tensor.op.inputs[0]);
                if (pre_cast.dims == null)
                {
                    return(pre_cast);
                }
                var cast_dtype = dtypes.as_tf_dtype((Type)tensor.op.get_attr("DstT"));
                if (!Array.Exists(new[] { dtypes.int32, dtypes.int64 }, cast_dtype_ => cast_dtype_ == cast_dtype))
                {
                    return(tensor.shape.unknown_shape((int)shape.dims[0]));
                }

                long[] x_ = { };
                foreach (var x in pre_cast.dims)
                {
                    if (x != -1)
                    {
                        x_[x_.Length] = x;
                    }
                    else
                    {
                        x_[x_.Length] = -1;
                    }
                }
                var dest_dtype_shape_array = np.array(x_).astype(cast_dtype);

                long[] y_ = { };
                foreach (int y in dest_dtype_shape_array.ToArray <int>())
                {
                    if (y >= 0)
                    {
                        y_[y_.Length] = y;
                    }
                    else
                    {
                        y_[y_.Length] = -1;
                    }
                }
                return(new Shape(y_));
            }
            else if (tensor.op.type == "Shape")
            {
                return(tensor.op.inputs[0].shape);
            }
            else if (tensor.op.type == "Pack")
            {
                var ret_ = new Shape(new int[] { });
                if ((int)tensor.op.get_attr("axis") != 0)
                {
                    throw new ValueError(String.Format(
                                             @"Since rank 1 inputs are expected, Pack's axis: {0} must be 0, otherwise it
would not be rank 1.", tensor.op.get_attr("axis")));
                }
                foreach (Tensor pack_input in tensor.op.inputs)
                {
                    var       pack_input_val = (int)constant_value(pack_input);
                    Dimension new_dim;
                    if (pack_input_val < 0)
                    {
                        new_dim = new Dimension(-1);
                    }
                    else if (pack_input_val == null)
                    {
                        new_dim = new Dimension(-1);
                    }
                    else
                    {
                        new_dim = new Dimension(pack_input_val);
                    }
                    ret_ = ret_.concatenate(new long[] { new_dim });
                }
                return(ret_);
            }
            else if (tensor.op.type == "Concat")
            {
                var ret_ = new Shape(new int[] { });

                var inputlist_ = new ArraySegment <Tensor>(tensor.op.inputs, 1,
                                                           tensor.op.inputs.Length - 1);
                foreach (var concat_input in inputlist_)
                {
                    ret_ = ret_.concatenate(constant_value_as_shape(concat_input));
                }
                return(ret_);
            }
            else if (tensor.op.type == "StridedSlice")
            {
                try
                {
                    var begin   = constant_value(tensor.op.inputs[1]);
                    var end     = constant_value(tensor.op.inputs[2]);
                    var strides = constant_value(tensor.op.inputs[3]);
                    if (new[] { begin, end, strides }.All(x => x == null))
                    {
                        begin   = begin[0];
                        end     = end[0];
                        strides = strides[0];
                        var begin_mask = tensor.op.get_attr("begin_mask");
                        if ((int)begin_mask == 1)
                        {
                            begin = null;
                        }
                        var end_mask = tensor.op.get_attr("end_mask");
                        if ((int)end_mask == 1)
                        {
                            end = null;
                        }

                        var ellipsis_mask    = tensor.op.get_attr("ellipsis_mask");
                        var new_axis_mask    = tensor.op.get_attr("new_axis_mask");
                        var shrink_axis_mask = tensor.op.get_attr("shrink_axis_mask");

                        bool valid_attributes;
                        if (!(bool)ellipsis_mask && !(bool)new_axis_mask &&
                            !(bool)shrink_axis_mask && !((bool)begin_mask || (int)begin_mask == 1) &&
                            !((bool)end_mask || (int)end_mask == 1))
                        {
                            valid_attributes = true;
                        }
                        else
                        {
                            valid_attributes = false;
                        }
                        if (valid_attributes)
                        {
                            // sorry for the mess here, but this hacky solution was the best way
                            // i could come up with to implement the things done in python in c#
                            var prev_ = constant_value_as_shape(tensor.op.inputs[0]).dims;
                            var prev  = prev_.Skip((int)begin).Take((int)end - (int)begin).ToArray();
                            // 100 being the comparison doesn't really matter here; it's going to break anyway
                            for (int iter = 0; iter != 100; iter = iter + (int)strides)
                            {
                                prev[prev.Length] = prev_[iter];
                                if ((iter + (int)strides) > prev_.Length)
                                {
                                    break;
                                }
                            }
                            var ret_ = new Shape(prev);
                            return(ret_);
                        }
                    }
                }
                catch (Exception ex)
                {
                    if (ex is ValueError || ex is TypeError)
                    {
                    }
                }
            }
            else if (tensor.op.type == "Placeholder" &&
                     tensor.op.graph.building_function &&
                     tensor.op.graph is FuncGraph func_graph)
            {
                int i = 0;
                foreach (Tensor capture in func_graph.internal_captures)
                {
                    if (capture.GetType() == typeof(Tensor))
                    {
                        var external_capture = func_graph.external_captures[i];
                        return(constant_value_as_shape(external_capture));
                    }

                    i++;
                }
            }

            var ret   = tensor.shape.unknown_shape((int)shape.dims[0]);
            var value = constant_value(tensor);

            if (value is not null)
            {
                var d_ = new int[value.size];
                foreach (var(index, d) in enumerate(value.ToArray <int>()))
                {
                    d_[index] = d >= 0 ? d : -1;
                }

                ret = ret.merge_with(new Shape(d_));
            }
            return(ret);
        }
Ejemplo n.º 3
0
        /// <summary>
        /// Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
        ///
        /// `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
        /// `false_fn` must have the same non-zero number and type of outputs.
        ///
        /// **WARNING**: Any Tensors or Operations created outside of `true_fn` and
        /// `false_fn` will be executed regardless of which branch is selected at runtime.
        ///
        /// Although this behavior is consistent with the dataflow model of TensorFlow,
        /// it has frequently surprised users who expected a lazier semantics.
        /// Consider the following simple program:
        ///
        /// z = tf.multiply(a, b)
        /// result = tf.cond(x &lt; y, ()=> tf.add(x, z), ()=> tf.square(y))
        ///
        /// If `x&lt;y`, the `tf.add` operation will be executed and `tf.square`
        /// operation will not be executed.Since `z` is needed for at least one
        /// branch of the `cond`, the `tf.multiply` operation is always executed,
        /// unconditionally.
        ///
        /// Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
        /// call to `cond`, and not at all during `Session.run()`). `cond`
        /// stitches together the graph fragments created during the `true_fn` and
        /// `false_fn` calls with some additional graph nodes to ensure that the right
        /// branch gets executed depending on the value of `pred`.
        ///
        /// `tf.cond` supports nested structures as implemented in
        /// `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
        /// same(possibly nested) value structure of lists, tuples, and/or named tuples.
        /// Singleton lists and tuples form the only exceptions to this: when returned by
        /// `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
        /// This behavior is disabled by passing `strict= True`.
        /// </summary>
        /// <param name="pred"> A scalar determining whether to return the result of `true_fn` or
        /// `false_fn`.</param>
        /// <param name="true_fn">The callable to be performed if pred is true.</param>
        /// <param name="false_fn">The callable to be performed if pred is false.</param>
        /// <param name="strict"> A boolean that enables/disables 'strict' mode; see above.</param>
        /// <param name="name">Optional name prefix for the returned tensors.</param>
        /// <returns>Tensors returned by the call to either `true_fn` or `false_fn`. If the
        /// callables return a singleton list, the element is extracted from the list.</returns>
        public static Tensor cond(Tensor pred,
                                  Func <ITensorOrOperation> true_fn  = null,
                                  Func <ITensorOrOperation> false_fn = null,
                                  bool strict = false,
                                  string name = null)
        {
            return(tf_with(ops.name_scope(name, "cond", new { pred }), delegate
            {
                if (tf.context.executing_eagerly())
                {
                    if (pred.ToArray <bool>()[0])
                    {
                        return true_fn() as Tensor;
                    }
                    else
                    {
                        return false_fn() as Tensor;
                    }

                    return null;
                }

                // Add the Switch to the graph.
                var switch_result = @switch(pred, pred);
                var(p_2, p_1) = (switch_result[0], switch_result[1]);
                var pivot_1 = array_ops.identity(p_1, name: "switch_t");
                var pivot_2 = array_ops.identity(p_2, name: "switch_f");
                pred = array_ops.identity(pred, name: "pred_id");

                // Disable the fetching of tensors that are only on one branch of cond.
                foreach (var tensor in new Tensor[] { p_1, p_2, pivot_1, pivot_2, pred })
                {
                    tensor.op.graph.prevent_fetching(tensor.op);
                }

                // Build the graph for the true branch in a new context.
                var context_t = new CondContext(pred: pred, pivot: pivot_1, branch: 1);
                ITensorOrOperation orig_res_t;
                Tensor res_t;
                try
                {
                    context_t.Enter();
                    (orig_res_t, res_t) = context_t.BuildCondBranch(true_fn);
                    context_t.ExitResult(new[] { res_t });
                }
                finally
                {
                    context_t.Exit();
                }
                // Build the graph for the false branch in a new context.
                var context_f = new CondContext(pred: pred, pivot: pivot_2, branch: 0);
                ITensorOrOperation orig_res_f;
                Tensor res_f;
                try
                {
                    context_f.Enter();
                    (orig_res_f, res_f) = context_f.BuildCondBranch(false_fn);
                    context_f.ExitResult(new[] { res_f });
                }
                finally
                {
                    context_f.Exit();
                }

                var res_t_flat = new Tensor[] { res_t };
                var res_f_flat = new Tensor[] { res_f };

                var merges = zip(res_f_flat, res_t_flat)
                             .Select(pair => merge(new Tensor[] { pair.Item1, pair.Item2 })[0])
                             .ToArray();

                if (orig_res_t is Tensor orig_res_tensor)
                {
                    merges = _convert_flows_to_tensorarrays(new[] { orig_res_tensor }, merges)
                             .Select(x => x as Tensor)
                             .ToArray();
                }
                else
                {
                }

                if (context_t.outer_context == null)
                {
                    ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_t);
                    ops.add_to_collection(tf.GraphKeys.COND_CONTEXT, context_f);
                }

                return merges[0];
            }));