1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Operations for working with string Tensors."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.compat import compat
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import sparse_tensor
29from tensorflow.python.framework import tensor_util
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import gen_parsing_ops
32from tensorflow.python.ops import gen_string_ops
33from tensorflow.python.ops import math_ops
34
35# go/tf-wildcard-import
36# pylint: disable=wildcard-import
37# pylint: disable=g-bad-import-order
38from tensorflow.python.ops.gen_string_ops import *
39from tensorflow.python.util import compat as util_compat
40from tensorflow.python.util import deprecation
41from tensorflow.python.util import dispatch
42from tensorflow.python.util.tf_export import tf_export
43# pylint: enable=g-bad-import-order
44# pylint: enable=wildcard-import
45
46
47# pylint: disable=redefined-builtin
48@tf_export("strings.regex_full_match")
49@dispatch.add_dispatch_support
50def regex_full_match(input, pattern, name=None):
51  r"""Match elements of `input` with regex `pattern`.
52
53  Args:
54    input: string `Tensor`, the source strings to process.
55    pattern: string or scalar string `Tensor`, regular expression to use,
56      see more details at https://github.com/google/re2/wiki/Syntax
57    name: Name of the op.
58
59  Returns:
60    bool `Tensor` of the same shape as `input` with match results.
61  """
62  # TODO(b/112455102): Remove compat.forward_compatible once past the horizon.
63  if not compat.forward_compatible(2018, 11, 10):
64    return gen_string_ops.regex_full_match(
65        input=input, pattern=pattern, name=name)
66  if isinstance(pattern, util_compat.bytes_or_text_types):
67    # When `pattern` is static through the life of the op we can
68    # use a version which performs the expensive regex compilation once at
69    # creation time.
70    return gen_string_ops.static_regex_full_match(
71        input=input, pattern=pattern, name=name)
72  return gen_string_ops.regex_full_match(
73      input=input, pattern=pattern, name=name)
74
75regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
76
77
78@tf_export(
79    "strings.regex_replace", v1=["strings.regex_replace", "regex_replace"])
80@deprecation.deprecated_endpoints("regex_replace")
81@dispatch.add_dispatch_support
82def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
83  r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
84
85  Args:
86    input: string `Tensor`, the source strings to process.
87    pattern: string or scalar string `Tensor`, regular expression to use,
88      see more details at https://github.com/google/re2/wiki/Syntax
89    rewrite: string or scalar string `Tensor`, value to use in match
90      replacement, supports backslash-escaped digits (\1 to \9) can be to insert
91      text matching corresponding parenthesized group.
92    replace_global: `bool`, if `True` replace all non-overlapping matches,
93      else replace only the first match.
94    name: A name for the operation (optional).
95
96  Returns:
97    string `Tensor` of the same shape as `input` with specified replacements.
98  """
99  if (isinstance(pattern, util_compat.bytes_or_text_types) and
100      isinstance(rewrite, util_compat.bytes_or_text_types)):
101    # When `pattern` and `rewrite` are static through the life of the op we can
102    # use a version which performs the expensive regex compilation once at
103    # creation time.
104    return gen_string_ops.static_regex_replace(
105        input=input, pattern=pattern,
106        rewrite=rewrite, replace_global=replace_global,
107        name=name)
108  return gen_string_ops.regex_replace(
109      input=input, pattern=pattern,
110      rewrite=rewrite, replace_global=replace_global,
111      name=name)
112
113
114@tf_export("strings.format")
115def string_format(template, inputs, placeholder="{}", summarize=3, name=None):
116  r"""Formats a string template using a list of tensors.
117
118  Formats a string template using a list of tensors, abbreviating tensors by
119  only printing the first and last `summarize` elements of each dimension
120  (recursively). If formatting only one tensor into a template, the tensor does
121  not have to be wrapped in a list.
122
123  Example:
124    Formatting a single-tensor template:
125    ```python
126    sess = tf.Session()
127    with sess.as_default():
128        tensor = tf.range(10)
129        formatted = tf.strings.format("tensor: {}, suffix", tensor)
130        out = sess.run(formatted)
131        expected = "tensor: [0 1 2 ... 7 8 9], suffix"
132
133        assert(out.decode() == expected)
134    ```
135
136    Formatting a multi-tensor template:
137    ```python
138    sess = tf.Session()
139    with sess.as_default():
140        tensor_one = tf.reshape(tf.range(100), [10, 10])
141        tensor_two = tf.range(10)
142        formatted = tf.strings.format("first: {}, second: {}, suffix",
143          (tensor_one, tensor_two))
144
145        out = sess.run(formatted)
146        expected = ("first: [[0 1 2 ... 7 8 9]\n"
147              " [10 11 12 ... 17 18 19]\n"
148              " [20 21 22 ... 27 28 29]\n"
149              " ...\n"
150              " [70 71 72 ... 77 78 79]\n"
151              " [80 81 82 ... 87 88 89]\n"
152              " [90 91 92 ... 97 98 99]], second: [0 1 2 ... 7 8 9], suffix")
153
154        assert(out.decode() == expected)
155    ```
156
157  Args:
158    template: A string template to format tensor values into.
159    inputs: A list of `Tensor` objects, or a single Tensor.
160      The list of tensors to format into the template string. If a solitary
161      tensor is passed in, the input tensor will automatically be wrapped as a
162      list.
163    placeholder: An optional `string`. Defaults to `{}`.
164      At each placeholder occurring in the template, a subsequent tensor
165      will be inserted.
166    summarize: An optional `int`. Defaults to `3`.
167      When formatting the tensors, show the first and last `summarize`
168      entries of each tensor dimension (recursively). If set to -1, all
169      elements of the tensor will be shown.
170    name: A name for the operation (optional).
171
172  Returns:
173    A scalar `Tensor` of type `string`.
174
175  Raises:
176    ValueError: if the number of placeholders does not match the number of
177      inputs.
178  """
179  # If there is only one tensor to format, we will automatically wrap it in a
180  # list to simplify the user experience
181  if tensor_util.is_tensor(inputs):
182    inputs = [inputs]
183  if template.count(placeholder) != len(inputs):
184    raise ValueError("%s placeholder(s) in template does not match %s tensor(s)"
185                     " provided as input" % (template.count(placeholder),
186                                             len(inputs)))
187
188  return gen_string_ops.string_format(inputs,
189                                      template=template,
190                                      placeholder=placeholder,
191                                      summarize=summarize,
192                                      name=name)
193
194
195@tf_export(v1=["string_split"])
196@deprecation.deprecated_args(None,
197                             "delimiter is deprecated, please use sep instead.",
198                             "delimiter")
199def string_split(source, sep=None, skip_empty=True, delimiter=None):  # pylint: disable=invalid-name
200  """Split elements of `source` based on `delimiter` into a `SparseTensor`.
201
202  Let N be the size of source (typically N will be the batch size). Split each
203  element of `source` based on `delimiter` and return a `SparseTensor`
204  containing the split tokens. Empty tokens are ignored.
205
206  If `sep` is an empty string, each element of the `source` is split
207  into individual strings, each containing one byte. (This includes splitting
208  multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
209  treated as a set of delimiters with each considered a potential split point.
210
211  For example:
212  N = 2, source[0] is 'hello world' and source[1] is 'a b c', then the output
213  will be
214
215  st.indices = [0, 0;
216                0, 1;
217                1, 0;
218                1, 1;
219                1, 2]
220  st.shape = [2, 3]
221  st.values = ['hello', 'world', 'a', 'b', 'c']
222
223  Args:
224    source: `1-D` string `Tensor`, the strings to split.
225    sep: `0-D` string `Tensor`, the delimiter character, the string should
226      be length 0 or 1. Default is ' '.
227    skip_empty: A `bool`. If `True`, skip the empty strings from the result.
228    delimiter: deprecated alias for `sep`.
229
230  Raises:
231    ValueError: If delimiter is not a string.
232
233  Returns:
234    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
235    The first column of the indices corresponds to the row in `source` and the
236    second column corresponds to the index of the split component in this row.
237  """
238  delimiter = deprecation.deprecated_argument_lookup(
239      "sep", sep, "delimiter", delimiter)
240
241  if delimiter is None:
242    delimiter = " "
243  delimiter = ops.convert_to_tensor(delimiter, dtype=dtypes.string)
244  source = ops.convert_to_tensor(source, dtype=dtypes.string)
245
246  indices, values, shape = gen_string_ops.string_split(
247      source, delimiter=delimiter, skip_empty=skip_empty)
248  indices.set_shape([None, 2])
249  values.set_shape([None])
250  shape.set_shape([2])
251  return sparse_tensor.SparseTensor(indices, values, shape)
252
253
254@tf_export("strings.split")
255def string_split_v2(source, sep=None, maxsplit=-1):
256  """Split elements of `source` based on `sep` into a `SparseTensor`.
257
258  Let N be the size of source (typically N will be the batch size). Split each
259  element of `source` based on `sep` and return a `SparseTensor`
260  containing the split tokens. Empty tokens are ignored.
261
262  For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
263  then the output will be
264
265  st.indices = [0, 0;
266                0, 1;
267                1, 0;
268                1, 1;
269                1, 2]
270  st.shape = [2, 3]
271  st.values = ['hello', 'world', 'a', 'b', 'c']
272
273  If `sep` is given, consecutive delimiters are not grouped together and are
274  deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
275  sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
276  string, consecutive whitespace are regarded as a single separator, and the
277  result will contain no empty strings at the start or end if the string has
278  leading or trailing whitespace.
279
280  Note that the above mentioned behavior matches python's str.split.
281
282  Args:
283    source: `1-D` string `Tensor`, the strings to split.
284    sep: `0-D` string `Tensor`, the delimiter character.
285    maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
286
287  Raises:
288    ValueError: If sep is not a string.
289
290  Returns:
291    A `SparseTensor` of rank `2`, the strings split according to the delimiter.
292    The first column of the indices corresponds to the row in `source` and the
293    second column corresponds to the index of the split component in this row.
294  """
295  if sep is None:
296    sep = ""
297  sep = ops.convert_to_tensor(sep, dtype=dtypes.string)
298  source = ops.convert_to_tensor(source, dtype=dtypes.string)
299
300  indices, values, shape = gen_string_ops.string_split_v2(
301      source, sep=sep, maxsplit=maxsplit)
302  indices.set_shape([None, 2])
303  values.set_shape([None])
304  shape.set_shape([2])
305  return sparse_tensor.SparseTensor(indices, values, shape)
306
307
308def _reduce_join_reduction_dims(x, axis, reduction_indices):
309  """Returns range(rank(x) - 1, 0, -1) if reduction_indices is None."""
310  # TODO(aselle): Remove this after deprecation
311  if reduction_indices is not None:
312    if axis is not None:
313      raise ValueError("Can't specify both 'axis' and 'reduction_indices'.")
314    axis = reduction_indices
315  if axis is not None:
316    return axis
317  else:
318    # Fast path: avoid creating Rank and Range ops if ndims is known.
319    if x.get_shape().ndims is not None:
320      return constant_op.constant(
321          np.arange(x.get_shape().ndims - 1, -1, -1), dtype=dtypes.int32)
322
323    # Otherwise, we rely on Range and Rank to do the right thing at run-time.
324    return math_ops.range(array_ops.rank(x) - 1, -1, -1)
325
326
327@tf_export(v1=["strings.reduce_join", "reduce_join"])
328@deprecation.deprecated_endpoints("reduce_join")
329def reduce_join(inputs, axis=None,  # pylint: disable=missing-docstring
330                keep_dims=False,
331                separator="",
332                name=None,
333                reduction_indices=None,
334                keepdims=None):
335  keep_dims = deprecation.deprecated_argument_lookup(
336      "keepdims", keepdims, "keep_dims", keep_dims)
337  inputs_t = ops.convert_to_tensor(inputs)
338  reduction_indices = _reduce_join_reduction_dims(
339      inputs_t, axis, reduction_indices)
340  return gen_string_ops.reduce_join(
341      inputs=inputs_t,
342      reduction_indices=reduction_indices,
343      keep_dims=keep_dims,
344      separator=separator,
345      name=name)
346
347
348@tf_export("strings.reduce_join", v1=[])
349def reduce_join_v2(  # pylint: disable=missing-docstring
350    inputs,
351    axis=None,
352    keepdims=False,
353    separator="",
354    name=None):
355  return reduce_join(
356      inputs, axis, keep_dims=keepdims, separator=separator, name=name)
357
358
359reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
360    gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
361reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
362                                                  "tf.strings.reduce_join(")
363
364
365# This wrapper provides backwards compatibility for code that predates the
366# unit argument and that passed 'name' as a positional argument.
367@tf_export(v1=["strings.length"])
368@dispatch.add_dispatch_support
369def string_length(input, name=None, unit="BYTE"):
370  return gen_string_ops.string_length(input, unit=unit, name=name)
371
372
373@tf_export("strings.length", v1=[])
374@dispatch.add_dispatch_support
375def string_length_v2(input, unit="BYTE", name=None):
376  return string_length(input, name, unit)
377
378
379string_length.__doc__ = gen_string_ops.string_length.__doc__
380
381
382@tf_export(v1=["substr"])
383@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
384def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
385  return substr(input, pos, len, name=name, unit=unit)
386
387substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
388
389
390@tf_export(v1=["strings.substr"])
391@dispatch.add_dispatch_support
392def substr(input, pos, len, name=None, unit="BYTE"):
393  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
394
395substr.__doc__ = gen_string_ops.substr.__doc__
396
397
398@tf_export("strings.substr", v1=[])
399@dispatch.add_dispatch_support
400def substr_v2(input, pos, len, unit="BYTE", name=None):
401  return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
402
403substr_v2.__doc__ = gen_string_ops.substr.__doc__
404
405
406ops.NotDifferentiable("RegexReplace")
407ops.NotDifferentiable("StringToHashBucket")
408ops.NotDifferentiable("StringToHashBucketFast")
409ops.NotDifferentiable("StringToHashBucketStrong")
410ops.NotDifferentiable("ReduceJoin")
411ops.NotDifferentiable("StringJoin")
412ops.NotDifferentiable("StringSplit")
413ops.NotDifferentiable("AsString")
414ops.NotDifferentiable("EncodeBase64")
415ops.NotDifferentiable("DecodeBase64")
416
417
418@tf_export("strings.to_number", v1=[])
419@dispatch.add_dispatch_support
420def string_to_number(input, out_type=dtypes.float32, name=None):
421  r"""Converts each string in the input Tensor to the specified numeric type.
422
423  (Note that int32 overflow results in an error while float overflow
424  results in a rounded value.)
425
426  Args:
427    input: A `Tensor` of type `string`.
428    out_type: An optional `tf.DType` from: `tf.float32, tf.float64, tf.int32,
429      tf.int64`. Defaults to `tf.float32`.
430      The numeric type to interpret each string in `string_tensor` as.
431    name: A name for the operation (optional).
432
433  Returns:
434    A `Tensor` of type `out_type`.
435  """
436  return gen_parsing_ops.string_to_number(input, out_type, name)
437
438
439@tf_export(v1=["strings.to_number", "string_to_number"])
440def string_to_number_v1(
441    string_tensor=None,
442    out_type=dtypes.float32,
443    name=None,
444    input=None):
445  string_tensor = deprecation.deprecated_argument_lookup(
446      "input", input, "string_tensor", string_tensor)
447  return gen_parsing_ops.string_to_number(string_tensor, out_type, name)
448
449string_to_number_v1.__doc__ = gen_parsing_ops.string_to_number.__doc__
450
451
452@tf_export("strings.to_hash_bucket", v1=[])
453@dispatch.add_dispatch_support
454def string_to_hash_bucket(input, num_buckets, name=None):
455  # pylint: disable=line-too-long
456  r"""Converts each string in the input Tensor to its hash mod by a number of buckets.
457
458  The hash function is deterministic on the content of the string within the
459  process.
460
461  Note that the hash function may change from time to time.
462  This functionality will be deprecated and it's recommended to use
463  `tf.string_to_hash_bucket_fast()` or `tf.string_to_hash_bucket_strong()`.
464
465  Args:
466    input: A `Tensor` of type `string`.
467    num_buckets: An `int` that is `>= 1`. The number of buckets.
468    name: A name for the operation (optional).
469
470  Returns:
471    A `Tensor` of type `int64`.
472  """
473  # pylint: enable=line-too-long
474  return gen_string_ops.string_to_hash_bucket(input, num_buckets, name)
475
476
477@tf_export(v1=["strings.to_hash_bucket", "string_to_hash_bucket"])
478def string_to_hash_bucket_v1(
479    string_tensor=None,
480    num_buckets=None,
481    name=None,
482    input=None):
483  string_tensor = deprecation.deprecated_argument_lookup(
484      "input", input, "string_tensor", string_tensor)
485  return gen_string_ops.string_to_hash_bucket(string_tensor, num_buckets, name)
486
487string_to_hash_bucket_v1.__doc__ = gen_string_ops.string_to_hash_bucket.__doc__
488