1# -*- coding: utf-8 -*-
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16
17"""Operations for working with string Tensors."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_parsing_ops
32from tensorflow.python.ops import gen_string_ops
33from tensorflow.python.ops import math_ops
34
35# go/tf-wildcard-import
36# pylint: disable=wildcard-import
37# pylint: disable=g-bad-import-order
38from tensorflow.python.ops.gen_string_ops import *
39from tensorflow.python.util import compat as util_compat
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.tf_export import tf_export
43# pylint: enable=g-bad-import-order
44# pylint: enable=wildcard-import
45
46
47# pylint: disable=redefined-builtin
48@tf_export("strings.regex_full_match")
49@dispatch.add_dispatch_support
50def regex_full_match(input, pattern, name=None):
51  r"""Match elements of `input` with regex `pattern`.
52
53  Args:
54    input: string `Tensor`, the source strings to process.
55    pattern: string or scalar string `Tensor`, regular expression to use,
56      see more details at https://github.com/google/re2/wiki/Syntax
57    name: Name of the op.
58
59  Returns:
60    bool `Tensor` of the same shape as `input` with match results.
61  """
62  if isinstance(pattern, util_compat.bytes_or_text_types):
63    # When `pattern` is static through the life of the op we can
64    # use a version which performs the expensive regex compilation once at
65    # creation time.
66    return gen_string_ops.static_regex_full_match(
67        input=input, pattern=pattern, name=name)
68  return gen_string_ops.regex_full_match(
69      input=input, pattern=pattern, name=name)
70
71regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
72
73
74@tf_export(
75    "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
76@dispatch.add_dispatch_support
77@deprecation.deprecated_endpoints("regex_replace")
78@dispatch.add_dispatch_support
79def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
80  r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
81
82  >>> tf.strings.regex_replace("Text with tags.<br /><b>contains html</b>",
83  ...                          "<[^>]+>", " ")
84  <tf.Tensor: shape=(), dtype=string, numpy=b'Text with tags.  contains html '>
85
86  Args:
87    input: string `Tensor`, the source strings to process.
88    pattern: string or scalar string `Tensor`, regular expression to use,
89      see more details at https://github.com/google/re2/wiki/Syntax
90    rewrite: string or scalar string `Tensor`, value to use in match
91      replacement, supports backslash-escaped digits (\1 to \9) can be to insert
92      text matching corresponding parenthesized group.
93    replace_global: `bool`, if `True` replace all non-overlapping matches,
94      else replace only the first match.
95    name: A name for the operation (optional).
96
97  Returns:
98    string `Tensor` of the same shape as `input` with specified replacements.
99  """
100  if (isinstance(pattern, util_compat.bytes_or_text_types) and
101      isinstance(rewrite, util_compat.bytes_or_text_types)):
102    # When `pattern` and `rewrite` are static through the life of the op we can
103    # use a version which performs the expensive regex compilation once at
104    # creation time.
105    return gen_string_ops.static_regex_replace(
106        input=input, pattern=pattern,
107        rewrite=rewrite, replace_global=replace_global,
108        name=name)
109  return gen_string_ops.regex_replace(
110      input=input, pattern=pattern,
111      rewrite=rewrite, replace_global=replace_global,
112      name=name)
113
114
115@tf_export("strings.format")
116@dispatch.add_dispatch_support
117def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
118  r"""Formats a string template using a list of tensors.
119
120  Formats a string template using a list of tensors, abbreviating tensors by
121  only printing the first and last `summarize` elements of each dimension
122  (recursively). If formatting only one tensor into a template, the tensor does
123  not have to be wrapped in a list.
124
125  Example:
126    Formatting a single-tensor template:
127
128    >>> tensor = tf.range(5)
129    >>> tf.strings.format("tensor: {}, suffix", tensor)
130    <tf.Tensor: shape=(), dtype=string, numpy=b'tensor: [0 1 2 3 4], suffix'>
131
132    Formatting a multi-tensor template:
133
134    >>> tensor_a = tf.range(2)
135    >>> tensor_b = tf.range(1, 4, 2)
136    >>> tf.strings.format("a: {}, b: {}, suffix", (tensor_a, tensor_b))
137    <tf.Tensor: shape=(), dtype=string, numpy=b'a: [0 1], b: [1 3], suffix'>
138
139
140  Args:
141    template: A string template to format tensor values into.
142    inputs: A list of `Tensor` objects, or a single Tensor.
143      The list of tensors to format into the template string. If a solitary
144      tensor is passed in, the input tensor will automatically be wrapped as a
145      list.
146    placeholder: An optional `string`. Defaults to `{}`.
147      At each placeholder occurring in the template, a subsequent tensor
148      will be inserted.
149    summarize: An optional `int`. Defaults to `3`.
150      When formatting the tensors, show the first and last `summarize`
151      entries of each tensor dimension (recursively). If set to -1, all
152      elements of the tensor will be shown.
153    name: A name for the operation (optional).
154
155  Returns:
156    A scalar `Tensor` of type `string`.
157
158  Raises:
159    ValueError: if the number of placeholders does not match the number of
160      inputs.
161  """
162  # If there is only one tensor to format, we will automatically wrap it in a
163  # list to simplify the user experience
164  if tensor_util.is_tf_type(inputs):
165    inputs = [inputs]
166  if template.count(placeholder) != len(inputs):
167    raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
168                     " provided as input" % (template.count(placeholder),
169                                             len(inputs)))
170
171  return gen_string_ops.string_format(inputs,
172                                      template=template,
173                                      placeholder=placeholder,
174                                      summarize=summarize,
175                                      name=name)
176
177
178# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
179# defines a wrapper for this function.
180def string_split(source, sep=None, skip_empty=True, delimiter=None):  # pylint: disable=invalid-name
181  """Split elements of `source` based on `delimiter` into a `SparseTensor`.
182
183  Let N be the size of source (typically N will be the batch size). Split each
184  element of `source` based on `delimiter` and return a `SparseTensor`
185  containing the split tokens. Empty tokens are ignored.
186
187  If `sep` is an empty string, each element of the `source` is split
188  into individual strings, each containing one byte. (This includes splitting
189  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
190  treated as a set of delimiters with each considered a potential split point.
191
192  For example:
193  N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
194  will be
195
196  st.indices = [0, 0;
197                0, 1;
198                1, 0;
199                1, 1;
200                1, 2]
201  st.shape = [2, 3]
202  st.values = ['hello', 'world', 'a', 'b', 'c']
203
204  Args:
205    source: `1-D` string `Tensor`, the strings to split.
206    sep: `0-D` string `Tensor`, the delimiter character, the string should
207      be length 0 or 1. Default is ' '.
208    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
209    delimiter: deprecated alias for `sep`.
210
211  Raises:
212    ValueError: If delimiter is not a string.
213
214  Returns:
215    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
216    The first column of the indices corresponds to the row in `source` and the
217    second column corresponds to the index of the split component in this row.
218  """
219  delimiter = deprecation.deprecated_argument_lookup(
220      "sep", sep, "delimiter", delimiter)
221
222  if delimiter is None:
223    delimiter = " "
224  delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)
225  source = ops.convert_to_tensor(source, dtype=dtypes.string)
226
227  indices, values, shape = gen_string_ops.string_split(
228      source, delimiter=delimiter, skip_empty=skip_empty)
229  indices.set_shape([None, 2])
230  values.set_shape([None])
231  shape.set_shape([2])
232  return sparse_tensor.SparseTensor(indices, values, shape)
233
234
235# Note: tf.strings.split is exported in ragged/ragged_string_ops.py, which
236# defines a wrapper for this function.
237def string_split_v2(source, sep=None, maxsplit=-1):
238  """Split elements of `source` based on `sep` into a `SparseTensor`.
239
240  Let N be the size of source (typically N will be the batch size). Split each
241  element of `source` based on `sep` and return a `SparseTensor`
242  containing the split tokens. Empty tokens are ignored.
243
244  For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
245  then the output will be
246
247  st.indices = [0, 0;
248                0, 1;
249                1, 0;
250                1, 1;
251                1, 2]
252  st.shape = [2, 3]
253  st.values = ['hello', 'world', 'a', 'b', 'c']
254
255  If `sep` is given, consecutive delimiters are not grouped together and are
256  deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
257  sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
258  string, consecutive whitespace are regarded as a single separator, and the
259  result will contain no empty strings at the start or end if the string has
260  leading or trailing whitespace.
261
262  Note that the above mentioned behavior matches python's str.split.
263
264  Args:
265    source: `1-D` string `Tensor`, the strings to split.
266    sep: `0-D` string `Tensor`, the delimiter character.
267    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
268
269  Raises:
270    ValueError: If sep is not a string.
271
272  Returns:
273    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
274    The first column of the indices corresponds to the row in `source` and the
275    second column corresponds to the index of the split component in this row.
276  """
277  if sep is None:
278    sep = ""
279  sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
280  source = ops.convert_to_tensor(source, dtype=dtypes.string)
281
282  indices, values, shape = gen_string_ops.string_split_v2(
283      source, sep=sep, maxsplit=maxsplit)
284  indices.set_shape([None, 2])
285  values.set_shape([None])
286  shape.set_shape([2])
287  return sparse_tensor.SparseTensor(indices, values, shape)
288
289
290def _reduce_join_reduction_dims(x, axis):
291  """Returns range(rank(x) - 1, 0, -1) if axis is None; or axis otherwise."""
292  if axis is not None:
293    return axis
294  else:
295    # Fast path: avoid creating Rank and Range ops if ndims is known.
296    if x.get_shape().ndims is not None:
297      return constant_op.constant(
298          np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32)
299
300    # Otherwise, we rely on Range and Rank to do the right thing at run-time.
301    return math_ops.range(array_ops.rank(x) - 1, -1, -1)
302
303
304@tf_export(v1=["strings.reduce_join", "reduce_join"])
305@dispatch.add_dispatch_support
306@deprecation.deprecated_args(None,
307                             "keep_dims is deprecated, use keepdims instead",
308                             "keep_dims")
309@deprecation.deprecated_endpoints("reduce_join")
310def reduce_join(inputs, axis=None,  # pylint: disable=missing-docstring
311                keep_dims=None,
312                separator="",
313                name=None,
314                reduction_indices=None,
315                keepdims=None):
316  keepdims = deprecation.deprecated_argument_lookup("keepdims", keepdims,
317                                                    "keep_dims", keep_dims)
318  if keep_dims is None:
319    keep_dims = False
320  axis = deprecation.deprecated_argument_lookup("axis", axis,
321                                                "reduction_indices",
322                                                reduction_indices)
323  return reduce_join_v2(
324      inputs=inputs,
325      axis=axis,
326      keepdims=keepdims,
327      separator=separator,
328      name=name)
329
330
331@tf_export("strings.reduce_join", v1=[])
332@dispatch.add_dispatch_support
333def reduce_join_v2(  # pylint: disable=missing-docstring
334    inputs,
335    axis=None,
336    keepdims=False,
337    separator="",
338    name=None):
339  """Joins all strings into a single string, or joins along an axis.
340
341  This is the reduction operation for the elementwise `tf.strings.join` op.
342
343  >>> tf.strings.reduce_join([['abc','123'],
344  ...                         ['def','456']]).numpy()
345  b'abc123def456'
346  >>> tf.strings.reduce_join([['abc','123'],
347  ...                         ['def','456']], axis=-1).numpy()
348  array([b'abc123', b'def456'], dtype=object)
349  >>> tf.strings.reduce_join([['abc','123'],
350  ...                         ['def','456']],
351  ...                        axis=-1,
352  ...                        separator=" ").numpy()
353  array([b'abc 123', b'def 456'], dtype=object)
354
355  Args:
356    inputs: A `tf.string` tensor.
357    axis: Which axis to join along. The default behavior is to join all
358      elements, producing a scalar.
359    keepdims: If true, retains reduced dimensions with length 1.
360    separator: a string added between each string being joined.
361    name: A name for the operation (optional).
362
363  Returns:
364    A `tf.string` tensor.
365  """
366  with ops.name_scope(None, "ReduceJoin", [inputs, axis]):
367    inputs_t = ops.convert_to_tensor(inputs)
368    axis = _reduce_join_reduction_dims(inputs_t, axis)
369    return gen_string_ops.reduce_join(
370        inputs=inputs_t,
371        reduction_indices=axis,
372        keep_dims=keepdims,
373        separator=separator,
374        name=name)
375
376reduce_join.__doc__ = reduce_join_v2.__doc__
377
378
379# This wrapper provides backwards compatibility for code that predates the
380# unit argument and that passed 'name' as a positional argument.
381@tf_export(v1=["strings.length"])
382@dispatch.add_dispatch_support
383def string_length(input, name=None, unit="BYTE"):
384  """Computes the length of each string given in the input tensor.
385
386  >>> strings = tf.constant(['Hello','TensorFlow', '��'])
387  >>> tf.strings.length(strings).numpy() # default counts bytes
388  array([ 5, 10, 4], dtype=int32)
389  >>> tf.strings.length(strings, unit="UTF8_CHAR").numpy()
390  array([ 5, 10, 1], dtype=int32)
391
392  Args:
393    input: A `Tensor` of type `string`. The strings for which to compute the
394      length for each element.
395    name: A name for the operation (optional).
396    unit: An optional `string` from: `"BYTE", "UTF8_CHAR"`. Defaults to
397      `"BYTE"`. The unit that is counted to compute string length.  One of:
398        `"BYTE"` (for the number of bytes in each string) or `"UTF8_CHAR"` (for
399        the number of UTF-8 encoded Unicode code points in each string). Results
400        are undefined if `unit=UTF8_CHAR` and the `input` strings do not contain
401        structurally valid UTF-8.
402
403  Returns:
404    A `Tensor` of type `int32`, containing the length of the input string in
405    the same element of the input tensor.
406  """
407  return gen_string_ops.string_length(input, unit=unit, name=name)
408
409
410@tf_export("strings.length", v1=[])
411@dispatch.add_dispatch_support
412def string_length_v2(input, unit="BYTE", name=None):
413  return gen_string_ops.string_length(input, unit=unit, name=name)
414
415
416string_length_v2.__doc__ = gen_string_ops.string_length.__doc__
417
418
419@tf_export(v1=["substr"])
420@dispatch.add_dispatch_support
421@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
422def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
423  return substr(input, pos, len, name=name, unit=unit)
424
425substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
426
427
428@tf_export(v1=["strings.substr"])
429@dispatch.add_dispatch_support
430def substr(input, pos, len, name=None, unit="BYTE"):
431  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
432
433substr.__doc__ = gen_string_ops.substr.__doc__
434
435
436@tf_export("strings.substr", v1=[])
437@dispatch.add_dispatch_support
438def substr_v2(input, pos, len, unit="BYTE", name=None):
439  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
440
441substr_v2.__doc__ = gen_string_ops.substr.__doc__
442
443
444ops.NotDifferentiable("RegexReplace")
445ops.NotDifferentiable("StringToHashBucket")
446ops.NotDifferentiable("StringToHashBucketFast")
447ops.NotDifferentiable("StringToHashBucketStrong")
448ops.NotDifferentiable("ReduceJoin")
449ops.NotDifferentiable("StringJoin")
450ops.NotDifferentiable("StringSplit")
451ops.NotDifferentiable("AsString")
452ops.NotDifferentiable("EncodeBase64")
453ops.NotDifferentiable("DecodeBase64")
454
455
456@tf_export("strings.to_number", v1=[])
457@dispatch.add_dispatch_support
458def string_to_number(input, out_type=dtypes.float32, name=None):
459  r"""Converts each string in the input Tensor to the specified numeric type.
460
461  (Note that int32 overflow results in an error while float overflow
462  results in a rounded value.)
463
464  Examples:
465
466  >>> tf.strings.to_number("1.55")
467  <tf.Tensor: shape=(), dtype=float32, numpy=1.55>
468  >>> tf.strings.to_number("3", tf.int32)
469  <tf.Tensor: shape=(), dtype=int32, numpy=3>
470
471  Args:
472    input: A `Tensor` of type `string`.
473    out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32,
474      tf.int64`. Defaults to `tf.float32`.
475      The numeric type to interpret each string in `string_tensor` as.
476    name: A name for the operation (optional).
477
478  Returns:
479    A `Tensor` of type `out_type`.
480  """
481  return gen_parsing_ops.string_to_number(input, out_type, name)
482
483
484@tf_export(v1=["strings.to_number", "string_to_number"])
485@dispatch.add_dispatch_support
486def string_to_number_v1(
487    string_tensor=None,
488    out_type=dtypes.float32,
489    name=None,
490    input=None):
491  string_tensor = deprecation.deprecated_argument_lookup(
492      "input", input, "string_tensor", string_tensor)
493  return gen_parsing_ops.string_to_number(string_tensor, out_type, name)
494
495string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__
496
497
498@tf_export("strings.to_hash_bucket", v1=[])
499@dispatch.add_dispatch_support
500def string_to_hash_bucket(input, num_buckets, name=None):
501  # pylint: disable=line-too-long
502  r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
503
504  The hash function is deterministic on the content of the string within the
505  process.
506
507  Note that the hash function may change from time to time.
508  This functionality will be deprecated and it's recommended to use
509  `tf.strings.to_hash_bucket_fast()` or `tf.strings.to_hash_bucket_strong()`.
510
511  Examples:
512
513  >>> tf.strings.to_hash_bucket(["Hello", "TensorFlow", "2.x"], 3)
514  <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 0, 1])>
515
516  Args:
517    input: A `Tensor` of type `string`.
518    num_buckets: An `int` that is `>= 1`. The number of buckets.
519    name: A name for the operation (optional).
520
521  Returns:
522    A `Tensor` of type `int64`.
523  """
524  # pylint: enable=line-too-long
525  return gen_string_ops.string_to_hash_bucket(input, num_buckets, name)
526
527
528@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"])
529@dispatch.add_dispatch_support
530def string_to_hash_bucket_v1(
531    string_tensor=None,
532    num_buckets=None,
533    name=None,
534    input=None):
535  string_tensor = deprecation.deprecated_argument_lookup(
536      "input", input, "string_tensor", string_tensor)
537  return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name)
538
539string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__
540
541
542@tf_export("strings.join", v1=["strings.join", "string_join"])
543@dispatch.add_dispatch_support
544@deprecation.deprecated_endpoints("string_join")
545@dispatch.add_dispatch_support
546def string_join(inputs, separator="", name=None):
547  """Perform element-wise concatenation of a list of string tensors.
548
549  Given a list of string tensors of same shape, performs element-wise
550  concatenation of the strings of the same index in all tensors.
551
552
553  >>> tf.strings.join(['abc','def']).numpy()
554  b'abcdef'
555  >>> tf.strings.join([['abc','123'],
556  ...                  ['def','456'],
557  ...                  ['ghi','789']]).numpy()
558  array([b'abcdefghi', b'123456789'], dtype=object)
559  >>> tf.strings.join([['abc','123'],
560  ...                  ['def','456']],
561  ...                  separator=" ").numpy()
562  array([b'abc def', b'123 456'], dtype=object)
563
564  The reduction version of this elementwise operation is
565  `tf.strings.reduce_join`
566
567  Args:
568    inputs: A list of `tf.Tensor` objects of same size and `tf.string` dtype.
569    separator: A string added between each string being joined.
570    name: A name for the operation (optional).
571
572  Returns:
573    A `tf.string` tensor.
574  """
575  return gen_string_ops.string_join(inputs, separator=separator, name=name)
576