QuantizedLSTMTest.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "QuantizedLSTM.h"
  17. #include "NeuralNetworksWrapper.h"
  18. #include "gmock/gmock-matchers.h"
  19. #include "gtest/gtest.h"
  20. #include <iostream>
  21. namespace android {
  22. namespace nn {
  23. namespace wrapper {
  24. namespace {
  25. struct OperandTypeParams {
  26. Type type;
  27. std::vector<uint32_t> shape;
  28. float scale;
  29. int32_t zeroPoint;
  30. OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint)
  31. : type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {}
  32. };
  33. } // namespace
  34. using ::testing::Each;
  35. using ::testing::ElementsAreArray;
  36. using ::testing::FloatNear;
  37. using ::testing::Matcher;
  38. class QuantizedLSTMOpModel {
  39. public:
  40. QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) {
  41. std::vector<uint32_t> inputs;
  42. for (int i = 0; i < NUM_INPUTS; ++i) {
  43. const auto& curOTP = inputOperandTypeParams[i];
  44. OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint);
  45. inputs.push_back(model_.addOperand(&curType));
  46. }
  47. const uint32_t numBatches = inputOperandTypeParams[0].shape[0];
  48. inputSize_ = inputOperandTypeParams[0].shape[0];
  49. const uint32_t outputSize =
  50. inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1];
  51. outputSize_ = outputSize;
  52. std::vector<uint32_t> outputs;
  53. OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize},
  54. 1. / 2048., 0);
  55. outputs.push_back(model_.addOperand(&cellStateOutOperandType));
  56. OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize},
  57. 1. / 128., 128);
  58. outputs.push_back(model_.addOperand(&outputOperandType));
  59. model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs);
  60. model_.identifyInputsAndOutputs(inputs, outputs);
  61. initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_);
  62. initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor],
  63. &prevOutput_);
  64. initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor],
  65. &prevCellState_);
  66. cellStateOut_.resize(numBatches * outputSize, 0);
  67. output_.resize(numBatches * outputSize, 0);
  68. model_.finish();
  69. }
  70. void invoke() {
  71. ASSERT_TRUE(model_.isValid());
  72. Compilation compilation(&model_);
  73. compilation.finish();
  74. Execution execution(&compilation);
  75. // Set all the inputs.
  76. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_),
  77. Result::NO_ERROR);
  78. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor,
  79. inputToInputWeights_),
  80. Result::NO_ERROR);
  81. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor,
  82. inputToForgetWeights_),
  83. Result::NO_ERROR);
  84. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor,
  85. inputToCellWeights_),
  86. Result::NO_ERROR);
  87. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor,
  88. inputToOutputWeights_),
  89. Result::NO_ERROR);
  90. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor,
  91. recurrentToInputWeights_),
  92. Result::NO_ERROR);
  93. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor,
  94. recurrentToForgetWeights_),
  95. Result::NO_ERROR);
  96. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor,
  97. recurrentToCellWeights_),
  98. Result::NO_ERROR);
  99. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor,
  100. recurrentToOutputWeights_),
  101. Result::NO_ERROR);
  102. ASSERT_EQ(
  103. setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_),
  104. Result::NO_ERROR);
  105. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor,
  106. forgetGateBias_),
  107. Result::NO_ERROR);
  108. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_),
  109. Result::NO_ERROR);
  110. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor,
  111. outputGateBias_),
  112. Result::NO_ERROR);
  113. ASSERT_EQ(
  114. setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_),
  115. Result::NO_ERROR);
  116. ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_),
  117. Result::NO_ERROR);
  118. // Set all the outputs.
  119. ASSERT_EQ(
  120. setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_),
  121. Result::NO_ERROR);
  122. ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_),
  123. Result::NO_ERROR);
  124. ASSERT_EQ(execution.compute(), Result::NO_ERROR);
  125. // Put state outputs into inputs for the next step
  126. prevOutput_ = output_;
  127. prevCellState_ = cellStateOut_;
  128. }
  129. int inputSize() { return inputSize_; }
  130. int outputSize() { return outputSize_; }
  131. void setInput(const std::vector<uint8_t>& input) { input_ = input; }
  132. void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,
  133. std::vector<uint8_t> inputToForgetWeights,
  134. std::vector<uint8_t> inputToCellWeights,
  135. std::vector<uint8_t> inputToOutputWeights,
  136. std::vector<uint8_t> recurrentToInputWeights,
  137. std::vector<uint8_t> recurrentToForgetWeights,
  138. std::vector<uint8_t> recurrentToCellWeights,
  139. std::vector<uint8_t> recurrentToOutputWeights,
  140. std::vector<int32_t> inputGateBias,
  141. std::vector<int32_t> forgetGateBias,
  142. std::vector<int32_t> cellGateBias, //
  143. std::vector<int32_t> outputGateBias) {
  144. inputToInputWeights_ = inputToInputWeights;
  145. inputToForgetWeights_ = inputToForgetWeights;
  146. inputToCellWeights_ = inputToCellWeights;
  147. inputToOutputWeights_ = inputToOutputWeights;
  148. recurrentToInputWeights_ = recurrentToInputWeights;
  149. recurrentToForgetWeights_ = recurrentToForgetWeights;
  150. recurrentToCellWeights_ = recurrentToCellWeights;
  151. recurrentToOutputWeights_ = recurrentToOutputWeights;
  152. inputGateBias_ = inputGateBias;
  153. forgetGateBias_ = forgetGateBias;
  154. cellGateBias_ = cellGateBias;
  155. outputGateBias_ = outputGateBias;
  156. }
  157. template <typename T>
  158. void initializeInputData(OperandTypeParams params, std::vector<T>* vec) {
  159. int size = 1;
  160. for (int d : params.shape) {
  161. size *= d;
  162. }
  163. vec->clear();
  164. vec->resize(size, params.zeroPoint);
  165. }
  166. std::vector<uint8_t> getOutput() { return output_; }
  167. private:
  168. static constexpr int NUM_INPUTS = 15;
  169. static constexpr int NUM_OUTPUTS = 2;
  170. Model model_;
  171. // Inputs
  172. std::vector<uint8_t> input_;
  173. std::vector<uint8_t> inputToInputWeights_;
  174. std::vector<uint8_t> inputToForgetWeights_;
  175. std::vector<uint8_t> inputToCellWeights_;
  176. std::vector<uint8_t> inputToOutputWeights_;
  177. std::vector<uint8_t> recurrentToInputWeights_;
  178. std::vector<uint8_t> recurrentToForgetWeights_;
  179. std::vector<uint8_t> recurrentToCellWeights_;
  180. std::vector<uint8_t> recurrentToOutputWeights_;
  181. std::vector<int32_t> inputGateBias_;
  182. std::vector<int32_t> forgetGateBias_;
  183. std::vector<int32_t> cellGateBias_;
  184. std::vector<int32_t> outputGateBias_;
  185. std::vector<int16_t> prevCellState_;
  186. std::vector<uint8_t> prevOutput_;
  187. // Outputs
  188. std::vector<int16_t> cellStateOut_;
  189. std::vector<uint8_t> output_;
  190. int inputSize_;
  191. int outputSize_;
  192. template <typename T>
  193. Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) {
  194. return execution->setInput(tensor, data.data(), sizeof(T) * data.size());
  195. }
  196. template <typename T>
  197. Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) {
  198. return execution->setOutput(tensor, data->data(), sizeof(T) * data->size());
  199. }
  200. };
  201. class QuantizedLstmTest : public ::testing::Test {
  202. protected:
  203. void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
  204. const std::vector<std::vector<uint8_t>>& output,
  205. QuantizedLSTMOpModel* lstm) {
  206. const int numBatches = input.size();
  207. EXPECT_GT(numBatches, 0);
  208. const int inputSize = lstm->inputSize();
  209. EXPECT_GT(inputSize, 0);
  210. const int inputSequenceSize = input[0].size() / inputSize;
  211. EXPECT_GT(inputSequenceSize, 0);
  212. for (int i = 0; i < inputSequenceSize; ++i) {
  213. std::vector<uint8_t> inputStep;
  214. for (int b = 0; b < numBatches; ++b) {
  215. const uint8_t* batchStart = input[b].data() + i * inputSize;
  216. const uint8_t* batchEnd = batchStart + inputSize;
  217. inputStep.insert(inputStep.end(), batchStart, batchEnd);
  218. }
  219. lstm->setInput(inputStep);
  220. lstm->invoke();
  221. const int outputSize = lstm->outputSize();
  222. std::vector<float> expected;
  223. for (int b = 0; b < numBatches; ++b) {
  224. const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
  225. const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
  226. expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
  227. }
  228. EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
  229. }
  230. }
  231. };
  232. // Inputs and weights in this test are random and the test only checks that the
  233. // outputs are equal to outputs obtained from running TF Lite version of
  234. // quantized LSTM on the same inputs.
  235. TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
  236. const int numBatches = 2;
  237. const int inputSize = 2;
  238. const int outputSize = 4;
  239. float weightsScale = 0.00408021;
  240. int weightsZeroPoint = 100;
  241. // OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3],
  242. // weightsScale / 128., 0);
  243. // inputs.push_back(model_.addOperand(&biasOperandType));
  244. // OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4],
  245. // 1. / 2048., 0);
  246. // inputs.push_back(model_.addOperand(&prevCellStateOperandType));
  247. QuantizedLSTMOpModel lstm({
  248. // input
  249. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128),
  250. // inputToInputWeights
  251. // inputToForgetWeights
  252. // inputToCellWeights
  253. // inputToOutputWeights
  254. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
  255. weightsZeroPoint),
  256. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
  257. weightsZeroPoint),
  258. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
  259. weightsZeroPoint),
  260. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
  261. weightsZeroPoint),
  262. // recurrentToInputWeights
  263. // recurrentToForgetWeights
  264. // recurrentToCellWeights
  265. // recurrentToOutputWeights
  266. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
  267. weightsZeroPoint),
  268. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
  269. weightsZeroPoint),
  270. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
  271. weightsZeroPoint),
  272. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
  273. weightsZeroPoint),
  274. // inputGateBias
  275. // forgetGateBias
  276. // cellGateBias
  277. // outputGateBias
  278. OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
  279. OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
  280. OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
  281. OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
  282. // prevCellState
  283. OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0),
  284. // prevOutput
  285. OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128),
  286. });
  287. lstm.setWeightsAndBiases(
  288. // inputToInputWeights
  289. {146, 250, 235, 171, 10, 218, 171, 108},
  290. // inputToForgetWeights
  291. {24, 50, 132, 179, 158, 110, 3, 169},
  292. // inputToCellWeights
  293. {133, 34, 29, 49, 206, 109, 54, 183},
  294. // inputToOutputWeights
  295. {195, 187, 11, 99, 109, 10, 218, 48},
  296. // recurrentToInputWeights
  297. {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26},
  298. // recurrentToForgetWeights
  299. {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253},
  300. // recurrentToCellWeights
  301. {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216},
  302. // recurrentToOutputWeights
  303. {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98},
  304. // inputGateBias
  305. {-7876, 13488, -726, 32839},
  306. // forgetGateBias
  307. {9206, -46884, -11693, -38724},
  308. // cellGateBias
  309. {39481, 48624, 48976, -21419},
  310. // outputGateBias
  311. {-58999, -17050, -41852, -40538});
  312. // LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
  313. std::vector<std::vector<uint8_t>> lstmInput;
  314. // clang-format off
  315. lstmInput = {{154, 166,
  316. 166, 179,
  317. 141, 141},
  318. {100, 200,
  319. 50, 150,
  320. 111, 222}};
  321. // clang-format on
  322. // LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
  323. std::vector<std::vector<uint8_t>> lstmGoldenOutput;
  324. // clang-format off
  325. lstmGoldenOutput = {{136, 150, 140, 115,
  326. 140, 151, 146, 112,
  327. 139, 153, 146, 114},
  328. {135, 152, 138, 112,
  329. 136, 156, 142, 112,
  330. 141, 154, 146, 108}};
  331. // clang-format on
  332. VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
  333. };
  334. } // namespace wrapper
  335. } // namespace nn
  336. } // namespace android