RNNTest.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "RNN.h"
  17. #include "NeuralNetworksWrapper.h"
  18. #include "gmock/gmock-matchers.h"
  19. #include "gtest/gtest.h"
  20. namespace android {
  21. namespace nn {
  22. namespace wrapper {
  23. using ::testing::Each;
  24. using ::testing::FloatNear;
  25. using ::testing::Matcher;
  26. namespace {
  27. std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
  28. float max_abs_error = 1.e-5) {
  29. std::vector<Matcher<float>> matchers;
  30. matchers.reserve(values.size());
  31. for (const float& v : values) {
  32. matchers.emplace_back(FloatNear(v, max_abs_error));
  33. }
  34. return matchers;
  35. }
  36. static float rnn_input[] = {
  37. 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133,
  38. 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471,
  39. -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222,
  40. 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933,
  41. 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103,
  42. 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043,
  43. -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007,
  44. -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154,
  45. 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584,
  46. 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144,
  47. 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351,
  48. -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719,
  49. 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567,
  50. -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881,
  51. -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032,
  52. -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374,
  53. 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071,
  54. -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219,
  55. -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682,
  56. 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493,
  57. -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265,
  58. 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539,
  59. 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446,
  60. 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017,
  61. -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563,
  62. 0.93455386, -0.6324693, -0.083922029};
  63. static float rnn_golden_output[] = {
  64. 0.496726, 0, 0.965996, 0, 0.0584254, 0,
  65. 0, 0.12315, 0, 0, 0.612266, 0.456601,
  66. 0, 0.52286, 1.16099, 0.0291232,
  67. 0, 0, 0.524901, 0, 0, 0,
  68. 0, 1.02116, 0, 1.35762, 0, 0.356909,
  69. 0.436415, 0.0355727, 0, 0,
  70. 0, 0, 0, 0.262335, 0, 0,
  71. 0, 1.33992, 0, 2.9739, 0, 0,
  72. 1.31914, 2.66147, 0, 0,
  73. 0.942568, 0, 0, 0, 0.025507, 0,
  74. 0, 0, 0.321429, 0.569141, 1.25274, 1.57719,
  75. 0.8158, 1.21805, 0.586239, 0.25427,
  76. 1.04436, 0, 0.630725, 0, 0.133801, 0.210693,
  77. 0.363026, 0, 0.533426, 0, 1.25926, 0.722707,
  78. 0, 1.22031, 1.30117, 0.495867,
  79. 0.222187, 0, 0.72725, 0, 0.767003, 0,
  80. 0, 0.147835, 0, 0, 0, 0.608758,
  81. 0.469394, 0.00720298, 0.927537, 0,
  82. 0.856974, 0.424257, 0, 0, 0.937329, 0,
  83. 0, 0, 0.476425, 0, 0.566017, 0.418462,
  84. 0.141911, 0.996214, 1.13063, 0,
  85. 0.967899, 0, 0, 0, 0.0831304, 0,
  86. 0, 1.00378, 0, 0, 0, 1.44818,
  87. 1.01768, 0.943891, 0.502745, 0,
  88. 0.940135, 0, 0, 0, 0, 0,
  89. 0, 2.13243, 0, 0.71208, 0.123918, 1.53907,
  90. 1.30225, 1.59644, 0.70222, 0,
  91. 0.804329, 0, 0.430576, 0, 0.505872, 0.509603,
  92. 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311,
  93. 0.0454298, 0.300267, 0.562784, 0.395095,
  94. 0.228154, 0, 0.675323, 0, 1.70536, 0.766217,
  95. 0, 0, 0, 0.735363, 0.0759267, 1.91017,
  96. 0.941888, 0, 0, 0,
  97. 0, 0, 1.5909, 0, 0, 0,
  98. 0, 0.5755, 0, 0.184687, 0, 1.56296,
  99. 0.625285, 0, 0, 0,
  100. 0, 0, 0.0857888, 0, 0, 0,
  101. 0, 0.488383, 0.252786, 0, 0, 0,
  102. 1.02817, 1.85665, 0, 0,
  103. 0.00981836, 0, 1.06371, 0, 0, 0,
  104. 0, 0, 0, 0.290445, 0.316406, 0,
  105. 0.304161, 1.25079, 0.0707152, 0,
  106. 0.986264, 0.309201, 0, 0, 0, 0,
  107. 0, 1.64896, 0.346248, 0, 0.918175, 0.78884,
  108. 0.524981, 1.92076, 2.07013, 0.333244,
  109. 0.415153, 0.210318, 0, 0, 0, 0,
  110. 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
  111. 0.628881, 3.58099, 1.49974, 0};
  112. } // anonymous namespace
  113. #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
  114. ACTION(Input) \
  115. ACTION(Weights) \
  116. ACTION(RecurrentWeights) \
  117. ACTION(Bias) \
  118. ACTION(HiddenStateIn)
  119. // For all output and intermediate states
  120. #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
  121. ACTION(HiddenStateOut) \
  122. ACTION(Output)
  123. class BasicRNNOpModel {
  124. public:
  125. BasicRNNOpModel(uint32_t batches, uint32_t units, uint32_t size)
  126. : batches_(batches),
  127. units_(units),
  128. input_size_(size),
  129. activation_(kActivationRelu) {
  130. std::vector<uint32_t> inputs;
  131. OperandType InputTy(Type::TENSOR_FLOAT32, {batches_, input_size_});
  132. inputs.push_back(model_.addOperand(&InputTy));
  133. OperandType WeightTy(Type::TENSOR_FLOAT32, {units_, input_size_});
  134. inputs.push_back(model_.addOperand(&WeightTy));
  135. OperandType RecurrentWeightTy(Type::TENSOR_FLOAT32, {units_, units_});
  136. inputs.push_back(model_.addOperand(&RecurrentWeightTy));
  137. OperandType BiasTy(Type::TENSOR_FLOAT32, {units_});
  138. inputs.push_back(model_.addOperand(&BiasTy));
  139. OperandType HiddenStateTy(Type::TENSOR_FLOAT32, {batches_, units_});
  140. inputs.push_back(model_.addOperand(&HiddenStateTy));
  141. OperandType ActionParamTy(Type::INT32, {});
  142. inputs.push_back(model_.addOperand(&ActionParamTy));
  143. std::vector<uint32_t> outputs;
  144. outputs.push_back(model_.addOperand(&HiddenStateTy));
  145. OperandType OutputTy(Type::TENSOR_FLOAT32, {batches_, units_});
  146. outputs.push_back(model_.addOperand(&OutputTy));
  147. Input_.insert(Input_.end(), batches_ * input_size_, 0.f);
  148. HiddenStateIn_.insert(HiddenStateIn_.end(), batches_ * units_, 0.f);
  149. HiddenStateOut_.insert(HiddenStateOut_.end(), batches_ * units_, 0.f);
  150. Output_.insert(Output_.end(), batches_ * units_, 0.f);
  151. model_.addOperation(ANEURALNETWORKS_RNN, inputs, outputs);
  152. model_.identifyInputsAndOutputs(inputs, outputs);
  153. model_.finish();
  154. }
  155. #define DefineSetter(X) \
  156. void Set##X(const std::vector<float>& f) { \
  157. X##_.insert(X##_.end(), f.begin(), f.end()); \
  158. }
  159. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
  160. #undef DefineSetter
  161. void SetInput(int offset, float* begin, float* end) {
  162. for (; begin != end; begin++, offset++) {
  163. Input_[offset] = *begin;
  164. }
  165. }
  166. void ResetHiddenState() {
  167. std::fill(HiddenStateIn_.begin(), HiddenStateIn_.end(), 0.f);
  168. std::fill(HiddenStateOut_.begin(), HiddenStateOut_.end(), 0.f);
  169. }
  170. const std::vector<float>& GetOutput() const { return Output_; }
  171. uint32_t input_size() const { return input_size_; }
  172. uint32_t num_units() const { return units_; }
  173. uint32_t num_batches() const { return batches_; }
  174. void Invoke() {
  175. ASSERT_TRUE(model_.isValid());
  176. HiddenStateIn_.swap(HiddenStateOut_);
  177. Compilation compilation(&model_);
  178. compilation.finish();
  179. Execution execution(&compilation);
  180. #define SetInputOrWeight(X) \
  181. ASSERT_EQ(execution.setInput(RNN::k##X##Tensor, X##_.data(), \
  182. sizeof(float) * X##_.size()), \
  183. Result::NO_ERROR);
  184. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
  185. #undef SetInputOrWeight
  186. #define SetOutput(X) \
  187. ASSERT_EQ(execution.setOutput(RNN::k##X##Tensor, X##_.data(), \
  188. sizeof(float) * X##_.size()), \
  189. Result::NO_ERROR);
  190. FOR_ALL_OUTPUT_TENSORS(SetOutput);
  191. #undef SetOutput
  192. ASSERT_EQ(execution.setInput(RNN::kActivationParam, &activation_,
  193. sizeof(activation_)),
  194. Result::NO_ERROR);
  195. ASSERT_EQ(execution.compute(), Result::NO_ERROR);
  196. }
  197. private:
  198. Model model_;
  199. const uint32_t batches_;
  200. const uint32_t units_;
  201. const uint32_t input_size_;
  202. const int activation_;
  203. #define DefineTensor(X) std::vector<float> X##_;
  204. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
  205. FOR_ALL_OUTPUT_TENSORS(DefineTensor);
  206. #undef DefineTensor
  207. };
  208. TEST(RNNOpTest, BlackBoxTest) {
  209. BasicRNNOpModel rnn(2, 16, 8);
  210. rnn.SetWeights(
  211. {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
  212. 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
  213. 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
  214. -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
  215. -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
  216. -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
  217. -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
  218. 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
  219. 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
  220. 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
  221. -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
  222. 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
  223. -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
  224. -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
  225. 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
  226. 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
  227. 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
  228. -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
  229. 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
  230. 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
  231. -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
  232. 0.277308, 0.415818});
  233. rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
  234. -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
  235. 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
  236. -0.37609905});
  237. rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  238. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  239. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  240. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  241. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  242. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  243. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  244. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  245. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  246. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  247. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  248. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  249. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  250. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  251. 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
  252. 0.1});
  253. rnn.ResetHiddenState();
  254. const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
  255. (rnn.input_size() * rnn.num_batches());
  256. for (int i = 0; i < input_sequence_size; i++) {
  257. float* batch_start = rnn_input + i * rnn.input_size();
  258. float* batch_end = batch_start + rnn.input_size();
  259. rnn.SetInput(0, batch_start, batch_end);
  260. rnn.SetInput(rnn.input_size(), batch_start, batch_end);
  261. rnn.Invoke();
  262. float* golden_start = rnn_golden_output + i * rnn.num_units();
  263. float* golden_end = golden_start + rnn.num_units();
  264. std::vector<float> expected;
  265. expected.insert(expected.end(), golden_start, golden_end);
  266. expected.insert(expected.end(), golden_start, golden_end);
  267. EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
  268. }
  269. }
  270. } // namespace wrapper
  271. } // namespace nn
  272. } // namespace android