RNN.cpp 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "RNN.h"
  17. #include "CpuExecutor.h"
  18. #include "CpuOperationUtils.h"
  19. #include "HalInterfaces.h"
  20. #include "Tracing.h"
  21. namespace android {
  22. namespace nn {
  23. RNN::RNN(const Operation& operation,
  24. std::vector<RunTimeOperandInfo>& operands) {
  25. NNTRACE_TRANS("RNN::RNN");
  26. input_ = GetInput(operation, operands, kInputTensor);
  27. weights_ = GetInput(operation, operands, kWeightsTensor);
  28. recurrent_weights_ = GetInput(operation, operands, kRecurrentWeightsTensor);
  29. hidden_state_in_ = GetInput(operation, operands, kHiddenStateInTensor);
  30. bias_ = GetInput(operation, operands, kBiasTensor);
  31. activation_ = static_cast<ActivationFn>(
  32. getScalarData<int32_t>(operands[operation.inputs[kActivationParam]]));
  33. hidden_state_out_ = GetOutput(operation, operands, kHiddenStateOutTensor);
  34. output_ = GetOutput(operation, operands, kOutputTensor);
  35. }
  36. bool RNN::Prepare(const Operation &operation,
  37. std::vector<RunTimeOperandInfo> &operands,
  38. Shape *hiddenStateShape,
  39. Shape *outputShape) {
  40. NNTRACE_TRANS("RNN::Prepare");
  41. // Check we have all the inputs and outputs we need.
  42. const int num_inputs = NumInputsWithValues(operation, operands);
  43. NN_CHECK(num_inputs == 5 || num_inputs == 6);
  44. NN_CHECK_EQ(NumOutputs(operation), 2);
  45. const RunTimeOperandInfo *input =
  46. GetInput(operation, operands, kInputTensor);
  47. const RunTimeOperandInfo *input_weights =
  48. GetInput(operation, operands, kWeightsTensor);
  49. const RunTimeOperandInfo *recurrent_weights =
  50. GetInput(operation, operands, kRecurrentWeightsTensor);
  51. const RunTimeOperandInfo *bias =
  52. GetInput(operation, operands, kBiasTensor);
  53. // Check all the parameters of tensor match within themselves and match the
  54. // input configuration.
  55. const uint32_t batch_size = SizeOfDimension(input, 0);
  56. const uint32_t num_units = SizeOfDimension(input_weights, 0);
  57. NN_CHECK_EQ(SizeOfDimension(input, 1), SizeOfDimension(input_weights, 1));
  58. NN_CHECK_EQ(SizeOfDimension(input_weights, 0), SizeOfDimension(bias, 0));
  59. NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 0), SizeOfDimension(bias, 0));
  60. NN_CHECK_EQ(SizeOfDimension(recurrent_weights, 1), SizeOfDimension(bias, 0));
  61. const Shape &inputShape = input->shape();
  62. // Resize state.
  63. hiddenStateShape->type = inputShape.type;
  64. hiddenStateShape->dimensions = { batch_size, num_units };
  65. // Resize output.
  66. outputShape->type = inputShape.type;
  67. outputShape->dimensions = { batch_size, num_units };
  68. return true;
  69. }
  70. bool RNN::Eval() {
  71. switch (input_->type) {
  72. case OperandType::TENSOR_FLOAT16: {
  73. RNNStep<_Float16>(reinterpret_cast<_Float16*>(input_->buffer), input_->shape(),
  74. reinterpret_cast<_Float16*>(hidden_state_in_->buffer),
  75. reinterpret_cast<_Float16*>(bias_->buffer),
  76. reinterpret_cast<_Float16*>(weights_->buffer), weights_->shape(),
  77. reinterpret_cast<_Float16*>(recurrent_weights_->buffer),
  78. recurrent_weights_->shape(), activation_,
  79. reinterpret_cast<_Float16*>(output_->buffer));
  80. memcpy(hidden_state_out_->buffer, output_->buffer,
  81. sizeof(_Float16) * getNumberOfElements(output_->shape()));
  82. break;
  83. }
  84. case OperandType::TENSOR_FLOAT32: {
  85. RNNStep<float>(reinterpret_cast<float*>(input_->buffer), input_->shape(),
  86. reinterpret_cast<float*>(hidden_state_in_->buffer),
  87. reinterpret_cast<float*>(bias_->buffer),
  88. reinterpret_cast<float*>(weights_->buffer), weights_->shape(),
  89. reinterpret_cast<float*>(recurrent_weights_->buffer),
  90. recurrent_weights_->shape(), activation_,
  91. reinterpret_cast<float*>(output_->buffer));
  92. memcpy(hidden_state_out_->buffer, output_->buffer,
  93. sizeof(float) * getNumberOfElements(output_->shape()));
  94. break;
  95. }
  96. default: {
  97. LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
  98. return false;
  99. }
  100. }
  101. return true;
  102. }
  103. template <typename T>
  104. bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* hiddenStateInputData,
  105. const T* biasData, const T* weightsData, const Shape& weightsShape,
  106. const T* recurrentWeightsData, const Shape& recurrentWeightsShape,
  107. const int32_t activation, T* outputData) {
  108. NNTRACE_COMP("RNN::Eval");
  109. Shape dummyShape;
  110. uint32_t numUnits = weightsShape.dimensions[0];
  111. return RNNStep<T>(inputData, inputShape, /*auxInputData=*/nullptr, /*auxInputShape=*/dummyShape,
  112. hiddenStateInputData, biasData, weightsData, weightsShape,
  113. /*auxWeightsData=*/nullptr, /*auxWeightsShape=*/dummyShape,
  114. recurrentWeightsData, recurrentWeightsShape, activation,
  115. /*outputBatchStride=*/numUnits, /*outputBatchOffset=*/0, outputData);
  116. }
  117. // A more general version of the RNNStep function.
  118. // Auxiliary input is treated as if it was concatenated to a regular input and
  119. // the result was multiplied by the weights matrix which was also concatenated
  120. // with auxiliary weights.
  121. template <typename T>
  122. bool RNN::RNNStep(const T* inputData, const Shape& inputShape, const T* auxInputData,
  123. const Shape& auxInputShape, const T* hiddenStateInputData, const T* biasData,
  124. const T* weightsData, const Shape& weightsShape, const T* auxWeightsData,
  125. const Shape& auxWeightsShape, const T* recurrentWeightsData,
  126. const Shape& recurrentWeightsShape, const int32_t activation,
  127. const uint32_t outputBatchStride, const uint32_t outputBatchOffset, T* outputData,
  128. T* hiddenStateOutput) {
  129. NNTRACE_COMP("RNN::Eval");
  130. const uint32_t batch_size = inputShape.dimensions[0];
  131. const uint32_t num_units = weightsShape.dimensions[0];
  132. const uint32_t input_size = inputShape.dimensions[1];
  133. const uint32_t input_weights_stride = weightsShape.dimensions[1];
  134. const uint32_t recurrent_weights_stride = recurrentWeightsShape.dimensions[1];
  135. uint32_t aux_input_size = 0;
  136. uint32_t aux_input_weights_stride = 0;
  137. bool hasAuxInput = (auxInputData != nullptr);
  138. if (hasAuxInput) {
  139. aux_input_size = auxInputShape.dimensions[1];
  140. aux_input_weights_stride = auxWeightsShape.dimensions[1];
  141. }
  142. // For each batch
  143. for (uint32_t b = 0; b < batch_size; b++) {
  144. // Initialize the pointer to input, output and bias.
  145. const T* input_ptr_batch = inputData + b * input_size;
  146. const T* hidden_state_in_ptr_batch = hiddenStateInputData + b * num_units;
  147. const T* aux_input_ptr_batch = nullptr;
  148. if (hasAuxInput) {
  149. aux_input_ptr_batch = auxInputData + b * aux_input_size;
  150. }
  151. T* output_ptr_batch = outputData + b * outputBatchStride + outputBatchOffset;
  152. // Initialize input_weights and recurrent_weights.
  153. const T* input_weights_ptr = weightsData;
  154. const T* recurrent_weights_ptr = recurrentWeightsData;
  155. const T* aux_input_weights_ptr = nullptr;
  156. if (hasAuxInput) {
  157. aux_input_weights_ptr = auxWeightsData;
  158. }
  159. // Output = bias
  160. for (uint32_t o = 0; o < num_units; o++) {
  161. output_ptr_batch[o] = biasData[o];
  162. }
  163. // Output += input * input_weights
  164. for (uint32_t o = 0; o < num_units; o++) {
  165. for (uint32_t i = 0; i < input_size; i++) {
  166. output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i];
  167. }
  168. input_weights_ptr += input_weights_stride;
  169. }
  170. if (hasAuxInput) {
  171. // Output += aux_input * aux_input_weights
  172. for (uint32_t o = 0; o < num_units; o++) {
  173. for (uint32_t i = 0; i < input_size; i++) {
  174. output_ptr_batch[o] += aux_input_ptr_batch[i] * aux_input_weights_ptr[i];
  175. }
  176. aux_input_weights_ptr += aux_input_weights_stride;
  177. }
  178. }
  179. // Output += recurrent_weights * hidden_state
  180. for (uint32_t o = 0; o < num_units; o++) {
  181. for (uint32_t h = 0; h < num_units; h++) {
  182. output_ptr_batch[o] += hidden_state_in_ptr_batch[h] * recurrent_weights_ptr[h];
  183. }
  184. recurrent_weights_ptr += recurrent_weights_stride;
  185. }
  186. // Output = activation(Output)
  187. for (uint32_t o = 0; o < num_units; o++) {
  188. output_ptr_batch[o] =
  189. (ActivationFunctor(static_cast<ActivationFn>(activation)))(output_ptr_batch[o]);
  190. if (hiddenStateOutput != nullptr) {
  191. *hiddenStateOutput = output_ptr_batch[o];
  192. ++hiddenStateOutput;
  193. }
  194. }
  195. }
  196. return true;
  197. }
  198. } // namespace nn
  199. } // namespace android