1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
11 
12 #include <Eigen/CXX11/Tensor>
13 
14 using Eigen::Tensor;
15 using Eigen::TensorMap;
16 
test_assign()17 static void test_assign()
18 {
19   std::string data1[6];
20   TensorMap<Tensor<std::string, 2>> mat1(data1, 2, 3);
21   std::string data2[6];
22   const TensorMap<Tensor<const std::string, 2>> mat2(data2, 2, 3);
23 
24   for (int i = 0; i < 6; ++i) {
25     std::ostringstream s1;
26     s1 << "abc" << i*3;
27     data1[i] = s1.str();
28     std::ostringstream s2;
29     s2 << "def" << i*5;
30     data2[i] = s2.str();
31   }
32 
33   Tensor<std::string, 2> rslt1;
34   rslt1 = mat1;
35   Tensor<std::string, 2> rslt2;
36   rslt2 = mat2;
37 
38   Tensor<std::string, 2> rslt3 = mat1;
39   Tensor<std::string, 2> rslt4 = mat2;
40 
41   Tensor<std::string, 2> rslt5(mat1);
42   Tensor<std::string, 2> rslt6(mat2);
43 
44   for (int i = 0; i < 2; ++i) {
45     for (int j = 0; j < 3; ++j) {
46       VERIFY_IS_EQUAL(rslt1(i,j), data1[i+2*j]);
47       VERIFY_IS_EQUAL(rslt2(i,j), data2[i+2*j]);
48       VERIFY_IS_EQUAL(rslt3(i,j), data1[i+2*j]);
49       VERIFY_IS_EQUAL(rslt4(i,j), data2[i+2*j]);
50       VERIFY_IS_EQUAL(rslt5(i,j), data1[i+2*j]);
51       VERIFY_IS_EQUAL(rslt6(i,j), data2[i+2*j]);
52     }
53   }
54 }
55 
56 
test_concat()57 static void test_concat()
58 {
59   Tensor<std::string, 2> t1(2, 3);
60   Tensor<std::string, 2> t2(2, 3);
61 
62   for (int i = 0; i < 2; ++i) {
63     for (int j = 0; j < 3; ++j) {
64       std::ostringstream s1;
65       s1 << "abc" << i + j*2;
66       t1(i, j) = s1.str();
67       std::ostringstream s2;
68       s2 << "def" << i*5 + j*32;
69       t2(i, j) = s2.str();
70     }
71   }
72 
73   Tensor<std::string, 2> result = t1.concatenate(t2, 1);
74   VERIFY_IS_EQUAL(result.dimension(0), 2);
75   VERIFY_IS_EQUAL(result.dimension(1), 6);
76 
77   for (int i = 0; i < 2; ++i) {
78     for (int j = 0; j < 3; ++j) {
79       VERIFY_IS_EQUAL(result(i, j),   t1(i, j));
80       VERIFY_IS_EQUAL(result(i, j+3), t2(i, j));
81     }
82   }
83 }
84 
85 
test_slices()86 static void test_slices()
87 {
88   Tensor<std::string, 2> data(2, 6);
89   for (int i = 0; i < 2; ++i) {
90     for (int j = 0; j < 3; ++j) {
91       std::ostringstream s1;
92       s1 << "abc" << i + j*2;
93       data(i, j) = s1.str();
94     }
95   }
96 
97   const Eigen::DSizes<ptrdiff_t, 2> half_size(2, 3);
98   const Eigen::DSizes<ptrdiff_t, 2> first_half(0, 0);
99   const Eigen::DSizes<ptrdiff_t, 2> second_half(0, 3);
100 
101   Tensor<std::string, 2> t1 = data.slice(first_half, half_size);
102   Tensor<std::string, 2> t2 = data.slice(second_half, half_size);
103 
104   for (int i = 0; i < 2; ++i) {
105     for (int j = 0; j < 3; ++j) {
106       VERIFY_IS_EQUAL(data(i, j),   t1(i, j));
107       VERIFY_IS_EQUAL(data(i, j+3), t2(i, j));
108     }
109   }
110 }
111 
112 
test_additions()113 static void test_additions()
114 {
115   Tensor<std::string, 1> data1(3);
116   Tensor<std::string, 1> data2(3);
117   for (int i = 0; i < 3; ++i) {
118     data1(i) = "abc";
119     std::ostringstream s1;
120     s1 << i;
121     data2(i) = s1.str();
122   }
123 
124   Tensor<std::string, 1> sum = data1 + data2;
125   for (int i = 0; i < 3; ++i) {
126     std::ostringstream concat;
127     concat << "abc" << i;
128     std::string expected = concat.str();
129     VERIFY_IS_EQUAL(sum(i), expected);
130   }
131 }
132 
133 
test_initialization()134 static void test_initialization()
135 {
136   Tensor<std::string, 2> a(2, 3);
137   a.setConstant(std::string("foo"));
138   for (int i = 0; i < 2*3; ++i) {
139     VERIFY_IS_EQUAL(a(i), std::string("foo"));
140   }
141 }
142 
143 
test_cxx11_tensor_of_strings()144 void test_cxx11_tensor_of_strings()
145 {
146   // Beware: none of this is likely to ever work on a GPU.
147   CALL_SUBTEST(test_assign());
148   CALL_SUBTEST(test_concat());
149   CALL_SUBTEST(test_slices());
150   CALL_SUBTEST(test_additions());
151   CALL_SUBTEST(test_initialization());
152 }
153