/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Provides a set of matchers for tensorflow nodes. // // Example usage: // // tensorflow::Node* node = ...; // EXPECT_THAT(node, NodeWith(Name("name"), Op("op"), // Inputs(Out(3, NodeWith(Name("input")))))) // // Matchable node properties (the expressions that go inside NodeWith(...)) // are: // // - Name(string): matches the node name exactly. We will probably need to // have this take a string matcher soon in the future. // // - Op(string): matches the op exactly. // // - AssignedDevice(string): matches the assigned device exactly. // // - Inputs(): matches the list of non-control inputs to the node // exactly (i.e. does not match a suffix or a prefix) where each element // matches an output of a node (see Out(idx, node) below). // // - CtrlDeps(): matches the list of control dependences on the // node exactly but in any order. // // - ConstantValue(tensorflow::Input::Initializer init): matches a Const node // with the constant value `init`. Implies Op("Const"). // // - Attr(name, value): Matches a single attribute with name `name` and value // `value`. Right now only boolean values are supported. // // Overlapping node properties may not be repeated in a single NodeWith(...) // matcher. E.g. NodeWith(Op("Foo"), Op("Bar")) will CHECK-fail. Since // ConstantValue implies Op("Const"), a single NodeWith matcher can't have both // ConstantValue(...) and Op(...). Multiple Attr() values can be combined as // long as the attribute names are different. // // Out(idx, node) matches the `idx`'th output of a node that matches `node`. #ifndef TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ #define TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_ #include #include #include #include "absl/algorithm/container.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/cc/framework/ops.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/core/graph/graph.h" namespace tensorflow { namespace testing { namespace matchers { namespace impl { using OutEdge = std::pair; // ----------------------------------------------------------------------------- // Implementation details. // Properties that we match on for a particular Node. If a particular property // is nullopt then any value for it is allowed. class NodeMatcherProperties { public: using NodeSeqMatcher = std::vector<::testing::Matcher>; using InputSeqMatcher = std::vector<::testing::Matcher>; using AttrKeyValuePair = std::pair>; const absl::optional& name() const { return name_; } const absl::optional& op() const { return op_; } const absl::optional& assigned_device() const { return assigned_device_; } const absl::optional& constant_value() const { return constant_value_; } const absl::optional& inputs() const { return input_matchers_; } const absl::optional& control_deps() const { return control_deps_; } const absl::optional& attr() const { return attr_; } void set_name(string name) { DCHECK(IsEmpty()); name_ = std::move(name); } void set_op(string op) { DCHECK(IsEmpty()); op_ = std::move(op); } void set_assigned_device(string assigned_device) { DCHECK(IsEmpty()); assigned_device_ = std::move(assigned_device); } void set_constant_value(Tensor constant_value) { DCHECK(IsEmpty()); constant_value_ = std::move(constant_value); op_ = "Const"; } void set_inputs(InputSeqMatcher inputs) { DCHECK(IsEmpty()); input_matchers_ = std::move(inputs); } void set_control_deps(NodeSeqMatcher control_deps) { DCHECK(IsEmpty()); control_deps_ = std::move(control_deps); } void set_attr(AttrKeyValuePair attr) { DCHECK(IsEmpty()); attr_ = std::move(attr); } bool IsEmpty() const { return !name().has_value() && !op().has_value() && !inputs().has_value() && !control_deps().has_value() && !attr().has_value(); } private: absl::optional name_; absl::optional op_; absl::optional assigned_device_; absl::optional constant_value_; absl::optional input_matchers_; absl::optional control_deps_; absl::optional attr_; }; ::testing::Matcher NodeWith( absl::Span props); impl::NodeMatcherProperties Inputs( absl::Span> inputs); impl::NodeMatcherProperties CtrlDeps( absl::Span> control_deps); impl::NodeMatcherProperties Attr(std::pair attrs); impl::NodeMatcherProperties Attr(string name); std::pair AttrLiteralHelper( const std::pair& bool_attr); std::pair AttrLiteralHelper( const std::pair>& int_list_attr); std::pair AttrLiteralHelper( const std::pair>& string_list_attr); } // namespace impl // ----------------------------------------------------------------------------- // Public interface. // Matches a node with name `name`. impl::NodeMatcherProperties Name(string name); // Matches a node with op `op`. impl::NodeMatcherProperties Op(string op); // Matches a node with assigned device `assigned_device`. impl::NodeMatcherProperties AssignedDevice(string assigned_device); // Matches a node with a boolean typed attrbute named `name` and with value // `value`. template impl::NodeMatcherProperties Attr(const string& name, ValueTy value) { return impl::Attr({impl::AttrLiteralHelper({name, value})}); } inline impl::NodeMatcherProperties Attr(const string& name) { return impl::Attr(name); } // Matches a node with inputs `inputs`. // // `inputs` are ordered; `inputs`[i] must match input i. template impl::NodeMatcherProperties Inputs(Ts... inputs) { return impl::Inputs({inputs...}); } // Matches the `idx`'th output of a node that matches `node`. ::testing::Matcher Out(int oidx, ::testing::Matcher node); // Matches the first output of a node that matches `node`. inline ::testing::Matcher Out( ::testing::Matcher node) { return Out(0, node); } // Matches a node with control dependences `control_deps`. // // `control_deps` are unordered and will match the control deps of a node in any // order. template impl::NodeMatcherProperties CtrlDeps(Ts... control_deps) { return impl::CtrlDeps({control_deps...}); } // Matches a constant node with value `val`. impl::NodeMatcherProperties ConstantValue( const ::tensorflow::Input::Initializer& val); // The main gmock matcher. See file comment for example usage. template ::testing::Matcher NodeWith(Ts... args) { std::array array = {args...}; return impl::NodeWith(array); } ::testing::Matcher Const( const ::tensorflow::Input::Initializer& val); } // namespace matchers // If `g` has a node named `name` returns it, otherwise returns null. Node* FindNodeByName(Graph* g, absl::string_view name); } // namespace testing void PrintTo(const Node* n, ::std::ostream* os); void PrintTo(Node* n, ::std::ostream* os); } // namespace tensorflow #endif // TENSORFLOW_COMPILER_JIT_NODE_MATCHERS_H_