QuantizedLSTM.cpp 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "QuantizedLSTM.h"
  17. #include "CpuExecutor.h"
  18. #include "CpuOperationUtils.h"
  19. #include "Tracing.h"
  20. #include "public/gemmlowp.h"
  21. #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
  22. namespace android {
  23. namespace nn {
  24. namespace {
  25. template <typename T>
  26. inline T* GetBuffer(RunTimeOperandInfo* operand) {
  27. return reinterpret_cast<T*>(operand->buffer);
  28. }
  29. template <typename T>
  30. inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
  31. return reinterpret_cast<const T*>(operand->buffer);
  32. }
  33. using tflite::Dims;
  34. // The function below is taken from TF Lite implementation in order to decouple
  35. // NN API from TF Lite dependency. Original function, with a description of its
  36. // parameters and types can be found by this link:
  37. // https://github.com/tensorflow/tensorflow/blob/0d697e5fc4c05c699eea0764364104ea500ccc68/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h#L1926
  38. //
  39. // clang-format off
  40. template <int StateIntegerBits>
  41. void quantizedLstmStep(const uint8_t* input_data_uint8, const Dims<4>& input_dims,
  42. const uint8_t* prev_activ_data_uint8,
  43. const Dims<4>& prev_activ_dims, const uint8_t* weights_data_uint8,
  44. const Dims<4>& weights_dims, const int32_t* bias_data_int32,
  45. const Dims<4>& bias_dims, const int16_t* prevCellState_data_int16,
  46. const Dims<4>& prevCellState_dims, int16_t* output_state_data_int16,
  47. const Dims<4>& output_state_dims, uint8_t* output_activ_data_uint8,
  48. const Dims<4>& output_activ_dims, uint8_t* concat_temp_data_uint8,
  49. const Dims<4>& concat_temp_dims, int16_t* activ_temp_data_int16,
  50. const Dims<4>& activ_temp_dims, int32_t weights_zero_point,
  51. int32_t accum_multiplier, int accum_shift) {
  52. // Gather dimensions information, and perform consistency checks.
  53. const int outer_size =
  54. MatchingFlatSizeSkipDim(input_dims, 0, prev_activ_dims, prevCellState_dims,
  55. output_state_dims, output_activ_dims);
  56. TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1);
  57. TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1);
  58. const int input_depth = ArraySize(input_dims, 0);
  59. const int prev_activ_depth = ArraySize(prev_activ_dims, 0);
  60. const int total_input_depth = prev_activ_depth + input_depth;
  61. TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth);
  62. TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3),
  63. 1);
  64. const int intern_activ_depth =
  65. MatchingArraySize(weights_dims, 1, bias_dims, 0);
  66. TFLITE_CHECK_EQ(intern_activ_depth % 4, 0);
  67. const int output_depth =
  68. MatchingArraySize(prevCellState_dims, 0, prev_activ_dims, 0,
  69. output_state_dims, 0, output_activ_dims, 0);
  70. TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4);
  71. const int fc_batches = FlatSizeSkipDim(activ_temp_dims, 0);
  72. const int fc_output_depth =
  73. MatchingArraySize(weights_dims, 1, activ_temp_dims, 0);
  74. const int fc_accum_depth = ArraySize(weights_dims, 0);
  75. TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth);
  76. // Depth-concatenate prev_activ and input data together.
  77. uint8_t const* concat_input_arrays_data[2] = {input_data_uint8,
  78. prev_activ_data_uint8};
  79. Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims};
  80. tflite::reference_ops::Concatenation<tflite::FusedActivationFunctionType::kNone, uint8_t>(
  81. 0, concat_input_arrays_data, concat_input_arrays_dims, 2,
  82. concat_temp_data_uint8, concat_temp_dims);
  83. // Implementation of the fully connected node inside the LSTM cell.
  84. // The operands are 8-bit integers, the accumulators are internally 32bit
  85. // integers, and the output is 16-bit fixed-point with 3 integer bits so
  86. // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that
  87. // is explained in the function comment above.
  88. for (int b = 0; b < fc_batches; ++b) {
  89. for (int out_c = 0; out_c < fc_output_depth; ++out_c) {
  90. // Internal accumulation.
  91. // Initialize accumulator with the bias-value.
  92. int32_t accum = bias_data_int32[out_c];
  93. // Accumulation loop.
  94. for (int d = 0; d < fc_accum_depth; ++d) {
  95. int16_t input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128;
  96. int16_t weights_val =
  97. weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point;
  98. accum += input_val * weights_val;
  99. }
  100. // Down-scale the final int32 accumulator to the scale used by our
  101. // (16-bit, using 3 integer bits) fixed-point format. The quantized
  102. // multiplier and shift here have been pre-computed offline
  103. // (e.g. by toco).
  104. accum =
  105. tflite::MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift);
  106. // Saturate, cast to int16, and store to the temporary activations array.
  107. accum = std::max(-32768, std::min(32767, accum));
  108. activ_temp_data_int16[out_c + fc_output_depth * b] = accum;
  109. }
  110. }
  111. // Rest of the LSTM cell: tanh and logistic math functions, and some adds
  112. // and muls, all done in 16-bit fixed-point.
  113. for (int b = 0; b < outer_size; ++b) {
  114. for (int c = 0; c < output_depth; ++c) {
  115. // Define the fixed-point data types that we will use here. All use
  116. // int16 as the underlying integer type i.e. all are 16-bit fixed-point.
  117. // They only differ by the number of integral vs. fractional bits,
  118. // determining the range of values that they can represent.
  119. //
  120. // F0 uses 0 integer bits, range [-1, 1].
  121. // This is the return type of math functions such as tanh, logistic,
  122. // whose range is in [-1, 1].
  123. using F0 = gemmlowp::FixedPoint<std::int16_t, 0>;
  124. // F3 uses 3 integer bits, range [-8, 8].
  125. // This is the range of the previous fully-connected node's output,
  126. // which is our input here.
  127. using F3 = gemmlowp::FixedPoint<std::int16_t, 3>;
  128. // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits,
  129. // 2^StateIntegerBits]. It's used to represent the internal state, whose
  130. // number of integer bits is currently dictated by the model. See comment
  131. // on the StateIntegerBits template parameter above.
  132. using FS = gemmlowp::FixedPoint<std::int16_t, StateIntegerBits>;
  133. // Implementation of input gate, using fixed-point logistic function.
  134. F3 input_gate_input = F3::FromRaw(
  135. activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]);
  136. F0 input_gate_output = gemmlowp::logistic(input_gate_input);
  137. // Implementation of input modulation gate, using fixed-point tanh
  138. // function.
  139. F3 input_modulation_gate_input = F3::FromRaw(
  140. activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]);
  141. F0 input_modulation_gate_output =
  142. gemmlowp::tanh(input_modulation_gate_input);
  143. // Implementation of forget gate, using fixed-point logistic function.
  144. F3 forget_gate_input = F3::FromRaw(
  145. activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]);
  146. F0 forget_gate_output = gemmlowp::logistic(forget_gate_input);
  147. // Implementation of output gate, using fixed-point logistic function.
  148. F3 output_gate_input = F3::FromRaw(
  149. activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]);
  150. F0 output_gate_output = gemmlowp::logistic(output_gate_input);
  151. // Implementation of internal multiplication nodes, still in fixed-point.
  152. F0 input_times_input_modulation =
  153. input_gate_output * input_modulation_gate_output;
  154. FS prevCellState = FS::FromRaw(prevCellState_data_int16[b * output_depth + c]);
  155. FS prevCellState_times_forget_state = forget_gate_output * prevCellState;
  156. // Implementation of internal addition node, saturating.
  157. FS new_state = gemmlowp::SaturatingAdd(
  158. gemmlowp::Rescale<StateIntegerBits>(input_times_input_modulation),
  159. prevCellState_times_forget_state);
  160. // Implementation of last internal Tanh node, still in fixed-point.
  161. // Since a Tanh fixed-point implementation is specialized for a given
  162. // number or integer bits, and each specialization can have a substantial
  163. // code size, and we already used above a Tanh on an input with 3 integer
  164. // bits, and per the table in the above function comment there is no
  165. // significant accuracy to be lost by clamping to [-8, +8] for a
  166. // 3-integer-bits representation, let us just do that. This helps people
  167. // porting this to targets where code footprint must be minimized.
  168. F3 new_state_f3 = gemmlowp::Rescale<3>(new_state);
  169. F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3);
  170. // Store the new internal state back to memory, as 16-bit integers.
  171. // Note: here we store the original value with StateIntegerBits, not
  172. // the rescaled 3-integer-bits value fed to tanh.
  173. output_state_data_int16[b * output_depth + c] = new_state.raw();
  174. // Down-scale the output activations to 8-bit integers, saturating,
  175. // and store back to memory.
  176. int16_t rescaled_output_activ =
  177. gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8);
  178. int16_t clamped_output_activ =
  179. std::max<int16_t>(-128, std::min<int16_t>(127, rescaled_output_activ));
  180. output_activ_data_uint8[b * output_depth + c] =
  181. 128 + clamped_output_activ;
  182. }
  183. }
  184. }
  185. // clang-format on
  186. // The function assigns a 2D matrix to a submatrix of the weights at a given row
  187. // and column offsets.
  188. void assignWeightsSubmatrix(const RunTimeOperandInfo* submatrix, const int32_t offset_row,
  189. const int32_t offset_column, const std::vector<uint32_t>& weightsDims,
  190. uint8_t* weights) {
  191. const uint8_t* submatrixValues = GetBuffer<uint8_t>(submatrix);
  192. const std::vector<uint32_t> submatrixDims = submatrix->shape().dimensions;
  193. for (uint32_t i = 0; i < submatrixDims[0] * submatrixDims[1]; ++i) {
  194. const uint32_t row = i / submatrixDims[1];
  195. const uint32_t column = i % submatrixDims[1];
  196. weights[(row + offset_row) * weightsDims[1] + column + offset_column] = submatrixValues[i];
  197. }
  198. }
  199. } // namespace
  200. QuantizedLSTMCell::QuantizedLSTMCell(const Operation& operation,
  201. std::vector<RunTimeOperandInfo>& operands) {
  202. input_ = GetInput(operation, operands, kInputTensor);
  203. inputToInputWeights_ = GetInput(operation, operands, kInputToInputWeightsTensor);
  204. inputToForgetWeights_ = GetInput(operation, operands, kInputToForgetWeightsTensor);
  205. inputToCellWeights_ = GetInput(operation, operands, kInputToCellWeightsTensor);
  206. inputToOutputWeights_ = GetInput(operation, operands, kInputToOutputWeightsTensor);
  207. recurrentToInputWeights_ = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
  208. recurrentToForgetWeights_ = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
  209. recurrentToCellWeights_ = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
  210. recurrentToOutputWeights_ = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
  211. inputGateBias_ = GetInput(operation, operands, kInputGateBiasTensor);
  212. forgetGateBias_ = GetInput(operation, operands, kForgetGateBiasTensor);
  213. cellGateBias_ = GetInput(operation, operands, kCellGateBiasTensor);
  214. outputGateBias_ = GetInput(operation, operands, kOutputGateBiasTensor);
  215. prevCellState_ = GetInput(operation, operands, kPrevCellStateTensor);
  216. prevOutput_ = GetInput(operation, operands, kPrevOutputTensor);
  217. cellStateOut_ = GetOutput(operation, operands, kCellStateOutTensor);
  218. output_ = GetOutput(operation, operands, kOutputTensor);
  219. }
  220. bool QuantizedLSTMCell::prepare(const Operation& operation,
  221. std::vector<RunTimeOperandInfo>& operands, Shape* cellStateOutShape,
  222. Shape* outputShape) {
  223. auto input = GetInput(operation, operands, kInputTensor);
  224. NN_RET_CHECK_EQ(NumDimensions(input), 2);
  225. NN_RET_CHECK_EQ(input->scale, 1. / 128.0);
  226. NN_RET_CHECK_EQ(input->zeroPoint, 128);
  227. const uint32_t numBatches = SizeOfDimension(input, 0);
  228. const uint32_t inputSize = SizeOfDimension(input, 1);
  229. auto prevOutput = GetInput(operation, operands, kPrevOutputTensor);
  230. NN_RET_CHECK_EQ(NumDimensions(prevOutput), 2);
  231. NN_RET_CHECK_EQ(SizeOfDimension(prevOutput, 0), numBatches);
  232. NN_RET_CHECK_EQ(prevOutput->scale, 1. / 128.0);
  233. NN_RET_CHECK_EQ(prevOutput->zeroPoint, 128);
  234. const uint32_t outputSize = SizeOfDimension(prevOutput, 1);
  235. auto inputToInputWeights = GetInput(operation, operands, kInputToInputWeightsTensor);
  236. const float weightsScale = inputToInputWeights->scale;
  237. NN_RET_CHECK(weightsScale != 0);
  238. const float weightsZeroPoint = inputToInputWeights->zeroPoint;
  239. auto checkWeightsShape = [&](const RunTimeOperandInfo* weights, uint32_t columns) -> bool {
  240. NN_RET_CHECK_EQ(NumDimensions(weights), 2);
  241. NN_RET_CHECK_EQ(SizeOfDimension(weights, 0), outputSize);
  242. NN_RET_CHECK_EQ(SizeOfDimension(weights, 1), columns);
  243. NN_RET_CHECK_EQ(weights->scale, weightsScale);
  244. NN_RET_CHECK_EQ(weights->zeroPoint, weightsZeroPoint);
  245. return true;
  246. };
  247. auto inputToForgetWeights = GetInput(operation, operands, kInputToForgetWeightsTensor);
  248. auto inputToCellWeights = GetInput(operation, operands, kInputToCellWeightsTensor);
  249. auto inputToOutputWeights = GetInput(operation, operands, kInputToOutputWeightsTensor);
  250. NN_RET_CHECK(checkWeightsShape(inputToInputWeights, inputSize));
  251. NN_RET_CHECK(checkWeightsShape(inputToForgetWeights, inputSize));
  252. NN_RET_CHECK(checkWeightsShape(inputToCellWeights, inputSize));
  253. NN_RET_CHECK(checkWeightsShape(inputToOutputWeights, inputSize));
  254. auto recurrentToInputWeights = GetInput(operation, operands, kRecurrentToInputWeightsTensor);
  255. auto recurrentToForgetWeights = GetInput(operation, operands, kRecurrentToForgetWeightsTensor);
  256. auto recurrentToCellWeights = GetInput(operation, operands, kRecurrentToCellWeightsTensor);
  257. auto recurrentToOutputWeights = GetInput(operation, operands, kRecurrentToOutputWeightsTensor);
  258. NN_RET_CHECK(checkWeightsShape(recurrentToInputWeights, outputSize));
  259. NN_RET_CHECK(checkWeightsShape(recurrentToForgetWeights, outputSize));
  260. NN_RET_CHECK(checkWeightsShape(recurrentToCellWeights, outputSize));
  261. NN_RET_CHECK(checkWeightsShape(recurrentToOutputWeights, outputSize));
  262. auto inputGateBias = GetInput(operation, operands, kInputGateBiasTensor);
  263. const float biasScale = inputGateBias->scale;
  264. NN_RET_CHECK_EQ(biasScale, weightsScale / 128.0);
  265. const float biasZeroPoint = inputGateBias->zeroPoint;
  266. NN_RET_CHECK_EQ(biasZeroPoint, 0);
  267. auto checkBiasShape = [&](const RunTimeOperandInfo* bias) -> bool {
  268. NN_RET_CHECK_EQ(NumDimensions(bias), 1);
  269. NN_RET_CHECK_EQ(SizeOfDimension(bias, 0), outputSize);
  270. NN_RET_CHECK_EQ(bias->scale, biasScale);
  271. NN_RET_CHECK_EQ(bias->zeroPoint, biasZeroPoint);
  272. return true;
  273. };
  274. auto forgetGateBias = GetInput(operation, operands, kForgetGateBiasTensor);
  275. auto cellGateBias = GetInput(operation, operands, kCellGateBiasTensor);
  276. auto outputGateBias = GetInput(operation, operands, kOutputGateBiasTensor);
  277. NN_RET_CHECK(checkBiasShape(inputGateBias));
  278. NN_RET_CHECK(checkBiasShape(forgetGateBias));
  279. NN_RET_CHECK(checkBiasShape(cellGateBias));
  280. NN_RET_CHECK(checkBiasShape(outputGateBias));
  281. auto prevCellState = GetInput(operation, operands, kPrevCellStateTensor);
  282. NN_CHECK_EQ(NumDimensions(prevCellState), 2);
  283. NN_CHECK_EQ(SizeOfDimension(prevCellState, 0), numBatches);
  284. NN_CHECK_EQ(SizeOfDimension(prevCellState, 1), outputSize);
  285. NN_CHECK_EQ(prevCellState->zeroPoint, 0);
  286. // Cell state range for quantized LSTM is a function of StateIntegerBits and
  287. // can be calculated as:
  288. // [-2^StateIntegerBits, 2^StateIntegerBits * 32767/32768].
  289. // Therefore, for a fixed StateIntegerBits parameter, cell state scale is
  290. // equal to 2^StateIntegerBits * 2^(-15) = 2^(StateIntegerBits - 15) and
  291. // therefore:
  292. // StateIntegerBits = log2(cell state scale) + 15
  293. int stateScaleLog2Rounded;
  294. NN_CHECK(tflite::CheckedLog2(prevCellState->scale, &stateScaleLog2Rounded));
  295. const int stateIntegerBits = 15 + stateScaleLog2Rounded;
  296. // We only support StateIntegerBits == 4
  297. NN_CHECK(stateIntegerBits == 4);
  298. *cellStateOutShape = prevCellState->shape();
  299. *outputShape = prevOutput->shape();
  300. return true;
  301. }
  302. // The function contatenates 8 input weight matrices into one. Resulting matrix
  303. // has a shape [4 * outputSize, outputSize + inputSize]. The matrix is
  304. // constructed as follows:
  305. // +-----------------------------------+
  306. // | recurrentToInput | inputToInput |
  307. // |-------------------+---------------|
  308. // | recurrentToCell | inputToCell |
  309. // |-------------------+---------------|
  310. // | recurrentToForget | inputToForget |
  311. // |-------------------+---------------|
  312. // | recurrentToOutput | inputToOutput |
  313. // +-----------------------------------+
  314. void QuantizedLSTMCell::concatenateWeights(const std::vector<uint32_t>& weightsDims,
  315. uint8_t* weights) {
  316. const int outputSize = SizeOfDimension(inputToInputWeights_, 0);
  317. assignWeightsSubmatrix(inputToInputWeights_, 0 * outputSize, outputSize, weightsDims, weights);
  318. assignWeightsSubmatrix(inputToCellWeights_, 1 * outputSize, outputSize, weightsDims, weights);
  319. assignWeightsSubmatrix(inputToForgetWeights_, 2 * outputSize, outputSize, weightsDims, weights);
  320. assignWeightsSubmatrix(inputToOutputWeights_, 3 * outputSize, outputSize, weightsDims, weights);
  321. assignWeightsSubmatrix(recurrentToInputWeights_, 0 * outputSize, 0, weightsDims, weights);
  322. assignWeightsSubmatrix(recurrentToCellWeights_, 1 * outputSize, 0, weightsDims, weights);
  323. assignWeightsSubmatrix(recurrentToForgetWeights_, 2 * outputSize, 0, weightsDims, weights);
  324. assignWeightsSubmatrix(recurrentToOutputWeights_, 3 * outputSize, 0, weightsDims, weights);
  325. }
  326. // The function concatenate four bias vectors of shape [outputSize] into one
  327. // vector of shape [4 * outputSize].
  328. void QuantizedLSTMCell::concatenateBiases(uint32_t outputSize, int32_t* bias) {
  329. memcpy(bias + 0 * outputSize, GetBuffer<int32_t>(inputGateBias_), sizeof(int32_t) * outputSize);
  330. memcpy(bias + 1 * outputSize, GetBuffer<int32_t>(cellGateBias_), sizeof(int32_t) * outputSize);
  331. memcpy(bias + 2 * outputSize, GetBuffer<int32_t>(forgetGateBias_),
  332. sizeof(int32_t) * outputSize);
  333. memcpy(bias + 3 * outputSize, GetBuffer<int32_t>(outputGateBias_),
  334. sizeof(int32_t) * outputSize);
  335. }
  336. bool QuantizedLSTMCell::eval() {
  337. NNTRACE_COMP("QuantizedLSTM::eval");
  338. Shape weightsShape;
  339. weightsShape.dimensions = {4 * SizeOfDimension(prevOutput_, 1),
  340. SizeOfDimension(input_, 1) + SizeOfDimension(prevOutput_, 1)};
  341. std::vector<uint8_t> weights(getNumberOfElements(weightsShape));
  342. concatenateWeights(weightsShape.dimensions, weights.data());
  343. Shape biasShape;
  344. biasShape.dimensions = {getSizeOfDimension(weightsShape, 0)};
  345. std::vector<int32_t> bias(getNumberOfElements(biasShape));
  346. concatenateBiases(SizeOfDimension(prevOutput_, 1), bias.data());
  347. Shape concatTempShape;
  348. concatTempShape.dimensions = {SizeOfDimension(input_, 0), getSizeOfDimension(weightsShape, 1)};
  349. Shape activationTempShape;
  350. activationTempShape.dimensions = {SizeOfDimension(input_, 0),
  351. getSizeOfDimension(weightsShape, 0)};
  352. std::vector<uint8_t> concatTemp(getNumberOfElements(concatTempShape));
  353. std::vector<int16_t> activationTemp(getNumberOfElements(activationTempShape));
  354. // From https://arxiv.org/pdf/1712.05877, for a fully-connected layer,
  355. // accumulator multiplier is equal to:
  356. // (input scale) * (weights scale) / (fully-connected output scale)
  357. // In our case fully-connected output scale is fixed and equal to
  358. // 2^(-12) (See LSTMCell definition in TF Lite for more details on that).
  359. // But bias scale is set to (input scale) * (weights scale) (also from the
  360. // paper), so we can multiply it to an inverse of the fc-output scale to get
  361. // the multiplier value:
  362. double realAccumMultiplier = 4096 * inputGateBias_->scale;
  363. int32_t accumMultiplier;
  364. int accumShift;
  365. tflite::QuantizeMultiplier(realAccumMultiplier, &accumMultiplier, &accumShift);
  366. quantizedLstmStep<4>(
  367. // Inputs.
  368. GetBuffer<const uint8_t>(input_), convertShapeToDims(input_->shape()),
  369. GetBuffer<const uint8_t>(prevOutput_), convertShapeToDims(prevOutput_->shape()),
  370. weights.data(), convertShapeToDims(weightsShape), bias.data(),
  371. convertShapeToDims(biasShape), GetBuffer<const int16_t>(prevCellState_),
  372. convertShapeToDims(prevCellState_->shape()),
  373. // Outputs.
  374. GetBuffer<int16_t>(cellStateOut_), convertShapeToDims(cellStateOut_->shape()),
  375. GetBuffer<uint8_t>(output_), convertShapeToDims(output_->shape()), concatTemp.data(),
  376. convertShapeToDims(concatTempShape), activationTemp.data(),
  377. convertShapeToDims(activationTempShape), inputToInputWeights_->zeroPoint,
  378. accumMultiplier, accumShift);
  379. return true;
  380. }
  381. } // namespace nn
  382. } // namespace android