1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ragged operations for working with string Tensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.framework import tensor_util
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import control_flow_ops
28from tensorflow.python.ops import gen_string_ops
29from tensorflow.python.ops import string_ops
30from tensorflow.python.ops.ragged import ragged_array_ops
31from tensorflow.python.ops.ragged import ragged_math_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.util import compat as util_compat
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import dispatch
36from tensorflow.python.util.lazy_loader import LazyLoader
37from tensorflow.python.util.tf_export import tf_export
38
39
40map_fn_lib = LazyLoader("map_fn_lib", globals(),
41                        "tensorflow.python.ops.map_fn")
42
43
44@tf_export("strings.bytes_split")
45@dispatch.add_dispatch_support
46def string_bytes_split(input, name=None):  # pylint: disable=redefined-builtin
47  """Split string elements of `input` into bytes.
48
49  Examples:
50
51  >>> tf.strings.bytes_split('hello').numpy()
52  array([b'h', b'e', b'l', b'l', b'o'], dtype=object)
53  >>> tf.strings.bytes_split(['hello', '123'])
54  <tf.RaggedTensor [[b'h', b'e', b'l', b'l', b'o'], [b'1', b'2', b'3']]>
55
56  Note that this op splits strings into bytes, not unicode characters.  To
57  split strings into unicode characters, use `tf.strings.unicode_split`.
58
59  See also: `tf.io.decode_raw`, `tf.strings.split`, `tf.strings.unicode_split`.
60
61  Args:
62    input: A string `Tensor` or `RaggedTensor`: the strings to split.  Must
63      have a statically known rank (`N`).
64    name: A name for the operation (optional).
65
66  Returns:
67    A `RaggedTensor` of rank `N+1`: the bytes that make up the source strings.
68  """
69  with ops.name_scope(name, "StringsByteSplit", [input]):
70    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input,
71                                                             name="input")
72    if isinstance(input, ragged_tensor.RaggedTensor):
73      return input.with_flat_values(string_bytes_split(input.flat_values))
74
75    rank = input.shape.ndims
76    if rank is None:
77      raise ValueError("input must have a statically-known rank.")
78
79    if rank == 0:
80      return string_bytes_split(array_ops.stack([input]))[0]
81    elif rank == 1:
82      indices, values, shape = gen_string_ops.string_split(
83          input, delimiter="", skip_empty=False)
84      return ragged_tensor.RaggedTensor.from_value_rowids(
85          values=values, value_rowids=indices[:, 0], nrows=shape[0],
86          validate=False)
87    else:
88      return string_bytes_split(ragged_tensor.RaggedTensor.from_tensor(input))
89
90
91# pylint: disable=redefined-builtin
92@tf_export("strings.unicode_encode")
93@dispatch.add_dispatch_support
94def unicode_encode(input,
95                   output_encoding,
96                   errors="replace",
97                   replacement_char=65533,
98                   name=None):
99  r"""Encodes each sequence of Unicode code points in `input` into a string.
100
101  `result[i1...iN]` is the string formed by concatenating the Unicode
102  codepoints `input[1...iN, :]`, encoded using `output_encoding`.
103
104  Args:
105    input: An `N+1` dimensional potentially ragged integer tensor with shape
106      `[D1...DN, num_chars]`.
107    output_encoding: Unicode encoding that should be used to encode each
108      codepoint sequence.  Can be `"UTF-8"`, `"UTF-16-BE"`, or `"UTF-32-BE"`.
109    errors: Specifies the response when an invalid codepoint is encountered
110      (optional). One of:
111            * `'replace'`: Replace invalid codepoint with the
112              `replacement_char`. (default)
113            * `'ignore'`: Skip invalid codepoints.
114            * `'strict'`: Raise an exception for any invalid codepoint.
115    replacement_char: The replacement character codepoint to be used in place of
116      any invalid input when `errors='replace'`. Any valid unicode codepoint may
117      be used. The default value is the default unicode replacement character
118      which is 0xFFFD (U+65533).
119    name: A name for the operation (optional).
120
121  Returns:
122    A `N` dimensional `string` tensor with shape `[D1...DN]`.
123
124  #### Example:
125
126  >>> input = tf.ragged.constant(
127  ...     [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]])
128  >>> print(unicode_encode(input, 'UTF-8'))
129  tf.Tensor([b'G\xc3\xb6\xc3\xb6dnight' b'\xf0\x9f\x98\x8a'],
130            shape=(2,), dtype=string)
131  """
132  with ops.name_scope(name, "UnicodeEncode", [input]):
133    input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
134    if input_tensor.shape.ndims is None:
135      raise ValueError("Rank of input_tensor must be statically known.")
136    if ragged_tensor.is_ragged(input_tensor):
137      if input_tensor.flat_values.shape.ndims > 1:
138        # If the flat_values of our ragged tensor is multi-dimensional, we can
139        # process it separately and our output will have the same nested splits
140        # as our input.
141        return input_tensor.with_flat_values(
142            unicode_encode(input_tensor.flat_values, output_encoding, errors,
143                           replacement_char))
144      elif input_tensor.ragged_rank > 1:
145        # Recursively process the values of the ragged tensor.
146        return input_tensor.with_values(
147            unicode_encode(input_tensor.values, output_encoding, errors,
148                           replacement_char))
149      else:
150        # Our ragged tensor is of the correct shape (rank 1 flat_values tensor
151        # with ragged_rank of 1) so we can process it as normal.
152        return gen_string_ops.unicode_encode(
153            input_values=input_tensor.values,
154            input_splits=input_tensor.row_splits,
155            output_encoding=output_encoding,
156            errors=errors,
157            replacement_char=replacement_char)
158    else:
159      if input_tensor.shape.ndims == 2:
160        # The input tensor is of the correct 2-D shape, it's just not ragged.
161        return unicode_encode(
162            ragged_tensor.RaggedTensor.from_tensor(input_tensor),
163            output_encoding, errors, replacement_char)
164      elif input_tensor.shape.ndims > 2:
165        # We need to initially flatten the input tensor to 2-D, and then can
166        # reshape the output of our processed flattened tensor.
167        flat_input_tensor = array_ops.reshape(
168            input_tensor,
169            array_ops.stack([-1, array_ops.shape(input_tensor)[-1]]))
170        flat_output_tensor = unicode_encode(flat_input_tensor, output_encoding,
171                                            errors, replacement_char)
172        return array_ops.reshape(flat_output_tensor, input_tensor.shape[:-1])
173      elif input_tensor.shape.ndims == 0:
174        raise ValueError("input_tensor's rank must be at least 1.")
175      else:
176        # Our input tensor is rank 1, so we create a ragged tensor with an added
177        # dimension to create the correct input shape & type, and then remove
178        # the additional dimension from the output and return the string scalar.
179        ragged_input_tensor = ragged_tensor.RaggedTensor.from_row_splits(
180            input_tensor,
181            array_ops.stack(
182                [0, array_ops.shape(input_tensor, out_type=dtypes.int32)[0]]),
183            validate=False)
184        output_tensor = unicode_encode(ragged_input_tensor, output_encoding,
185                                       errors, replacement_char)
186        return array_ops.reshape(output_tensor, [])
187
188
189# pylint: disable=redefined-builtin
190@tf_export("strings.unicode_decode")
191@dispatch.add_dispatch_support
192def unicode_decode(input,
193                   input_encoding,
194                   errors="replace",
195                   replacement_char=0xFFFD,
196                   replace_control_characters=False,
197                   name=None):
198  r"""Decodes each string in `input` into a sequence of Unicode code points.
199
200  `result[i1...iN, j]` is the Unicode codepoint for the `j`th character in
201  `input[i1...iN]`, when decoded using `input_encoding`.
202
203  Args:
204    input: An `N` dimensional potentially ragged `string` tensor with shape
205      `[D1...DN]`.  `N` must be statically known.
206    input_encoding: String name for the unicode encoding that should be used to
207      decode each string.
208    errors: Specifies the response when an input string can't be converted
209      using the indicated encoding. One of:
210      * `'strict'`: Raise an exception for any illegal substrings.
211      * `'replace'`: Replace illegal substrings with `replacement_char`.
212      * `'ignore'`: Skip illegal substrings.
213    replacement_char: The replacement codepoint to be used in place of invalid
214      substrings in `input` when `errors='replace'`; and in place of C0 control
215      characters in `input` when `replace_control_characters=True`.
216    replace_control_characters: Whether to replace the C0 control characters
217      `(U+0000 - U+001F)` with the `replacement_char`.
218    name: A name for the operation (optional).
219
220  Returns:
221    A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`.
222    The returned tensor is a `tf.Tensor` if `input` is a scalar, or a
223    `tf.RaggedTensor` otherwise.
224
225  #### Example:
226
227  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
228  >>> tf.strings.unicode_decode(input, 'UTF-8').to_list()
229  [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]]
230  """
231  with ops.name_scope(name, "UnicodeDecode", [input]):
232    return _unicode_decode(input, input_encoding, errors, replacement_char,
233                           replace_control_characters, with_offsets=False)
234
235
236@tf_export("strings.unicode_decode_with_offsets")
237@dispatch.add_dispatch_support
238def unicode_decode_with_offsets(input,
239                                input_encoding,
240                                errors="replace",
241                                replacement_char=0xFFFD,
242                                replace_control_characters=False,
243                                name=None):
244  r"""Decodes each string into a sequence of code points with start offsets.
245
246  This op is similar to `tf.strings.decode(...)`, but it also returns the
247  start offset for each character in its respective string.  This information
248  can be used to align the characters with the original byte sequence.
249
250  Returns a tuple `(codepoints, start_offsets)` where:
251
252  * `codepoints[i1...iN, j]` is the Unicode codepoint for the `j`th character
253    in `input[i1...iN]`, when decoded using `input_encoding`.
254  * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th
255    character in `input[i1...iN]`, when decoded using `input_encoding`.
256
257  Args:
258    input: An `N` dimensional potentially ragged `string` tensor with shape
259      `[D1...DN]`.  `N` must be statically known.
260    input_encoding: String name for the unicode encoding that should be used to
261      decode each string.
262    errors: Specifies the response when an input string can't be converted
263      using the indicated encoding. One of:
264      * `'strict'`: Raise an exception for any illegal substrings.
265      * `'replace'`: Replace illegal substrings with `replacement_char`.
266      * `'ignore'`: Skip illegal substrings.
267    replacement_char: The replacement codepoint to be used in place of invalid
268      substrings in `input` when `errors='replace'`; and in place of C0 control
269      characters in `input` when `replace_control_characters=True`.
270    replace_control_characters: Whether to replace the C0 control characters
271      `(U+0000 - U+001F)` with the `replacement_char`.
272    name: A name for the operation (optional).
273
274  Returns:
275    A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`.
276
277    * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`.
278    * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`.
279
280    The returned tensors are `tf.Tensor`s if `input` is a scalar, or
281    `tf.RaggedTensor`s otherwise.
282
283  #### Example:
284
285  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
286  >>> result = tf.strings.unicode_decode_with_offsets(input, 'UTF-8')
287  >>> result[0].to_list()  # codepoints
288  [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]]
289  >>> result[1].to_list()  # offsets
290  [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]]
291
292  """
293  with ops.name_scope(name, "UnicodeDecodeWithOffsets", [input]):
294    return _unicode_decode(input, input_encoding, errors, replacement_char,
295                           replace_control_characters, with_offsets=True)
296
297
298@tf_export("strings.unicode_split")
299@dispatch.add_dispatch_support
300def unicode_split(input,
301                  input_encoding,
302                  errors="replace",
303                  replacement_char=0xFFFD,
304                  name=None):
305  r"""Splits each string in `input` into a sequence of Unicode code points.
306
307  `result[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its
308  `j`th character, when decoded using `input_encoding`.
309
310  Args:
311    input: An `N` dimensional potentially ragged `string` tensor with shape
312      `[D1...DN]`.  `N` must be statically known.
313    input_encoding: String name for the unicode encoding that should be used to
314      decode each string.
315    errors: Specifies the response when an input string can't be converted
316      using the indicated encoding. One of:
317      * `'strict'`: Raise an exception for any illegal substrings.
318      * `'replace'`: Replace illegal substrings with `replacement_char`.
319      * `'ignore'`: Skip illegal substrings.
320    replacement_char: The replacement codepoint to be used in place of invalid
321      substrings in `input` when `errors='replace'`.
322    name: A name for the operation (optional).
323
324  Returns:
325    A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`.
326    The returned tensor is a `tf.Tensor` if `input` is a scalar, or a
327    `tf.RaggedTensor` otherwise.
328
329  #### Example:
330
331  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
332  >>> tf.strings.unicode_split(input, 'UTF-8').to_list()
333  [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'],
334   [b'\xf0\x9f\x98\x8a']]
335  """
336  with ops.name_scope(name, "UnicodeSplit", [input]):
337    codepoints = _unicode_decode(input, input_encoding, errors,
338                                 replacement_char, False, with_offsets=False)
339    return unicode_encode(
340        ragged_array_ops.expand_dims(codepoints, -1),
341        output_encoding=input_encoding,
342        errors=errors,
343        replacement_char=replacement_char)
344
345
346@tf_export("strings.unicode_split_with_offsets")
347@dispatch.add_dispatch_support
348def unicode_split_with_offsets(input,
349                               input_encoding,
350                               errors="replace",
351                               replacement_char=0xFFFD,
352                               name=None):
353  r"""Splits each string into a sequence of code points with start offsets.
354
355  This op is similar to `tf.strings.decode(...)`, but it also returns the
356  start offset for each character in its respective string.  This information
357  can be used to align the characters with the original byte sequence.
358
359  Returns a tuple `(chars, start_offsets)` where:
360
361  * `chars[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its
362    `j`th character, when decoded using `input_encoding`.
363  * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th
364    character in `input[i1...iN]`, when decoded using `input_encoding`.
365
366  Args:
367    input: An `N` dimensional potentially ragged `string` tensor with shape
368      `[D1...DN]`.  `N` must be statically known.
369    input_encoding: String name for the unicode encoding that should be used to
370      decode each string.
371    errors: Specifies the response when an input string can't be converted
372      using the indicated encoding. One of:
373      * `'strict'`: Raise an exception for any illegal substrings.
374      * `'replace'`: Replace illegal substrings with `replacement_char`.
375      * `'ignore'`: Skip illegal substrings.
376    replacement_char: The replacement codepoint to be used in place of invalid
377      substrings in `input` when `errors='replace'`.
378    name: A name for the operation (optional).
379
380  Returns:
381    A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`.
382
383    * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`.
384    * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`.
385
386    The returned tensors are `tf.Tensor`s if `input` is a scalar, or
387    `tf.RaggedTensor`s otherwise.
388
389  #### Example:
390
391  >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
392  >>> result = tf.strings.unicode_split_with_offsets(input, 'UTF-8')
393  >>> result[0].to_list()  # character substrings
394  [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'],
395   [b'\xf0\x9f\x98\x8a']]
396  >>> result[1].to_list()  # offsets
397  [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]]
398
399  """
400  with ops.name_scope(name, "UnicodeSplitWithOffsets", [input]):
401    codepoints, offsets = _unicode_decode(input, input_encoding, errors,
402                                          replacement_char, False,
403                                          with_offsets=True)
404    chars = unicode_encode(
405        ragged_array_ops.expand_dims(codepoints, -1),
406        output_encoding=input_encoding,
407        errors=errors,
408        replacement_char=replacement_char)
409    return chars, offsets
410
411
412def _unicode_decode(input, input_encoding, errors, replacement_char,
413                    replace_control_characters, with_offsets):
414  """Decodes each string into a sequence of codepoints."""
415  input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, name="input")
416  input_ndims = input.shape.ndims
417  if input_ndims is None:
418    raise ValueError("Rank of `input` must be statically known.")
419
420  if input_ndims > 1:
421    # Convert to a ragged tensor with ragged_rank = input_ndims - 1.
422    if not ragged_tensor.is_ragged(input):
423      input = ragged_tensor.RaggedTensor.from_tensor(
424          input, ragged_rank=input_ndims - 1)
425    elif input.ragged_rank < input_ndims - 1:
426      input = input.with_flat_values(
427          ragged_tensor.RaggedTensor.from_tensor(
428              input.flat_values,
429              ragged_rank=input_ndims - input.ragged_rank - 1))
430
431  # Reshape the input to a flat vector, and apply the gen_string_ops op.
432  if ragged_tensor.is_ragged(input):
433    flat_input = array_ops.reshape(input.flat_values, [-1])
434  else:
435    flat_input = array_ops.reshape(input, [-1])
436
437  if with_offsets:
438    decode_op = gen_string_ops.unicode_decode_with_offsets
439  else:
440    decode_op = gen_string_ops.unicode_decode
441  flat_result = decode_op(
442      input=flat_input,
443      input_encoding=input_encoding,
444      errors=errors,
445      replacement_char=replacement_char,
446      replace_control_characters=replace_control_characters)
447
448  if input_ndims == 0:
449    codepoints = flat_result.char_values
450    if with_offsets:
451      offsets = flat_result.char_to_byte_starts
452  else:
453    codepoints = ragged_tensor.RaggedTensor.from_row_splits(
454        flat_result.char_values, flat_result.row_splits, validate=False)
455    if input_ndims > 1:
456      codepoints = input.with_flat_values(codepoints)
457    if with_offsets:
458      offsets = ragged_tensor.RaggedTensor.from_row_splits(
459          flat_result.char_to_byte_starts, flat_result.row_splits,
460          validate=False)
461      if input_ndims > 1:
462        offsets = input.with_flat_values(offsets)
463
464  if with_offsets:
465    return codepoints, offsets
466  else:
467    return codepoints
468
469
470@tf_export("strings.split", v1=[])
471@dispatch.add_dispatch_support
472def string_split_v2(input, sep=None, maxsplit=-1, name=None):  # pylint: disable=redefined-builtin
473  """Split elements of `input` based on `sep` into a `RaggedTensor`.
474
475  Let N be the size of `input` (typically N will be the batch size). Split each
476  element of `input` based on `sep` and return a `RaggedTensor` containing the
477  split tokens. Empty tokens are ignored.
478
479  Example:
480
481  >>> tf.strings.split('hello world').numpy()
482   array([b'hello', b'world'], dtype=object)
483  >>> tf.strings.split(['hello world', 'a b c'])
484  <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
485
486  If `sep` is given, consecutive delimiters are not grouped together and are
487  deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and
488  `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
489  string, consecutive whitespace are regarded as a single separator, and the
490  result will contain no empty strings at the start or end if the string has
491  leading or trailing whitespace.
492
493  Note that the above mentioned behavior matches python's str.split.
494
495  Args:
496    input: A string `Tensor` of rank `N`, the strings to split.  If
497      `rank(input)` is not known statically, then it is assumed to be `1`.
498    sep: `0-D` string `Tensor`, the delimiter string.
499    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
500    name: A name for the operation (optional).
501
502  Raises:
503    ValueError: If sep is not a string.
504
505  Returns:
506    A `RaggedTensor` of rank `N+1`, the strings split according to the
507    delimiter.
508  """
509  with ops.name_scope(name, "StringSplit", [input]):
510    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
511        input, dtype=dtypes.string, name="input")
512    if isinstance(input, ragged_tensor.RaggedTensor):
513      return input.with_flat_values(
514          string_split_v2(input.flat_values, sep, maxsplit))
515
516    rank = input.shape.ndims
517    if rank == 0:
518      return string_split_v2(array_ops.stack([input]), sep, maxsplit)[0]
519    elif rank == 1 or rank is None:
520      sparse_result = string_ops.string_split_v2(
521          input, sep=sep, maxsplit=maxsplit)
522      return ragged_tensor.RaggedTensor.from_value_rowids(
523          values=sparse_result.values,
524          value_rowids=sparse_result.indices[:, 0],
525          nrows=sparse_result.dense_shape[0],
526          validate=False)
527    else:
528      return string_split_v2(
529          ragged_tensor.RaggedTensor.from_tensor(input), sep, maxsplit)
530
531
532@tf_export(v1=["string_split"])
533@dispatch.add_dispatch_support
534@deprecation.deprecated_args(None,
535                             "delimiter is deprecated, please use sep instead.",
536                             "delimiter")
537def string_split(source, sep=None, skip_empty=True, delimiter=None,
538                 result_type="SparseTensor", name=None):  # pylint: disable=invalid-name
539  """Split elements of `source` based on `delimiter`.
540
541  Let N be the size of `source` (typically N will be the batch size). Split each
542  element of `source` based on `delimiter` and return a `SparseTensor`
543  or `RaggedTensor` containing the split tokens. Empty tokens are ignored.
544
545  If `sep` is an empty string, each element of the `source` is split
546  into individual strings, each containing one byte. (This includes splitting
547  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
548  treated as a set of delimiters with each considered a potential split point.
549
550  Examples:
551
552  >>> print(tf.compat.v1.string_split(['hello world', 'a b c']))
553  SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...),
554               values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...),
555               dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64))
556
557  >>> print(tf.compat.v1.string_split(['hello world', 'a b c'],
558  ...     result_type="RaggedTensor"))
559  <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
560
561  Args:
562    source: `1-D` string `Tensor`, the strings to split.
563    sep: `0-D` string `Tensor`, the delimiter character, the string should
564      be length 0 or 1. Default is ' '.
565    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
566    delimiter: deprecated alias for `sep`.
567    result_type: The tensor type for the result: one of `"RaggedTensor"` or
568      `"SparseTensor"`.
569    name: A name for the operation (optional).
570
571  Raises:
572    ValueError: If delimiter is not a string.
573
574  Returns:
575    A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
576    to the delimiter.  The first column of the indices corresponds to the row
577    in `source` and the second column corresponds to the index of the split
578    component in this row.
579  """
580  with ops.name_scope(name, "StringSplit", [source]):
581    sparse_result = string_ops.string_split(
582        source, sep=sep, skip_empty=skip_empty, delimiter=delimiter)
583    if result_type == "SparseTensor":
584      return sparse_result
585    elif result_type == "RaggedTensor":
586      return ragged_tensor.RaggedTensor.from_value_rowids(
587          values=sparse_result.values,
588          value_rowids=sparse_result.indices[:, 0],
589          nrows=sparse_result.dense_shape[0],
590          validate=False)
591    else:
592      raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
593
594
595# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit),
596# but we need to add the result_type argument.
597@tf_export(v1=["strings.split"])
598@dispatch.add_dispatch_support
599def strings_split_v1(input=None, sep=None, maxsplit=-1,  # pylint: disable=redefined-builtin
600                     result_type="SparseTensor", source=None, name=None):
601  """Split elements of `input` based on `sep`.
602
603  Let N be the size of `input` (typically N will be the batch size). Split each
604  element of `input` based on `sep` and return a `SparseTensor` or
605  `RaggedTensor` containing the split tokens. Empty tokens are ignored.
606
607  Examples:
608
609  >>> print(tf.compat.v1.strings.split(['hello world', 'a b c']))
610  SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...),
611               values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...),
612               dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64))
613
614  >>> print(tf.compat.v1.strings.split(['hello world', 'a b c'],
615  ...     result_type="RaggedTensor"))
616  <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
617
618  If `sep` is given, consecutive delimiters are not grouped together and are
619  deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and
620  `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
621  string, consecutive whitespace are regarded as a single separator, and the
622  result will contain no empty strings at the start or end if the string has
623  leading or trailing whitespace.
624
625  Note that the above mentioned behavior matches python's str.split.
626
627  Args:
628    input: A string `Tensor` of rank `N`, the strings to split.  If
629      `rank(input)` is not known statically, then it is assumed to be `1`.
630    sep: `0-D` string `Tensor`, the delimiter character.
631    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
632    result_type: The tensor type for the result: one of `"RaggedTensor"` or
633      `"SparseTensor"`.
634    source: alias for "input" argument.
635    name: A name for the operation (optional).
636
637  Raises:
638    ValueError: If sep is not a string.
639
640  Returns:
641    A `SparseTensor` or `RaggedTensor` of rank `N+1`, the strings split
642    according to the delimiter.
643  """
644  input = deprecation.deprecated_argument_lookup(
645      "input", input, "source", source)
646  with ops.name_scope(name, "StringSplit", [input]):
647    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
648        input, dtype=dtypes.string, name="input")
649
650    if input.shape.rank == 0:
651      input = array_ops.expand_dims(input, 0)
652
653    if result_type == "SparseTensor":
654      if input.shape.rank == 1:
655        return string_ops.string_split_v2(input, sep=sep, maxsplit=maxsplit)
656      else:
657        return string_split_v2(input, sep=sep, maxsplit=maxsplit).to_sparse()
658    elif result_type == "RaggedTensor":
659      return string_split_v2(input, sep=sep, maxsplit=maxsplit)
660    else:
661      raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
662
663
664def reduce_join(inputs, axis=None, keepdims=None, separator="", name=None):
665  """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
666  return ragged_math_ops.ragged_reduce_aggregate(
667      string_ops.reduce_join, string_ops.unsorted_segment_join, inputs, axis,
668      keepdims, separator, name or "RaggedSegmentJoin")
669
670
671@tf_export("strings.ngrams")
672@dispatch.add_dispatch_support
673def ngrams(data,
674           ngram_width,
675           separator=" ",
676           pad_values=None,
677           padding_width=None,
678           preserve_short_sequences=False,
679           name=None):
680  """Create a tensor of n-grams based on `data`.
681
682  Creates a tensor of n-grams based on `data`. The n-grams are created by
683  joining windows of `width` adjacent strings from the inner axis of `data`
684  using `separator`.
685
686  The input data can be padded on both the start and end of the sequence, if
687  desired, using the `pad_values` argument. If set, `pad_values` should contain
688  either a tuple of strings or a single string; the 0th element of the tuple
689  will be used to pad the left side of the sequence and the 1st element of the
690  tuple will be used to pad the right side of the sequence. The `padding_width`
691  arg controls how many padding values are added to each side; it defaults to
692  `ngram_width-1`.
693
694  If this op is configured to not have padding, or if it is configured to add
695  padding with `padding_width` set to less than ngram_width-1, it is possible
696  that a sequence, or a sequence plus padding, is smaller than the ngram
697  width. In that case, no ngrams will be generated for that sequence. This can
698  be prevented by setting `preserve_short_sequences`, which will cause the op
699  to always generate at least one ngram per non-empty sequence.
700
701  Examples:
702
703  >>> tf.strings.ngrams(["A", "B", "C", "D"], 2).numpy()
704  array([b'A B', b'B C', b'C D'], dtype=object)
705  >>> tf.strings.ngrams(["TF", "and", "keras"], 1).numpy()
706  array([b'TF', b'and', b'keras'], dtype=object)
707
708  Args:
709    data: A Tensor or RaggedTensor containing the source data for the ngrams.
710    ngram_width: The width(s) of the ngrams to create. If this is a list or
711      tuple, the op will return ngrams of all specified arities in list order.
712      Values must be non-Tensor integers greater than 0.
713    separator: The separator string used between ngram elements. Must be a
714      string constant, not a Tensor.
715    pad_values: A tuple of (left_pad_value, right_pad_value), a single string,
716      or None. If None, no padding will be added; if a single string, then that
717      string will be used for both left and right padding. Values must be Python
718      strings.
719    padding_width: If set, `padding_width` pad values will be added to both
720      sides of each sequence. Defaults to `ngram_width`-1. Must be greater than
721      0. (Note that 1-grams are never padded, regardless of this value.)
722    preserve_short_sequences: If true, then ensure that at least one ngram is
723      generated for each input sequence.  In particular, if an input sequence is
724      shorter than `min(ngram_width) + 2*pad_width`, then generate a single
725      ngram containing the entire sequence.  If false, then no ngrams are
726      generated for these short input sequences.
727    name: The op name.
728
729  Returns:
730    A RaggedTensor of ngrams. If `data.shape=[D1...DN, S]`, then
731    `output.shape=[D1...DN, NUM_NGRAMS]`, where
732    `NUM_NGRAMS=S-ngram_width+1+2*padding_width`.
733
734  Raises:
735    TypeError: if `pad_values` is set to an invalid type.
736    ValueError: if `pad_values`, `padding_width`, or `ngram_width` is set to an
737      invalid value.
738  """
739
740  with ops.name_scope(name, "StringNGrams", [data]):
741    if pad_values is None:
742      left_pad = ""
743      right_pad = ""
744    elif isinstance(pad_values, (list, tuple)):
745      if (not isinstance(pad_values[0], util_compat.bytes_or_text_types) or
746          not isinstance(pad_values[1], util_compat.bytes_or_text_types)):
747        raise TypeError(
748            "pad_values must be a string, tuple of strings, or None.")
749      left_pad = pad_values[0]
750      right_pad = pad_values[1]
751    else:
752      if not isinstance(pad_values, util_compat.bytes_or_text_types):
753        raise TypeError(
754            "pad_values must be a string, tuple of strings, or None.")
755      left_pad = pad_values
756      right_pad = pad_values
757
758    if padding_width is not None and padding_width < 1:
759      raise ValueError("padding_width must be greater than 0.")
760
761    if padding_width is not None and pad_values is None:
762      raise ValueError("pad_values must be provided if padding_width is set.")
763
764    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(
765        data, name="data", dtype=dtypes.string)
766
767    # preserve the shape of the data if it is a tensor
768    to_tensor = False
769    if isinstance(data, ops.Tensor):
770      dense_shape = array_ops.concat([array_ops.shape(data)[:-1], [-1]], axis=0)
771      to_tensor = True
772
773    if not isinstance(data, ragged_tensor.RaggedTensor):
774      if data.shape.ndims is None:
775        raise ValueError("Rank of data must be known.")
776      elif data.shape.ndims == 0:
777        raise ValueError("Data must have rank>0")
778      elif data.shape.ndims == 1:
779        rt = ragged_tensor.RaggedTensor.from_row_starts(
780            data, [0], validate=False)
781        return ngrams(rt, ngram_width, separator, pad_values, padding_width,
782                      preserve_short_sequences, name)[0]
783      else:
784        data = ragged_tensor.RaggedTensor.from_tensor(
785            data, ragged_rank=data.shape.ndims - 1)
786
787    if data.ragged_rank > 1:
788      output = data.with_values(
789          ngrams(data.values, ngram_width, separator, pad_values, padding_width,
790                 preserve_short_sequences, name))
791      return array_ops.reshape(output.flat_values,
792                               dense_shape) if to_tensor else output
793
794    if pad_values is None:
795      padding_width = 0
796
797    if pad_values is not None and padding_width is None:
798      padding_width = -1
799
800    if not isinstance(ngram_width, (list, tuple)):
801      ngram_widths = [ngram_width]
802    else:
803      ngram_widths = ngram_width
804    for width in ngram_widths:
805      if width < 1:
806        raise ValueError("All ngram_widths must be greater than 0. Got %s" %
807                         ngram_width)
808
809    output, output_splits = gen_string_ops.string_n_grams(
810        data=data.flat_values,
811        data_splits=data.row_splits,
812        separator=separator,
813        ngram_widths=ngram_widths,
814        left_pad=left_pad,
815        right_pad=right_pad,
816        pad_width=padding_width,
817        preserve_short_sequences=preserve_short_sequences)
818
819    # if the input is Dense tensor, the output should also be a dense tensor
820    output = ragged_tensor.RaggedTensor.from_row_splits(
821        values=output, row_splits=output_splits, validate=False)
822    return array_ops.reshape(output.flat_values,
823                             dense_shape) if to_tensor else output
824
825
826def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
827  """Version of tf.strings.format that handles RaggedTensors."""
828  if tensor_util.is_tf_type(inputs) or ragged_tensor.is_ragged(inputs):
829    inputs = [inputs]
830
831  split_template = template.split(placeholder)
832  if len(inputs) != len(split_template) - 1:
833    raise ValueError("num placeholders in template and num inputs must match"
834                     ": {} vs {}".format(len(split_template) - 1, len(inputs)))
835
836  with ops.name_scope(name, "StringFormat", [inputs]):
837    output_pieces = [constant_op.constant(split_template[0])]
838    for i, input in enumerate(inputs):
839      if ragged_tensor.is_ragged(input):
840        output_pieces.append(ragged_tensor_to_string(input, summarize))
841      else:
842        output_pieces.append(string_ops.string_format(
843            "{}", [input], summarize=summarize))
844      output_pieces.append(constant_op.constant(split_template[i + 1]))
845    if len(output_pieces) == 1:
846      return output_pieces[0]
847    else:
848      return string_ops.reduce_join(output_pieces)
849
850
851def ragged_tensor_to_string(rt, summarize=None):
852  """Returns a scalar string tensor with the contents of a RaggedTensor.
853
854  Requires that `rt.shape.rank` is not `None`.
855
856  Note: this converts the entire `RaggedTensor` into a single string scalar.
857  If you want to convert individual elements, use `tf.strings.as_string(rt)`.
858
859  >>> rt1 = tf.ragged.constant([[1, 2, 3], [4, 5]])
860  >>> ragged_tensor_to_string(rt1).numpy()
861  b'[[1, 2, 3], [4, 5]]'
862
863  >>> rt2 = tf.ragged.constant([[['a'], ['b', 'c']], [['d', 'e', 'f'], []]])
864  >>> ragged_tensor_to_string(rt2).numpy()
865  b"[[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]"
866
867  >>> rt3 = tf.ragged.constant([[1], [2, 3, 4, 5, 6], [], [], [7], [8, 9]])
868  >>> ragged_tensor_to_string(rt3, summarize=2).numpy()
869  b'[[1], [2, 3, ..., 5, 6], ..., [7], [8, 9]]'
870
871  Args:
872    rt: The RaggedTensor that should be converted to a string.
873    summarize: If specified, then only the first and last `summarize` elements
874      within each dimension are included in the string. If `-1` or `None`, then
875      all elements are included.
876  """
877  if (summarize is not None and summarize != -1 and
878      not (isinstance(summarize, int) and summarize > 0)):
879    raise ValueError("Expected summarize to be -1 or a positive int, got %r" %
880                     summarize)
881  with ops.name_scope(None, "AsString", [rt]):
882    rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
883    if rt.shape.rank is None:
884      raise ValueError("RaggedTensor to_string requires that rt.shape.rank "
885                       "is not None.")
886    # Convert all elements of `rt` to strings.
887    if rt.dtype == dtypes.string:
888      escaped = string_ops.regex_replace(rt.flat_values, r"(['\\])", r"\\\1")
889      str_t = rt.with_flat_values("'" + escaped + "'")
890    else:
891      str_t = rt.with_flat_values(string_ops.as_string(rt.flat_values))
892
893    return _ragged_tensor_to_string(str_t, summarize)
894
895
896def _ragged_tensor_to_string(string_tensor, summarize):
897  """Returns a scalar string tensor with the contents of `string_tensor`.
898
899  Args:
900    string_tensor: A potentially ragged tensor with dtype=string.
901    summarize: Include only the first and last `summarize` elements of each
902      dimension.  If `-1` or `None`, then include all elements.
903
904  Returns:
905    A scalar string Tensor.
906  """
907  if string_tensor.shape.rank == 1:
908    pieces = string_tensor
909  else:
910    pieces = map_fn_lib.map_fn(
911        lambda s: _ragged_tensor_to_string(s, summarize),
912        string_tensor,
913        fn_output_signature=tensor_spec.TensorSpec(None, dtypes.string))
914  if summarize not in (-1, None):
915    pieces = control_flow_ops.cond(
916        _nrows(string_tensor) <= 2 * summarize,
917        lambda: pieces,
918        lambda: array_ops.concat(  # pylint: disable=g-long-lambda
919            [pieces[:summarize], ["..."], pieces[-summarize:]],
920            axis=0))
921  return "[" + string_ops.reduce_join(pieces, separator=", ") + "]"
922
923
924def _nrows(tensor, out_type=dtypes.int32):
925  if isinstance(tensor, ragged_tensor.RaggedTensor):
926    return tensor.nrows(out_type=out_type)
927  else:
928    return array_ops.shape(tensor, out_type=out_type)[0]
929