예제 #1
0
        public RaggedTensor string_split_v2(Tensor input, string sep = " ", int maxsplit = -1, string name = null)
        {
            return(tf_with(ops.name_scope(name, "StringSplit"), scope =>
            {
                var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING);
                if (input.rank == 0)
                {
                    var parts = string_split_v2(array_ops.stack(new[] { input }),
                                                sep: sep,
                                                maxsplit: maxsplit,
                                                name: name);
                    return parts;
                }

                var result = tf.Context.ExecuteOp("StringSplitV2", name,
                                                  new ExecuteOpArgs(input, sep)
                {
                    GetGradientAttrs = op => new
                    {
                        maxsplit = op.get_attr <int>("maxsplit")
                    }
                }.SetAttributes(new { maxsplit }));
                var(indices, values, shape) = (result[0], result[1], result[2]);
                indices.shape = new Shape(-1, 2);
                values.shape = new Shape(-1);
                shape.shape = new Shape(2);

                var sparse_result = new SparseTensor(indices, values, shape);
                return RaggedTensor.from_value_rowids(sparse_result.values,
                                                      value_rowids: sparse_result.indices[Slice.All, 0],
                                                      nrows: sparse_result.dense_shape[0],
                                                      validate: false);
            }));
예제 #2
0
 RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices)
 {
     return(input);
 }