LSTM.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef FRAMEWORKS_ML_NN_LSTMCELL_H
  17. #define FRAMEWORKS_ML_NN_LSTMCELL_H
  18. #include "ActivationFunctor.h"
  19. #include "HalOperation.h"
  20. #include "tensorflow/lite/kernels/internal/tensor_utils.h"
  21. #include <algorithm>
  22. #include <cmath>
  23. namespace android {
  24. namespace nn {
  25. struct LSTMParams {
  26. TfLiteFusedActivation activation;
  27. float cell_clip;
  28. float proj_clip;
  29. bool use_cifg;
  30. bool use_peephole;
  31. bool use_layer_norm;
  32. bool use_projection_weight;
  33. bool use_projection_bias;
  34. bool merge_outputs;
  35. bool time_major;
  36. };
  37. struct RunTimeOperandInfo;
  38. struct Shape;
  39. class LSTMCell {
  40. public:
  41. LSTMCell(const Operation& operation, std::vector<RunTimeOperandInfo>& operands);
  42. bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
  43. Shape* scratchShape, Shape* outputStateShape, Shape* cellStateShape,
  44. Shape* outputShape);
  45. bool Eval();
  46. // Input Tensors of size {n_batch, n_input}
  47. static constexpr int kInputTensor = 0;
  48. // Input weight tensors of size: {n_cell, n_input}
  49. static constexpr int kInputToInputWeightsTensor = 1; // Optional
  50. static constexpr int kInputToForgetWeightsTensor = 2;
  51. static constexpr int kInputToCellWeightsTensor = 3;
  52. static constexpr int kInputToOutputWeightsTensor = 4;
  53. // Recurrent weight tensors of size {n_cell, n_output}
  54. static constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
  55. static constexpr int kRecurrentToForgetWeightsTensor = 6;
  56. static constexpr int kRecurrentToCellWeightsTensor = 7;
  57. static constexpr int kRecurrentToOutputWeightsTensor = 8;
  58. // Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
  59. static constexpr int kCellToInputWeightsTensor = 9; // Optional
  60. static constexpr int kCellToForgetWeightsTensor = 10; // Optional
  61. static constexpr int kCellToOutputWeightsTensor = 11; // Optional
  62. // Gates bias tensors of size {n_cell}
  63. static constexpr int kInputGateBiasTensor = 12; // Optional
  64. static constexpr int kForgetGateBiasTensor = 13;
  65. static constexpr int kCellGateBiasTensor = 14;
  66. static constexpr int kOutputGateBiasTensor = 15;
  67. // Projection weight tensor of size {n_output, n_cell}
  68. static constexpr int kProjectionWeightsTensor = 16; // Optional
  69. // Projection bias tensor of size {n_output}
  70. static constexpr int kProjectionBiasTensor = 17; // Optional
  71. static constexpr int kOutputStateInTensor = 18;
  72. static constexpr int kCellStateInTensor = 19;
  73. static constexpr int kActivationParam = 20;
  74. static constexpr int kCellClipParam = 21;
  75. static constexpr int kProjClipParam = 22;
  76. // Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
  77. static constexpr int kInputLayerNormWeightsTensor = 23;
  78. static constexpr int kForgetLayerNormWeightsTensor = 24;
  79. static constexpr int kCellLayerNormWeightsTensor = 25;
  80. static constexpr int kOutputLayerNormWeightsTensor = 26;
  81. // Output tensors.
  82. static constexpr int kScratchBufferTensor = 0;
  83. static constexpr int kOutputStateOutTensor = 1;
  84. static constexpr int kCellStateOutTensor = 2;
  85. static constexpr int kOutputTensor = 3;
  86. static constexpr float kLayerNormEpsilon = 1e-8;
  87. static bool LSTMEvalFloat32(
  88. const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
  89. const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
  90. const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
  91. const Shape& input_to_output_weights_shape,
  92. const float* recurrent_to_input_weights_buffer,
  93. const float* recurrent_to_forget_weights_buffer,
  94. const float* recurrent_to_cell_weights_buffer,
  95. const float* recurrent_to_output_weights_buffer,
  96. const Shape& recurrent_to_output_weights_shape,
  97. const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
  98. const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
  99. const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
  100. const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
  101. const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
  102. const float* cell_bias_buffer, const float* output_gate_bias_buffer,
  103. const float* projection_weights_buffer, const float* projection_bias_buffer,
  104. const float* output_state_in_buffer, const float* cell_state_in_buffer,
  105. const float* input_layer_norm_weights_buffer,
  106. const float* forget_layer_norm_weights_buffer,
  107. const float* cell_layer_norm_weights_buffer,
  108. const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
  109. float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer,
  110. bool timeMajor = true, bool forwardSequence = true);
  111. static bool LSTMEvalFloat16(
  112. const LSTMParams& params, const _Float16* input_buffer, const Shape& input_shape,
  113. const _Float16* input_to_input_weights_buffer,
  114. const _Float16* input_to_forget_weights_buffer,
  115. const _Float16* input_to_cell_weights_buffer,
  116. const _Float16* input_to_output_weights_buffer,
  117. const Shape& input_to_output_weights_shape,
  118. const _Float16* recurrent_to_input_weights_buffer,
  119. const _Float16* recurrent_to_forget_weights_buffer,
  120. const _Float16* recurrent_to_cell_weights_buffer,
  121. const _Float16* recurrent_to_output_weights_buffer,
  122. const Shape& recurrent_to_output_weights_shape,
  123. const _Float16* cell_to_input_weights_buffer,
  124. const _Float16* cell_to_forget_weights_buffer,
  125. const _Float16* cell_to_output_weights_buffer, const _Float16* aux_input_buffer,
  126. const _Float16* aux_input_to_input_weights, const _Float16* aux_input_to_forget_weights,
  127. const _Float16* aux_input_to_cell_weights, const _Float16* aux_input_to_output_weights,
  128. const _Float16* input_gate_bias_buffer, const _Float16* forget_gate_bias_buffer,
  129. const _Float16* cell_bias_buffer, const _Float16* output_gate_bias_buffer,
  130. const _Float16* projection_weights_buffer, const _Float16* projection_bias_buffer,
  131. const _Float16* output_state_in_buffer, const _Float16* cell_state_in_buffer,
  132. const _Float16* input_layer_norm_weights_buffer,
  133. const _Float16* forget_layer_norm_weights_buffer,
  134. const _Float16* cell_layer_norm_weights_buffer,
  135. const _Float16* output_layer_norm_weights_buffer, _Float16* output_state_out_buffer,
  136. _Float16* cell_state_out_buffer, _Float16* output_buffer,
  137. _Float16* scratch_buffer_buffer, bool timeMajor = true, bool forwardSequence = true);
  138. static bool LSTMStep(
  139. const LSTMParams& params, const float* input_buffer, const Shape& input_shape,
  140. const float* input_to_input_weights_buffer, const float* input_to_forget_weights_buffer,
  141. const float* input_to_cell_weights_buffer, const float* input_to_output_weights_buffer,
  142. const Shape& input_to_output_weights_shape,
  143. const float* recurrent_to_input_weights_buffer,
  144. const float* recurrent_to_forget_weights_buffer,
  145. const float* recurrent_to_cell_weights_buffer,
  146. const float* recurrent_to_output_weights_buffer,
  147. const Shape& recurrent_to_output_weights_shape,
  148. const float* cell_to_input_weights_buffer, const float* cell_to_forget_weights_buffer,
  149. const float* cell_to_output_weights_buffer, const float* aux_input_buffer,
  150. const float* aux_input_to_input_weights, const float* aux_input_to_forget_weights,
  151. const float* aux_input_to_cell_weights, const float* aux_input_to_output_weights,
  152. const float* input_gate_bias_buffer, const float* forget_gate_bias_buffer,
  153. const float* cell_bias_buffer, const float* output_gate_bias_buffer,
  154. const float* projection_weights_buffer, const float* projection_bias_buffer,
  155. const float* output_state_in_buffer, const float* cell_state_in_buffer,
  156. const float* input_layer_norm_weights_buffer,
  157. const float* forget_layer_norm_weights_buffer,
  158. const float* cell_layer_norm_weights_buffer,
  159. const float* output_layer_norm_weights_buffer, float* output_state_out_buffer,
  160. float* cell_state_out_buffer, float* output_buffer, float* scratch_buffer_buffer);
  161. static bool CheckInputTensorDimensions(
  162. const RunTimeOperandInfo* input_, const RunTimeOperandInfo* input_to_input_weights,
  163. const RunTimeOperandInfo* input_to_forget_weights,
  164. const RunTimeOperandInfo* input_to_cell_weights,
  165. const RunTimeOperandInfo* input_to_output_weights,
  166. const RunTimeOperandInfo* recurrent_to_input_weights,
  167. const RunTimeOperandInfo* recurrent_to_forget_weights,
  168. const RunTimeOperandInfo* recurrent_to_cell_weights,
  169. const RunTimeOperandInfo* recurrent_to_output_weights,
  170. const RunTimeOperandInfo* cell_to_input_weights,
  171. const RunTimeOperandInfo* cell_to_forget_weights,
  172. const RunTimeOperandInfo* cell_to_output_weights,
  173. const RunTimeOperandInfo* input_gate_bias, const RunTimeOperandInfo* forget_gate_bias,
  174. const RunTimeOperandInfo* cell_bias, const RunTimeOperandInfo* output_gate_bias,
  175. const RunTimeOperandInfo* projection_weights, const RunTimeOperandInfo* projection_bias,
  176. const RunTimeOperandInfo* input_layer_norm_weights,
  177. const RunTimeOperandInfo* forget_layer_norm_weights,
  178. const RunTimeOperandInfo* cell_layer_norm_weights,
  179. const RunTimeOperandInfo* output_layer_norm_weights, uint32_t n_input,
  180. uint32_t n_output, uint32_t n_cell, LSTMParams* params);
  181. private:
  182. LSTMParams params_;
  183. const RunTimeOperandInfo* input_;
  184. const RunTimeOperandInfo* input_to_input_weights_;
  185. const RunTimeOperandInfo* input_to_forget_weights_;
  186. const RunTimeOperandInfo* input_to_cell_weights_;
  187. const RunTimeOperandInfo* input_to_output_weights_;
  188. const RunTimeOperandInfo* recurrent_to_input_weights_;
  189. const RunTimeOperandInfo* recurrent_to_forget_weights_;
  190. const RunTimeOperandInfo* recurrent_to_cell_weights_;
  191. const RunTimeOperandInfo* recurrent_to_output_weights_;
  192. const RunTimeOperandInfo* cell_to_input_weights_;
  193. const RunTimeOperandInfo* cell_to_forget_weights_;
  194. const RunTimeOperandInfo* cell_to_output_weights_;
  195. const RunTimeOperandInfo* input_gate_bias_;
  196. const RunTimeOperandInfo* forget_gate_bias_;
  197. const RunTimeOperandInfo* cell_bias_;
  198. const RunTimeOperandInfo* output_gate_bias_;
  199. const RunTimeOperandInfo* projection_weights_;
  200. const RunTimeOperandInfo* projection_bias_;
  201. const RunTimeOperandInfo* output_state_in_;
  202. const RunTimeOperandInfo* cell_state_in_;
  203. const RunTimeOperandInfo* input_layer_norm_weights_;
  204. const RunTimeOperandInfo* forget_layer_norm_weights_;
  205. const RunTimeOperandInfo* cell_layer_norm_weights_;
  206. const RunTimeOperandInfo* output_layer_norm_weights_;
  207. RunTimeOperandInfo* output_state_out_;
  208. RunTimeOperandInfo* cell_state_out_;
  209. RunTimeOperandInfo* output_;
  210. RunTimeOperandInfo* scratch_buffer_;
  211. };
  212. } // namespace nn
  213. } // namespace android
  214. #endif // FRAMEWORKS_ML_NN_LSTMCELL_H