1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_NODE_INDEXING_EVALUATION_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_NODE_INDEXING_EVALUATION_H_
18 
19 #include "absl/container/flat_hash_map.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/types.h"
23 
24 namespace xla {
25 class FusionNodeIndexingEvaluation {
26  public:
27   explicit FusionNodeIndexingEvaluation(const HloInstruction* fusion,
28                                         int64 root_usage_count = 1);
29 
30   // Evaluate the number of times 'producer' would be emitted if it is fused
31   // into 'fusion_'. If the duplication is "too high" (some arbitrary chosen
32   // constant), returns true.
33   bool CodeDuplicationTooHigh(const HloInstruction* producer) const;
34 
35   // Evaluate the maximum code duplication inside the fusion node. If the
36   // maximum code duplication is "too high" (some arbitrary chosen constant),
37   // returns true.
38   bool MaxCodeDuplicationTooHigh() const;
39 
40   // Evaluate the number of times 'producer' would be emitted if it is fused
41   // into 'fusion_'.
42   int64 EvaluateEmittedInstructions(const HloInstruction* producer) const;
43 
44   // Update the evaluation cache after having fused 'producer' into 'fusion_'.
45   // 'producer' is the cloned instruction which is now part of the fusion
46   // computation. 'indexing_users_of_producer' are the direct or indirect users
47   // of 'producer' which pass index values created by them.
48   void UpdateEvaluationCache(
49       const HloInstruction* producer,
50       absl::flat_hash_set<const HloInstruction*> indexing_users_of_producer);
51 
52   // Prior to fusing, we need to erase the indexing_users_ entry of the
53   // producer to be fused, because the HloInstruction pointer will be
54   // invalidated. We return the set of direct or indirect users which pass index
55   // values created by them to the fusion parameter corresponding to this
56   // producer. This will be needed for updating the evaluation cache (see
57   // UpdateEvaluationCache).
58   absl::flat_hash_set<const HloInstruction*> RemoveFusionOperand(
59       HloInstruction* fusion_operand);
60 
61  private:
62   static const int64 kAllowedCodeDuplication;
63 
64   // Computes the 'indexing_users_' and 'index_usage_count_' maps based on the
65   // current instructions inside the fusion node. Also updates
66   // 'total_emitted_instructions_' accordingly.
67   void RecomputeCache();
68 
69   // Computes the 'index_usage_count_' entry for 'instruction'.
70   void UpdateIndexUsageCount(const HloInstruction* instruction);
71 
72   // Updates the 'indexing_users_' entry of the operands of 'instruction'.
73   void UpdateIndexingUsersOfOperands(const HloInstruction* instruction);
74 
75   // Collects for each instruction in a fusion node from which direct or
76   // indirect users newly created index values are passed. Roughly speaking, we
77   // reuse index values if the shapes are equal when ignoring the element type
78   // (we may reuse also if the shape change is a bitcast, but we don't consider
79   // that here). By ignoring potential reuses our estimate of which instruction
80   // generates a new index value is a bit more conservative than necessary.
81   absl::flat_hash_map<const HloInstruction*,
82                       absl::flat_hash_set<const HloInstruction*>>
83       indexing_users_;
84 
85   // Stores the number of different index accesses for each instruction in a
86   // fusion node. The fusion emitter caches access with the same index, so this
87   // value indicates how many times a specific instruction will be emitted.
88   absl::flat_hash_map<const HloInstruction*, int64> index_usage_count_;
89 
90   // The fusion instruction.
91   const HloInstruction* fusion_;
92 };
93 }  // namespace xla
94 
95 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_FUSION_NODE_INDEXING_EVALUATION_H_
96