LayerNormLSTMTest.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "LSTM.h"
  17. #include <android-base/logging.h>
  18. #include "NeuralNetworksWrapper.h"
  19. #include "gmock/gmock-matchers.h"
  20. #include "gtest/gtest.h"
  21. #include <sstream>
  22. #include <string>
  23. #include <vector>
  24. namespace android {
  25. namespace nn {
  26. namespace wrapper {
  27. using ::testing::Each;
  28. using ::testing::FloatNear;
  29. using ::testing::Matcher;
  30. namespace {
  31. std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
  32. float max_abs_error = 1.e-6) {
  33. std::vector<Matcher<float>> matchers;
  34. matchers.reserve(values.size());
  35. for (const float& v : values) {
  36. matchers.emplace_back(FloatNear(v, max_abs_error));
  37. }
  38. return matchers;
  39. }
  40. } // anonymous namespace
  41. #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
  42. ACTION(Input) \
  43. ACTION(InputToInputWeights) \
  44. ACTION(InputToCellWeights) \
  45. ACTION(InputToForgetWeights) \
  46. ACTION(InputToOutputWeights) \
  47. ACTION(RecurrentToInputWeights) \
  48. ACTION(RecurrentToCellWeights) \
  49. ACTION(RecurrentToForgetWeights) \
  50. ACTION(RecurrentToOutputWeights) \
  51. ACTION(CellToInputWeights) \
  52. ACTION(CellToForgetWeights) \
  53. ACTION(CellToOutputWeights) \
  54. ACTION(InputGateBias) \
  55. ACTION(CellGateBias) \
  56. ACTION(ForgetGateBias) \
  57. ACTION(OutputGateBias) \
  58. ACTION(ProjectionWeights) \
  59. ACTION(ProjectionBias) \
  60. ACTION(OutputStateIn) \
  61. ACTION(CellStateIn)
  62. #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
  63. ACTION(InputLayerNormWeights) \
  64. ACTION(ForgetLayerNormWeights) \
  65. ACTION(CellLayerNormWeights) \
  66. ACTION(OutputLayerNormWeights)
  67. // For all output and intermediate states
  68. #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
  69. ACTION(ScratchBuffer) \
  70. ACTION(OutputStateOut) \
  71. ACTION(CellStateOut) \
  72. ACTION(Output)
  73. class LayerNormLSTMOpModel {
  74. public:
  75. LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
  76. bool use_cifg, bool use_peephole, bool use_projection_weights,
  77. bool use_projection_bias, float cell_clip, float proj_clip,
  78. const std::vector<std::vector<uint32_t>>& input_shapes0)
  79. : n_input_(n_input),
  80. n_output_(n_output),
  81. use_cifg_(use_cifg),
  82. use_peephole_(use_peephole),
  83. use_projection_weights_(use_projection_weights),
  84. use_projection_bias_(use_projection_bias),
  85. activation_(ActivationFn::kActivationTanh),
  86. cell_clip_(cell_clip),
  87. proj_clip_(proj_clip) {
  88. std::vector<uint32_t> inputs;
  89. std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
  90. auto it = input_shapes.begin();
  91. // Input and weights
  92. #define AddInput(X) \
  93. CHECK(it != input_shapes.end()); \
  94. OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
  95. inputs.push_back(model_.addOperand(&X##OpndTy));
  96. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
  97. // Parameters
  98. OperandType ActivationOpndTy(Type::INT32, {});
  99. inputs.push_back(model_.addOperand(&ActivationOpndTy));
  100. OperandType CellClipOpndTy(Type::FLOAT32, {});
  101. inputs.push_back(model_.addOperand(&CellClipOpndTy));
  102. OperandType ProjClipOpndTy(Type::FLOAT32, {});
  103. inputs.push_back(model_.addOperand(&ProjClipOpndTy));
  104. FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
  105. #undef AddOperand
  106. // Output and other intermediate state
  107. std::vector<std::vector<uint32_t>> output_shapes{
  108. {n_batch, n_cell * (use_cifg ? 3 : 4)},
  109. {n_batch, n_output},
  110. {n_batch, n_cell},
  111. {n_batch, n_output},
  112. };
  113. std::vector<uint32_t> outputs;
  114. auto it2 = output_shapes.begin();
  115. #define AddOutput(X) \
  116. CHECK(it2 != output_shapes.end()); \
  117. OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
  118. outputs.push_back(model_.addOperand(&X##OpndTy));
  119. FOR_ALL_OUTPUT_TENSORS(AddOutput);
  120. #undef AddOutput
  121. model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
  122. model_.identifyInputsAndOutputs(inputs, outputs);
  123. Input_.insert(Input_.end(), n_batch * n_input, 0.f);
  124. OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
  125. CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
  126. auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
  127. uint32_t sz = 1;
  128. for (uint32_t d : dims) {
  129. sz *= d;
  130. }
  131. return sz;
  132. };
  133. it2 = output_shapes.begin();
  134. #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
  135. FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
  136. #undef ReserveOutput
  137. model_.finish();
  138. }
  139. #define DefineSetter(X) \
  140. void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
  141. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
  142. FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
  143. #undef DefineSetter
  144. void ResetOutputState() {
  145. std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
  146. std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
  147. }
  148. void ResetCellState() {
  149. std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
  150. std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
  151. }
  152. void SetInput(int offset, const float* begin, const float* end) {
  153. for (; begin != end; begin++, offset++) {
  154. Input_[offset] = *begin;
  155. }
  156. }
  157. uint32_t num_inputs() const { return n_input_; }
  158. uint32_t num_outputs() const { return n_output_; }
  159. const std::vector<float>& GetOutput() const { return Output_; }
  160. void Invoke() {
  161. ASSERT_TRUE(model_.isValid());
  162. OutputStateIn_.swap(OutputStateOut_);
  163. CellStateIn_.swap(CellStateOut_);
  164. Compilation compilation(&model_);
  165. compilation.finish();
  166. Execution execution(&compilation);
  167. #define SetInputOrWeight(X) \
  168. ASSERT_EQ( \
  169. execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
  170. Result::NO_ERROR);
  171. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
  172. FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
  173. #undef SetInputOrWeight
  174. #define SetOutput(X) \
  175. ASSERT_EQ( \
  176. execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
  177. Result::NO_ERROR);
  178. FOR_ALL_OUTPUT_TENSORS(SetOutput);
  179. #undef SetOutput
  180. if (use_cifg_) {
  181. execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
  182. execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
  183. }
  184. if (use_peephole_) {
  185. if (use_cifg_) {
  186. execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
  187. }
  188. } else {
  189. execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
  190. execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
  191. execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
  192. }
  193. if (use_projection_weights_) {
  194. if (!use_projection_bias_) {
  195. execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
  196. }
  197. } else {
  198. execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
  199. execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
  200. }
  201. ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
  202. Result::NO_ERROR);
  203. ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
  204. Result::NO_ERROR);
  205. ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
  206. Result::NO_ERROR);
  207. ASSERT_EQ(execution.compute(), Result::NO_ERROR);
  208. }
  209. private:
  210. Model model_;
  211. // Execution execution_;
  212. const uint32_t n_input_;
  213. const uint32_t n_output_;
  214. const bool use_cifg_;
  215. const bool use_peephole_;
  216. const bool use_projection_weights_;
  217. const bool use_projection_bias_;
  218. const int activation_;
  219. const float cell_clip_;
  220. const float proj_clip_;
  221. #define DefineTensor(X) std::vector<float> X##_;
  222. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
  223. FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
  224. FOR_ALL_OUTPUT_TENSORS(DefineTensor);
  225. #undef DefineTensor
  226. };
  227. TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
  228. const int n_batch = 2;
  229. const int n_input = 5;
  230. // n_cell and n_output have the same size when there is no projection.
  231. const int n_cell = 4;
  232. const int n_output = 3;
  233. LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
  234. /*use_cifg=*/false, /*use_peephole=*/true,
  235. /*use_projection_weights=*/true,
  236. /*use_projection_bias=*/false,
  237. /*cell_clip=*/0.0, /*proj_clip=*/0.0,
  238. {
  239. {n_batch, n_input}, // input tensor
  240. {n_cell, n_input}, // input_to_input_weight tensor
  241. {n_cell, n_input}, // input_to_forget_weight tensor
  242. {n_cell, n_input}, // input_to_cell_weight tensor
  243. {n_cell, n_input}, // input_to_output_weight tensor
  244. {n_cell, n_output}, // recurrent_to_input_weight tensor
  245. {n_cell, n_output}, // recurrent_to_forget_weight tensor
  246. {n_cell, n_output}, // recurrent_to_cell_weight tensor
  247. {n_cell, n_output}, // recurrent_to_output_weight tensor
  248. {n_cell}, // cell_to_input_weight tensor
  249. {n_cell}, // cell_to_forget_weight tensor
  250. {n_cell}, // cell_to_output_weight tensor
  251. {n_cell}, // input_gate_bias tensor
  252. {n_cell}, // forget_gate_bias tensor
  253. {n_cell}, // cell_bias tensor
  254. {n_cell}, // output_gate_bias tensor
  255. {n_output, n_cell}, // projection_weight tensor
  256. {0}, // projection_bias tensor
  257. {n_batch, n_output}, // output_state_in tensor
  258. {n_batch, n_cell}, // cell_state_in tensor
  259. {n_cell}, // input_layer_norm_weights tensor
  260. {n_cell}, // forget_layer_norm_weights tensor
  261. {n_cell}, // cell_layer_norm_weights tensor
  262. {n_cell}, // output_layer_norm_weights tensor
  263. });
  264. lstm.SetInputToInputWeights({0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
  265. -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
  266. lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
  267. -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5});
  268. lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
  269. 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6});
  270. lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
  271. 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4});
  272. lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
  273. lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
  274. lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
  275. lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
  276. lstm.SetRecurrentToInputWeights(
  277. {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
  278. lstm.SetRecurrentToCellWeights(
  279. {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
  280. lstm.SetRecurrentToForgetWeights(
  281. {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
  282. lstm.SetRecurrentToOutputWeights(
  283. {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
  284. lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
  285. lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
  286. lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
  287. lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
  288. lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
  289. lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
  290. lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
  291. lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
  292. const std::vector<std::vector<float>> lstm_input = {
  293. { // Batch0: 3 (input_sequence_size) * 5 (n_input)
  294. 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
  295. 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
  296. 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
  297. { // Batch1: 3 (input_sequence_size) * 5 (n_input)
  298. 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
  299. 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
  300. 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
  301. };
  302. const std::vector<std::vector<float>> lstm_golden_output = {
  303. {
  304. // Batch0: 3 (input_sequence_size) * 3 (n_output)
  305. 0.0244077, 0.128027, -0.00170918, // seq 0
  306. 0.0137642, 0.140751, 0.0395835, // seq 1
  307. -0.00459231, 0.155278, 0.0837377, // seq 2
  308. },
  309. {
  310. // Batch1: 3 (input_sequence_size) * 3 (n_output)
  311. -0.00692428, 0.0848741, 0.063445, // seq 0
  312. -0.00403912, 0.139963, 0.072681, // seq 1
  313. 0.00752706, 0.161903, 0.0561371, // seq 2
  314. }};
  315. // Resetting cell_state and output_state
  316. lstm.ResetCellState();
  317. lstm.ResetOutputState();
  318. const int input_sequence_size = lstm_input[0].size() / n_input;
  319. for (int i = 0; i < input_sequence_size; i++) {
  320. for (int b = 0; b < n_batch; ++b) {
  321. const float* batch_start = lstm_input[b].data() + i * n_input;
  322. const float* batch_end = batch_start + n_input;
  323. lstm.SetInput(b * n_input, batch_start, batch_end);
  324. }
  325. lstm.Invoke();
  326. std::vector<float> expected;
  327. for (int b = 0; b < n_batch; ++b) {
  328. const float* golden_start = lstm_golden_output[b].data() + i * n_output;
  329. const float* golden_end = golden_start + n_output;
  330. expected.insert(expected.end(), golden_start, golden_end);
  331. }
  332. EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
  333. }
  334. }
  335. } // namespace wrapper
  336. } // namespace nn
  337. } // namespace android