1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ 17 #define TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ 18 19 #include <vector> 20 21 #include "llvm/ADT/DenseMap.h" 22 23 namespace mlir { 24 25 // ------------------------------------------------------------------- 26 27 // This file contains a light version of GraphCycles implemented in 28 // tensorflow/compiler/jit/graphcycles/graphcycles.h 29 // 30 // We re-implement it here because we do not want to rely 31 // on TensorFlow data structures, and hence we can move 32 // corresponding passes to llvm repo. easily in case necessnary. 33 34 // -------------------------------------------------------------------- 35 36 // This is a set data structure that provides a deterministic iteration order. 37 // The iteration order of elements only depends on the sequence of 38 // inserts/deletes, so as long as the inserts/deletes happen in the same 39 // sequence, the set will have the same iteration order. 40 // 41 // Assumes that T can be cheaply copied for simplicity. 42 template <typename T> 43 class OrderedSet { 44 public: 45 // Inserts `value` into the ordered set. Returns true if the value was not 46 // present in the set before the insertion. Insert(T value)47 bool Insert(T value) { 48 bool new_insertion = 49 value_to_index_.insert({value, value_sequence_.size()}).second; 50 if (new_insertion) { 51 value_sequence_.push_back(value); 52 } 53 return new_insertion; 54 } 55 56 // Removes `value` from the set. Assumes `value` is already present in the 57 // set. Erase(T value)58 void Erase(T value) { 59 auto it = value_to_index_.find(value); 60 61 // Since we don't want to move values around in `value_sequence_` we swap 62 // the value in the last position and with value to be deleted and then 63 // pop_back. 64 value_to_index_[value_sequence_.back()] = it->second; 65 std::swap(value_sequence_[it->second], value_sequence_.back()); 66 value_sequence_.pop_back(); 67 value_to_index_.erase(it); 68 } 69 Reserve(size_t new_size)70 void Reserve(size_t new_size) { 71 value_to_index_.reserve(new_size); 72 value_sequence_.reserve(new_size); 73 } 74 Clear()75 void Clear() { 76 value_to_index_.clear(); 77 value_sequence_.clear(); 78 } 79 Contains(T value)80 bool Contains(T value) const { return value_to_index_.count(value); } Size()81 size_t Size() const { return value_sequence_.size(); } 82 GetSequence()83 const std::vector<T>& GetSequence() const { return value_sequence_; } 84 85 private: 86 // The stable order that we maintain through insertions and deletions. 87 std::vector<T> value_sequence_; 88 89 // Maps values to their indices in `value_sequence_`. 90 llvm::DenseMap<T, int> value_to_index_; 91 }; 92 93 // --------------------------------------------------------------------- 94 95 // GraphCycles detects the introduction of a cycle into a directed 96 // graph that is being built up incrementally. 97 // 98 // Nodes are identified by small integers. It is not possible to 99 // record multiple edges with the same (source, destination) pair; 100 // requests to add an edge where one already exists are silently 101 // ignored. 102 // 103 // It is also not possible to introduce a cycle; an attempt to insert 104 // an edge that would introduce a cycle fails and returns false. 105 // 106 // GraphCycles uses no internal locking; calls into it should be 107 // serialized externally. 108 109 // Performance considerations: 110 // Works well on sparse graphs, poorly on dense graphs. 111 // Extra information is maintained incrementally to detect cycles quickly. 112 // InsertEdge() is very fast when the edge already exists, and reasonably fast 113 // otherwise. 114 // FindPath() is linear in the size of the graph. 115 // The current implementation uses O(|V|+|E|) space. 116 117 class GraphCycles { 118 public: 119 explicit GraphCycles(int32_t num_nodes); 120 ~GraphCycles(); 121 122 // Attempt to insert an edge from x to y. If the 123 // edge would introduce a cycle, return false without making any 124 // changes. Otherwise add the edge and return true. 125 bool InsertEdge(int32_t x, int32_t y); 126 127 // Remove any edge that exists from x to y. 128 void RemoveEdge(int32_t x, int32_t y); 129 130 // Return whether there is an edge directly from x to y. 131 bool HasEdge(int32_t x, int32_t y) const; 132 133 // Contracts the edge from 'a' to node 'b', merging nodes 'a' and 'b'. One of 134 // the nodes is removed from the graph, and edges to/from it are added to 135 // the remaining one, which is returned. If contracting the edge would create 136 // a cycle, does nothing and return no value. 137 llvm::Optional<int32_t> ContractEdge(int32_t a, int32_t b); 138 139 // Return whether dest_node `y` is reachable from source_node `x` 140 // by following edges. This is non-thread-safe version. 141 bool IsReachable(int32_t x, int32_t y); 142 143 // Return a copy of the successors set. This is needed for code using the 144 // collection while modifying the GraphCycles. 145 std::vector<int32_t> SuccessorsCopy(int32_t node) const; 146 147 // Returns all nodes in post order. 148 // 149 // If there is a path from X to Y then X appears after Y in the 150 // returned vector. 151 std::vector<int32_t> AllNodesInPostOrder() const; 152 153 // ---------------------------------------------------- 154 struct Rep; 155 156 private: 157 GraphCycles(const GraphCycles&) = delete; 158 GraphCycles& operator=(const GraphCycles&) = delete; 159 160 Rep* rep_; // opaque representation 161 }; 162 163 } // namespace mlir 164 165 #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_UTILS_CYCLE_DETECTOR_H_ 166