コード例 #1
0
        public RaggedTensor string_split_v2(Tensor input, string sep = " ", int maxsplit = -1, string name = null)
        {
            return(tf_with(ops.name_scope(name, "StringSplit"), scope =>
            {
                var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING);
                if (input.rank == 0)
                {
                    var parts = string_split_v2(array_ops.stack(new[] { input }),
                                                sep: sep,
                                                maxsplit: maxsplit,
                                                name: name);
                    return parts;
                }

                var result = tf.Context.ExecuteOp("StringSplitV2", name,
                                                  new ExecuteOpArgs(input, sep)
                {
                    GetGradientAttrs = op => new
                    {
                        maxsplit = op.get_attr <int>("maxsplit")
                    }
                }.SetAttributes(new { maxsplit }));
                var(indices, values, shape) = (result[0], result[1], result[2]);
                indices.shape = new Shape(-1, 2);
                values.shape = new Shape(-1);
                shape.shape = new Shape(2);

                var sparse_result = new SparseTensor(indices, values, shape);
                return RaggedTensor.from_value_rowids(sparse_result.values,
                                                      value_rowids: sparse_result.indices[Slice.All, 0],
                                                      nrows: sparse_result.dense_shape[0],
                                                      validate: false);
            }));
コード例 #2
0
 public Tensor sparse_tensor_to_dense(SparseTensor sp_input,
                                      Array default_value   = default,
                                      bool validate_indices = true,
                                      string name           = null)
 => gen_sparse_ops.sparse_to_dense(sp_input.indices,
                                   sp_input.dense_shape,
                                   sp_input.values,
                                   default_value: default_value,
                                   validate_indices: validate_indices,
                                   name: name);