1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for AutomaticControlDependencies."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.util import object_identity
23
24READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
25RESOURCE_READ_OPS = set()
26
27
28COLLECTIVE_MANAGER_IDS = "_collective_manager_ids"
29
30
31def register_read_only_resource_op(op_type):
32  """Declares that `op_type` does not update its touched resource."""
33  RESOURCE_READ_OPS.add(op_type)
34
35
36def get_read_only_resource_input_indices_graph(func_graph):
37  """Returns sorted list of read-only resource indices in func_graph.inputs."""
38  result = []
39  # A cache to store the read only resource inputs of an Op.
40  # Operation -> ObjectIdentitySet of resource handles.
41  op_read_only_resource_inputs = {}
42  for input_index, t in enumerate(func_graph.inputs):
43    if t.dtype != dtypes.resource:
44      continue
45    read_only = True
46    for op in t.consumers():
47      if op in op_read_only_resource_inputs:
48        if t not in op_read_only_resource_inputs[op]:
49          read_only = False
50          break
51      else:
52        indices = _get_read_only_resource_input_indices_op(op)
53        op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet(
54            [op.inputs[i] for i in indices])
55        if t not in op_read_only_resource_inputs[op]:
56          read_only = False
57          break
58    if read_only:
59      result.append(input_index)
60  return result
61
62
63def _get_read_only_resource_input_indices_op(op):
64  """Returns sorted list of read-only resource indices in op.inputs."""
65  if op.type in RESOURCE_READ_OPS:
66    return [i for i, t in enumerate(op.inputs) if t.dtype == dtypes.resource]
67
68  try:
69    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
70  except ValueError:
71    # Attr was not set. Add all resource inputs to `writes` and return.
72    return []
73
74  read_only_index = 0
75  result = []
76  for i, t in enumerate(op.inputs):
77    if read_only_index >= len(read_only_input_indices):
78      break
79    if op.inputs[i].dtype != dtypes.resource:
80      continue
81    if (read_only_index < len(read_only_input_indices) and
82        i == read_only_input_indices[read_only_index]):
83      result.append(i)
84      read_only_index += 1
85
86  return result
87
88
89def get_read_write_resource_inputs(op):
90  """Returns a tuple of resource reads, writes in op.inputs.
91
92  Args:
93    op: Operation
94
95  Returns:
96    A 2-tuple of ObjectIdentitySets, the first entry containing read-only
97    resource handles and the second containing read-write resource handles in
98    `op.inputs`.
99  """
100  reads = object_identity.ObjectIdentitySet()
101  writes = object_identity.ObjectIdentitySet()
102
103  if op.type in RESOURCE_READ_OPS:
104    # Add all resource inputs to `reads` and return.
105    reads.update(t for t in op.inputs if t.dtype == dtypes.resource)
106    return (reads, writes)
107
108  try:
109    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
110  except ValueError:
111    # Attr was not set. Add all resource inputs to `writes` and return.
112    writes.update(t for t in op.inputs if t.dtype == dtypes.resource)
113    return (reads, writes)
114
115  read_only_index = 0
116  for i, t in enumerate(op.inputs):
117    if op.inputs[i].dtype != dtypes.resource:
118      continue
119    if (read_only_index < len(read_only_input_indices) and
120        i == read_only_input_indices[read_only_index]):
121      reads.add(op.inputs[i])
122      read_only_index += 1
123    else:
124      writes.add(op.inputs[i])
125  return (reads, writes)
126
127
128def _op_writes_to_resource(handle, op):
129  """Returns whether op writes to resource handle.
130
131  Args:
132    handle: Resource handle. Must be an input of `op`.
133    op: Operation.
134
135  Returns:
136    Returns False if op is a read-only op registered using
137    `register_read_only_resource_op` or if `handle` is an input at one of
138    the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True
139    otherwise.
140
141  Raises:
142    ValueError: if `handle` is not an input of `op`.
143  """
144  if op.type in RESOURCE_READ_OPS:
145    return False
146  input_index = _input_index(op, handle)
147  try:
148    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
149  except ValueError:
150    # Attr was not set. Conservatively assume that the resource is written to.
151    return True
152  return input_index not in read_only_input_indices
153
154
155def _input_index(op, handle):
156  """Returns the index of `handle` in `op.inputs`.
157
158  Args:
159    op: Operation.
160    handle: Resource handle.
161
162  Returns:
163    Index in `op.inputs` receiving the resource `handle`.
164
165  Raises:
166    ValueError: If handle and its replicated input are both not found in
167    `op.inputs`.
168  """
169  for i, t in enumerate(op.inputs):
170    if handle is t:
171      return i
172  raise ValueError("%s not in list" % str(handle))
173