1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
19 #include "tensorflow/core/platform/logging.h"
20 #include "tensorflow/core/platform/test.h"
21 #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic.h"
22 #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic_frozen.h"
23 #include "tensorflow/python/tools/aot_compiled_x_matmul_y_large.h"
24 #include "tensorflow/python/tools/aot_compiled_x_matmul_y_large_multithreaded.h"
25 #include "tensorflow/python/tools/aot_compiled_x_matmul_y_small.h"
26 #include "tensorflow/python/tools/aot_compiled_x_plus_y.h"
27
28 namespace tensorflow {
29 namespace {
TEST(AOTCompiledSavedModelTest,XPlusY)30 TEST(AOTCompiledSavedModelTest, XPlusY) {
31 XPlusY model;
32 // Calculation is: output_0 = x + y.
33 *model.arg_feed_x_data() = 3.0f;
34 *model.arg_feed_y_data() = 4.0f;
35 CHECK(model.Run());
36 ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
37 }
38
TEST(AOTCompiledSavedModelTest,XMatmulYLarge)39 TEST(AOTCompiledSavedModelTest, XMatmulYLarge) {
40 XMatmulYLarge model;
41 // Calculation is: output_0 = x @ y.
42 EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000);
43 EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000);
44 EXPECT_EQ(model.result0_count(), 3000 * 4000);
45
46 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3000, 5000);
47 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5000, 4000);
48 arg_feed_x.setRandom();
49 arg_feed_y.setRandom();
50
51 // Set up dimensions for standard matmul.
52 const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
53 Eigen::IndexPair<int>(1, 0)};
54 // Ground truth matmul.
55 const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
56 arg_feed_x.contract(arg_feed_y, product_dims);
57
58 model.set_arg_feed_x_data(arg_feed_x.data());
59 model.set_arg_feed_y_data(arg_feed_y.data());
60 CHECK(model.Run());
61 EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
62 /*abs_error=*/1e-6f);
63 EXPECT_NEAR(model.result_fetch_output_0(2999, 3999),
64 expected_output0(2999, 3999),
65 /*abs_error=*/1e-6f);
66 }
67
TEST(AOTCompiledSavedModelTest,XMatmulYLargeMultithreaded)68 TEST(AOTCompiledSavedModelTest, XMatmulYLargeMultithreaded) {
69 XMatmulYLargeMultithreaded model;
70
71 Eigen::ThreadPool pool(2);
72 Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
73 model.set_thread_pool(&device);
74
75 // Calculation is: output_0 = x @ y.
76 EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000);
77 EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000);
78 EXPECT_EQ(model.result0_count(), 3000 * 4000);
79
80 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3000, 5000);
81 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5000, 4000);
82 arg_feed_x.setRandom();
83 arg_feed_y.setRandom();
84
85 // Set up dimensions for standard matmul.
86 const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
87 Eigen::IndexPair<int>(1, 0)};
88 // Ground truth matmul.
89 const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
90 arg_feed_x.contract(arg_feed_y, product_dims);
91
92 model.set_arg_feed_x_data(arg_feed_x.data());
93 model.set_arg_feed_y_data(arg_feed_y.data());
94 CHECK(model.Run());
95 EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
96 /*abs_error=*/1e-3f);
97 EXPECT_NEAR(model.result_fetch_output_0(2999, 3999),
98 expected_output0(2999, 3999),
99 /*abs_error=*/1e-3f);
100 }
101
TEST(AOTCompiledSavedModelTest,XMatmulYSmall)102 TEST(AOTCompiledSavedModelTest, XMatmulYSmall) {
103 XMatmulYSmall model;
104 // Calculation is: output_0 = x @ y.
105 EXPECT_EQ(model.arg_feed_x_count(), 3 * 5);
106 EXPECT_EQ(model.arg_feed_y_count(), 5 * 4);
107 EXPECT_EQ(model.result0_count(), 3 * 4);
108
109 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3, 5);
110 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5, 4);
111 arg_feed_x.setRandom();
112 arg_feed_y.setRandom();
113
114 // Set up dimensions for standard matmul.
115 const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
116 Eigen::IndexPair<int>(1, 0)};
117 // Ground truth matmul.
118 const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
119 arg_feed_x.contract(arg_feed_y, product_dims);
120
121 model.set_arg_feed_x_data(arg_feed_x.data());
122 model.set_arg_feed_y_data(arg_feed_y.data());
123 CHECK(model.Run());
124 EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
125 /*abs_error=*/1e-6f);
126 EXPECT_NEAR(model.result_fetch_output_0(2, 3), expected_output0(2, 3),
127 /*abs_error=*/1e-6f);
128 }
129
TEST(AOTCompiledSavedModelTest,VarsAndArithmetic)130 TEST(AOTCompiledSavedModelTest, VarsAndArithmetic) {
131 VarsAndArithmeticFrozen frozen_model;
132 // Calculation is:
133 // output_0 = [(a + variable_x) * (b + variable_y) / child_variable] + 5.0
134 // where {variable_x, variable_y, child_variable} = {1.0, 2.0, 3.0} when
135 // initialized (frozen).
136 *frozen_model.arg_feed_a_data() = 1.0f;
137 *frozen_model.arg_feed_b_data() = 2.0f;
138 CHECK(frozen_model.Run());
139 ASSERT_NEAR(frozen_model.result_fetch_output_0(),
140 (1.0f + 1.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
141
142 VarsAndArithmetic nonfrozen_model;
143 *nonfrozen_model.arg_feed_a_data() = 1.0f;
144 *nonfrozen_model.arg_feed_b_data() = 2.0f;
145 // variable_x is no longer frozen. set it to 4.0;
146 float new_variable_x = 4.0f;
147 nonfrozen_model.set_var_param_variable_x_data(&new_variable_x);
148 CHECK(nonfrozen_model.Run());
149 ASSERT_NEAR(nonfrozen_model.result_fetch_output_0(),
150 (1.0f + 4.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
151 }
152 } // namespace
153 } // namespace tensorflow
154