BidirectionalSequenceLSTM.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460
  1. /*
  2. * Copyright (C) 2019 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "BidirectionalSequenceLSTM.h"
  17. #include "CpuExecutor.h"
  18. #include "CpuOperationUtils.h"
  19. #include "HalInterfaces.h"
  20. #include "OperationsUtils.h"
  21. #include "Tracing.h"
  22. namespace android {
  23. namespace nn {
  24. namespace {
  25. template <typename T>
  26. inline T* GetBuffer(RunTimeOperandInfo* operand) {
  27. return reinterpret_cast<T*>(operand->buffer);
  28. }
  29. template <typename T>
  30. inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
  31. return reinterpret_cast<const T*>(operand->buffer);
  32. }
  33. template <typename T>
  34. inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
  35. return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
  36. }
  37. } // anonymous namespace
  38. BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
  39. std::vector<RunTimeOperandInfo>& operands) {
  40. input_ = GetInput(operation, operands, kInputTensor);
  41. fw_input_to_input_weights_ =
  42. GetInput(operation, operands, kFwInputToInputWeightsTensor); // optional
  43. fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
  44. fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
  45. fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
  46. fw_recurrent_to_input_weights_ =
  47. GetInput(operation, operands, kFwRecurrentToInputWeightsTensor); // optional
  48. fw_recurrent_to_forget_weights_ =
  49. GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
  50. fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
  51. fw_recurrent_to_output_weights_ =
  52. GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
  53. fw_cell_to_input_weights_ =
  54. GetInput(operation, operands, kFwCellToInputWeightsTensor); // optional
  55. fw_cell_to_forget_weights_ =
  56. GetInput(operation, operands, kFwCellToForgetWeightsTensor); // optional
  57. fw_cell_to_output_weights_ =
  58. GetInput(operation, operands, kFwCellToOutputWeightsTensor); // optional
  59. fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
  60. fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
  61. fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
  62. fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
  63. fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor); // optional
  64. fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor); // optional
  65. fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
  66. fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
  67. bw_input_to_input_weights_ =
  68. GetInput(operation, operands, kBwInputToInputWeightsTensor); // optional
  69. bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
  70. bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
  71. bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
  72. bw_recurrent_to_input_weights_ =
  73. GetInput(operation, operands, kBwRecurrentToInputWeightsTensor); // optional
  74. bw_recurrent_to_forget_weights_ =
  75. GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
  76. bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
  77. bw_recurrent_to_output_weights_ =
  78. GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
  79. bw_cell_to_input_weights_ =
  80. GetInput(operation, operands, kBwCellToInputWeightsTensor); // optional
  81. bw_cell_to_forget_weights_ =
  82. GetInput(operation, operands, kBwCellToForgetWeightsTensor); // optional
  83. bw_cell_to_output_weights_ =
  84. GetInput(operation, operands, kBwCellToOutputWeightsTensor); // optional
  85. bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
  86. bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
  87. bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
  88. bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
  89. bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor); // optional
  90. bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor); // optional
  91. bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
  92. bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
  93. aux_input_ = GetInput(operation, operands, kAuxInputTensor);
  94. fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
  95. fw_aux_input_to_forget_weights_ =
  96. GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
  97. fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
  98. fw_aux_input_to_output_weights_ =
  99. GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
  100. bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
  101. bw_aux_input_to_forget_weights_ =
  102. GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
  103. bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
  104. bw_aux_input_to_output_weights_ =
  105. GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
  106. fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
  107. fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
  108. fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
  109. fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
  110. bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
  111. bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
  112. bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
  113. bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
  114. params_.activation = static_cast<TfLiteFusedActivation>(
  115. getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
  116. if (input_->type == OperandType::TENSOR_FLOAT32) {
  117. params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
  118. params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
  119. } else {
  120. params_.cell_clip = static_cast<float>(
  121. getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
  122. params_.proj_clip = static_cast<float>(
  123. getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
  124. }
  125. params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam));
  126. params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam));
  127. params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
  128. fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
  129. if (!params_.merge_outputs) {
  130. bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
  131. }
  132. }
  133. bool BidirectionalSequenceLSTM::Prepare(const Operation& operation,
  134. std::vector<RunTimeOperandInfo>& operands,
  135. Shape* fwOutputShape, Shape* bwOutputShape) {
  136. // Inferring batch size, number of outputs and number of cells from the
  137. // input tensors.
  138. NN_CHECK(NumDimensions(input_) == 3);
  139. const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
  140. const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
  141. const uint32_t n_input = SizeOfDimension(input_, 2);
  142. const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
  143. NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2);
  144. NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_input);
  145. NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2);
  146. NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
  147. const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
  148. // Check that input tensor dimensions matches with each other.
  149. if (!LSTMCell::CheckInputTensorDimensions(
  150. input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
  151. fw_input_to_cell_weights_, fw_input_to_output_weights_,
  152. fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
  153. fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
  154. fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
  155. fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
  156. fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
  157. fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
  158. fw_output_layer_norm_weights_, n_input, n_fw_output, n_fw_cell, &params_)) {
  159. return false;
  160. }
  161. const bool aux_inputs_all_or_none =
  162. (!IsNullInput(aux_input_) && !IsNullInput(fw_aux_input_to_cell_weights_) &&
  163. !IsNullInput(fw_aux_input_to_forget_weights_) &&
  164. !IsNullInput(fw_aux_input_to_output_weights_) &&
  165. !IsNullInput(bw_aux_input_to_cell_weights_) &&
  166. !IsNullInput(bw_aux_input_to_forget_weights_) &&
  167. !IsNullInput(bw_aux_input_to_output_weights_)) ||
  168. (IsNullInput(fw_aux_input_to_cell_weights_) &&
  169. IsNullInput(fw_aux_input_to_forget_weights_) &&
  170. IsNullInput(fw_aux_input_to_output_weights_) &&
  171. IsNullInput(bw_aux_input_to_cell_weights_) &&
  172. IsNullInput(bw_aux_input_to_forget_weights_) &&
  173. IsNullInput(bw_aux_input_to_output_weights_));
  174. NN_CHECK(aux_inputs_all_or_none);
  175. if (!IsNullInput(aux_input_)) {
  176. // Check that aux_input has the same dimensions (except last) as the input.
  177. NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
  178. NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
  179. }
  180. const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
  181. NN_CHECK_EQ(NumDimensions(bw_input_to_output_weights_), 2);
  182. NN_CHECK_EQ(SizeOfDimension(bw_input_to_output_weights_, 1), n_input);
  183. NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2);
  184. NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
  185. const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
  186. const Shape& inputShape = input_->shape();
  187. fwOutputShape->type = inputShape.type;
  188. fwOutputShape->offset = inputShape.offset;
  189. fwOutputShape->scale = inputShape.scale;
  190. fwOutputShape->dimensions.resize(3);
  191. fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
  192. fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
  193. fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
  194. // Check that input tensor dimensions matches with each other.
  195. if (!LSTMCell::CheckInputTensorDimensions(
  196. input_, bw_input_to_input_weights_, bw_input_to_forget_weights_,
  197. bw_input_to_cell_weights_, bw_input_to_output_weights_,
  198. bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
  199. bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
  200. bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
  201. bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
  202. bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
  203. bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
  204. bw_output_layer_norm_weights_, n_input, n_bw_output, n_bw_cell, &params_)) {
  205. return false;
  206. }
  207. if (!params_.merge_outputs) {
  208. bwOutputShape->type = inputShape.type;
  209. bwOutputShape->offset = inputShape.offset;
  210. bwOutputShape->scale = inputShape.scale;
  211. bwOutputShape->dimensions.resize(3);
  212. bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
  213. bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
  214. bwOutputShape->dimensions[2] = n_bw_output;
  215. }
  216. if (params_.use_cifg) {
  217. fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
  218. bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
  219. } else {
  220. fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
  221. bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
  222. }
  223. fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
  224. fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
  225. fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
  226. return true;
  227. }
  228. bool BidirectionalSequenceLSTM::Eval() {
  229. const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
  230. const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
  231. std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
  232. fw_output_dims[2] = n_fw_output;
  233. std::vector<uint32_t> bw_output_dims = fw_output_dims;
  234. bw_output_dims[2] = n_bw_output;
  235. const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
  236. const uint32_t n_output_elements =
  237. fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
  238. switch (input_->type) {
  239. case OperandType::TENSOR_FLOAT32: {
  240. std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
  241. const bool kForwardSequence = true;
  242. LSTMCell::LSTMEvalFloat32(
  243. params_, GetBuffer<const float>(input_), input_->shape(),
  244. GetBuffer<const float>(fw_input_to_input_weights_),
  245. GetBuffer<const float>(fw_input_to_forget_weights_),
  246. GetBuffer<const float>(fw_input_to_cell_weights_),
  247. GetBuffer<const float>(fw_input_to_output_weights_),
  248. fw_input_to_output_weights_->shape(),
  249. GetBuffer<const float>(fw_recurrent_to_input_weights_),
  250. GetBuffer<const float>(fw_recurrent_to_forget_weights_),
  251. GetBuffer<const float>(fw_recurrent_to_cell_weights_),
  252. GetBuffer<const float>(fw_recurrent_to_output_weights_),
  253. fw_recurrent_to_output_weights_->shape(),
  254. GetBuffer<const float>(fw_cell_to_input_weights_),
  255. GetBuffer<const float>(fw_cell_to_forget_weights_),
  256. GetBuffer<const float>(fw_cell_to_output_weights_),
  257. GetOptionalBuffer<const float>(aux_input_),
  258. GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
  259. GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
  260. GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
  261. GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
  262. GetBuffer<const float>(fw_input_gate_bias_),
  263. GetBuffer<const float>(fw_forget_gate_bias_),
  264. GetBuffer<const float>(fw_cell_bias_),
  265. GetBuffer<const float>(fw_output_gate_bias_),
  266. GetBuffer<const float>(fw_projection_weights_),
  267. GetBuffer<const float>(fw_projection_bias_),
  268. GetBuffer<const float>(fw_activation_state_),
  269. GetBuffer<const float>(fw_cell_state_),
  270. GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
  271. GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
  272. GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
  273. GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
  274. GetBuffer<float>(fw_activation_state_), GetBuffer<float>(fw_cell_state_),
  275. GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
  276. kForwardSequence);
  277. std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
  278. const bool kBackwardSequence = false;
  279. LSTMCell::LSTMEvalFloat32(
  280. params_, GetBuffer<const float>(input_), input_->shape(),
  281. GetBuffer<const float>(bw_input_to_input_weights_),
  282. GetBuffer<const float>(bw_input_to_forget_weights_),
  283. GetBuffer<const float>(bw_input_to_cell_weights_),
  284. GetBuffer<const float>(bw_input_to_output_weights_),
  285. bw_input_to_output_weights_->shape(),
  286. GetBuffer<const float>(bw_recurrent_to_input_weights_),
  287. GetBuffer<const float>(bw_recurrent_to_forget_weights_),
  288. GetBuffer<const float>(bw_recurrent_to_cell_weights_),
  289. GetBuffer<const float>(bw_recurrent_to_output_weights_),
  290. bw_recurrent_to_output_weights_->shape(),
  291. GetBuffer<const float>(bw_cell_to_input_weights_),
  292. GetBuffer<const float>(bw_cell_to_forget_weights_),
  293. GetBuffer<const float>(bw_cell_to_output_weights_),
  294. GetOptionalBuffer<const float>(aux_input_),
  295. GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
  296. GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
  297. GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
  298. GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
  299. GetBuffer<const float>(bw_input_gate_bias_),
  300. GetBuffer<const float>(bw_forget_gate_bias_),
  301. GetBuffer<const float>(bw_cell_bias_),
  302. GetBuffer<const float>(bw_output_gate_bias_),
  303. GetBuffer<const float>(bw_projection_weights_),
  304. GetBuffer<const float>(bw_projection_bias_),
  305. GetBuffer<const float>(bw_activation_state_),
  306. GetBuffer<const float>(bw_cell_state_),
  307. GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
  308. GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
  309. GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
  310. GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
  311. GetBuffer<float>(bw_activation_state_), GetBuffer<float>(bw_cell_state_),
  312. params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
  313. : GetBuffer<float>(bw_output_),
  314. bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
  315. if (params_.merge_outputs) {
  316. std::vector<float> temp(n_output_elements);
  317. mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
  318. GetBuffer<float>(fw_output_) + n_fw_output_elements,
  319. bw_output_dims, temp.data());
  320. std::copy(temp.data(), temp.data() + n_output_elements,
  321. GetBuffer<float>(fw_output_));
  322. }
  323. } break;
  324. case OperandType::TENSOR_FLOAT16: {
  325. std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
  326. const bool kForwardSequence = true;
  327. LSTMCell::LSTMEvalFloat16(
  328. params_, GetBuffer<const _Float16>(input_), input_->shape(),
  329. GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
  330. GetBuffer<const _Float16>(fw_input_to_forget_weights_),
  331. GetBuffer<const _Float16>(fw_input_to_cell_weights_),
  332. GetBuffer<const _Float16>(fw_input_to_output_weights_),
  333. fw_input_to_output_weights_->shape(),
  334. GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
  335. GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
  336. GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
  337. GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
  338. fw_recurrent_to_output_weights_->shape(),
  339. GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
  340. GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
  341. GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_),
  342. GetOptionalBuffer<const _Float16>(aux_input_),
  343. GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
  344. GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
  345. GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
  346. GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
  347. GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
  348. GetBuffer<const _Float16>(fw_forget_gate_bias_),
  349. GetBuffer<const _Float16>(fw_cell_bias_),
  350. GetBuffer<const _Float16>(fw_output_gate_bias_),
  351. GetOptionalBuffer<const _Float16>(fw_projection_weights_),
  352. GetOptionalBuffer<const _Float16>(fw_projection_bias_),
  353. GetBuffer<const _Float16>(fw_activation_state_),
  354. GetBuffer<const _Float16>(fw_cell_state_),
  355. GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
  356. GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
  357. GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
  358. GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
  359. GetBuffer<_Float16>(fw_activation_state_), GetBuffer<_Float16>(fw_cell_state_),
  360. GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
  361. kForwardSequence);
  362. std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
  363. const bool kBackwardSequence = false;
  364. LSTMCell::LSTMEvalFloat16(
  365. params_, GetBuffer<const _Float16>(input_), input_->shape(),
  366. GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
  367. GetBuffer<const _Float16>(bw_input_to_forget_weights_),
  368. GetBuffer<const _Float16>(bw_input_to_cell_weights_),
  369. GetBuffer<const _Float16>(bw_input_to_output_weights_),
  370. bw_input_to_output_weights_->shape(),
  371. GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
  372. GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
  373. GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
  374. GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
  375. bw_recurrent_to_output_weights_->shape(),
  376. GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
  377. GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
  378. GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_),
  379. GetOptionalBuffer<const _Float16>(aux_input_),
  380. GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
  381. GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
  382. GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
  383. GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
  384. GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
  385. GetBuffer<const _Float16>(bw_forget_gate_bias_),
  386. GetBuffer<const _Float16>(bw_cell_bias_),
  387. GetBuffer<const _Float16>(bw_output_gate_bias_),
  388. GetOptionalBuffer<const _Float16>(bw_projection_weights_),
  389. GetOptionalBuffer<const _Float16>(bw_projection_bias_),
  390. GetBuffer<const _Float16>(bw_activation_state_),
  391. GetBuffer<const _Float16>(bw_cell_state_),
  392. GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
  393. GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
  394. GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
  395. GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
  396. GetBuffer<_Float16>(bw_activation_state_), GetBuffer<_Float16>(bw_cell_state_),
  397. params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
  398. : GetBuffer<_Float16>(bw_output_),
  399. bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
  400. if (params_.merge_outputs) {
  401. std::vector<_Float16> temp(n_output_elements);
  402. mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
  403. GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
  404. bw_output_dims, temp.data());
  405. std::copy(temp.data(), temp.data() + n_output_elements,
  406. GetBuffer<_Float16>(fw_output_));
  407. }
  408. } break;
  409. default: {
  410. LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
  411. return false;
  412. }
  413. }
  414. return true;
  415. }
  416. } // namespace nn
  417. } // namespace android