FullyConnected.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "CpuOperationUtils.h"
  17. #include "OperationResolver.h"
  18. #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
  19. #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
  20. #include "Tracing.h"
  21. namespace android {
  22. namespace nn {
  23. namespace fully_connected {
  24. constexpr char kOperationName[] = "FULLY_CONNECTED";
  25. constexpr uint32_t kNumInputs = 4;
  26. constexpr uint32_t kInputTensor = 0;
  27. constexpr uint32_t kWeightsTensor = 1;
  28. constexpr uint32_t kBiasTensor = 2;
  29. constexpr uint32_t kActivationScalar = 3;
  30. constexpr uint32_t kNumOutputs = 1;
  31. constexpr uint32_t kOutputTensor = 0;
  32. namespace {
  33. // executionMutex is used to protect concurrent access of non-threadsafe resources
  34. // like gemmlowp::GemmContext.
  35. // std::mutex is safe for pthreads on Android.
  36. static std::mutex executionMutex;
  37. bool fullyConnectedFloat32(const float* inputData, const Shape& inputShape,
  38. const float* weightsData, const Shape& weightsShape,
  39. const float* biasData, const Shape& biasShape, int32_t activation,
  40. float* outputData, const Shape& outputShape) {
  41. NNTRACE_TRANS("fullyConnectedFloat32");
  42. float output_activation_min, output_activation_max;
  43. CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
  44. // b/80425683, optimized implementation produces incorrect results when the
  45. // number of input elements is the squre of batch_size.
  46. uint32_t batch_size = getSizeOfDimension(outputShape, 0);
  47. uint32_t input_n_elements = getNumberOfElements(inputShape);
  48. if (batch_size * batch_size == input_n_elements) {
  49. NNTRACE_COMP_SWITCH("reference_ops::FullyConnected");
  50. tflite::reference_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
  51. weightsData, convertShapeToDims(weightsShape),
  52. biasData, convertShapeToDims(biasShape),
  53. output_activation_min, output_activation_max,
  54. outputData, convertShapeToDims(outputShape));
  55. } else {
  56. NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
  57. tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape),
  58. weightsData, convertShapeToDims(weightsShape),
  59. biasData, convertShapeToDims(biasShape),
  60. output_activation_min, output_activation_max,
  61. outputData, convertShapeToDims(outputShape));
  62. }
  63. return true;
  64. }
  65. bool fullyConnectedFloat16(const _Float16* inputData, const Shape& inputShape,
  66. const _Float16* weightsData, const Shape& weightsShape,
  67. const _Float16* biasData, const Shape& biasShape, int32_t activation,
  68. _Float16* outputData, const Shape& outputShape) {
  69. NNTRACE_TRANS("fullyConnectedFloat16");
  70. std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
  71. convertFloat16ToFloat32(inputData, &inputDataFloat32);
  72. std::vector<float> weightsDataFloat32(getNumberOfElements(weightsShape));
  73. convertFloat16ToFloat32(weightsData, &weightsDataFloat32);
  74. std::vector<float> biasDataFloat32(getNumberOfElements(biasShape));
  75. convertFloat16ToFloat32(biasData, &biasDataFloat32);
  76. std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
  77. fullyConnectedFloat32(inputDataFloat32.data(), inputShape, weightsDataFloat32.data(),
  78. weightsShape, biasDataFloat32.data(), biasShape, activation,
  79. outputDataFloat32.data(), outputShape);
  80. convertFloat32ToFloat16(outputDataFloat32, outputData);
  81. return true;
  82. }
  83. bool fullyConnectedQuant8(const uint8_t* inputData, const Shape& inputShape,
  84. const uint8_t* weightsData, const Shape& weightsShape,
  85. const int32_t* biasData, const Shape& biasShape, int32_t activation,
  86. uint8_t* outputData, const Shape& outputShape) {
  87. NNTRACE_TRANS("fullyConnectedQuant8");
  88. int32_t inputOffset = -inputShape.offset;
  89. int32_t weightsOffset = -weightsShape.offset;
  90. int32_t outputOffset = outputShape.offset;
  91. double realMultiplier = 0.0;
  92. int32_t outputMultiplier = 0;
  93. int32_t outputShift = 0;
  94. int32_t outputActivationMin = 0;
  95. int32_t outputActivationMax = 0;
  96. NN_RET_CHECK(GetQuantizedConvolutionMultipler(inputShape, weightsShape, biasShape, outputShape,
  97. &realMultiplier));
  98. int exponent;
  99. NN_RET_CHECK(QuantizeMultiplier(realMultiplier, &outputMultiplier, &exponent));
  100. outputShift = -exponent;
  101. CalculateActivationRangeUint8(activation, outputShape, &outputActivationMin,
  102. &outputActivationMax);
  103. static gemmlowp::GemmContext gemmContext;
  104. // Prevent concurrent executions that access gemmContext.
  105. std::unique_lock<std::mutex> lock(executionMutex);
  106. // Alow gemmlowp automatically decide how many threads to use.
  107. gemmContext.set_max_num_threads(0);
  108. NNTRACE_COMP_SWITCH("optimized_ops::FullyConnected");
  109. tflite::optimized_ops::FullyConnected(inputData, convertShapeToDims(inputShape), inputOffset,
  110. weightsData, convertShapeToDims(weightsShape),
  111. weightsOffset, biasData, convertShapeToDims(biasShape),
  112. outputOffset, outputMultiplier, outputShift,
  113. outputActivationMin, outputActivationMax, outputData,
  114. convertShapeToDims(outputShape), &gemmContext);
  115. return true;
  116. }
  117. } // namespace
  118. bool validate(const IOperationValidationContext* context) {
  119. NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
  120. NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
  121. auto inputType = context->getInputType(kInputTensor);
  122. std::vector<OperandType> inExpectedTypes;
  123. std::vector<OperandType> outExpectedTypes;
  124. if (inputType == OperandType::TENSOR_FLOAT32) {
  125. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
  126. inExpectedTypes = {
  127. OperandType::TENSOR_FLOAT32,
  128. OperandType::TENSOR_FLOAT32,
  129. OperandType::TENSOR_FLOAT32,
  130. OperandType::INT32,
  131. };
  132. } else if (inputType == OperandType::TENSOR_FLOAT16) {
  133. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
  134. inExpectedTypes = {
  135. OperandType::TENSOR_FLOAT16,
  136. OperandType::TENSOR_FLOAT16,
  137. OperandType::TENSOR_FLOAT16,
  138. OperandType::INT32,
  139. };
  140. } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
  141. // NeuralNetworks.h specifies that ANEURALNETWORKS_FULLY_CONNECTED's output must
  142. // meet "outputScale > inputScale * weightsScale" for the operand type
  143. // ANEURALNETWORKS_TENSOR_QUANT8_ASYMM before API level 29.
  144. const float inputScale = context->getInputShape(kInputTensor).scale;
  145. const float weightsScale = context->getInputShape(kWeightsTensor).scale;
  146. const float outputScale = context->getOutputShape(kOutputTensor).scale;
  147. bool meetsQuantizedScaleConstraintBeforeV1_2 = (outputScale > inputScale * weightsScale);
  148. if (!meetsQuantizedScaleConstraintBeforeV1_2) {
  149. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
  150. } else {
  151. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
  152. }
  153. inExpectedTypes = {
  154. OperandType::TENSOR_QUANT8_ASYMM,
  155. OperandType::TENSOR_QUANT8_ASYMM,
  156. OperandType::TENSOR_INT32,
  157. OperandType::INT32,
  158. };
  159. } else {
  160. NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation " << kOperationName;
  161. return false;
  162. }
  163. NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
  164. NN_RET_CHECK(validateOutputTypes(context, {inputType}));
  165. return true;
  166. }
  167. bool prepare(IOperationExecutionContext* context) {
  168. Shape input = context->getInputShape(kInputTensor);
  169. Shape weights = context->getInputShape(kWeightsTensor);
  170. Shape bias = context->getInputShape(kBiasTensor);
  171. // Check all the parameters of tensor match within themselves and match the
  172. // input configuration.
  173. NN_RET_CHECK(input.type == weights.type);
  174. if (input.type == OperandType::TENSOR_QUANT8_ASYMM) {
  175. NN_RET_CHECK(bias.type == OperandType::TENSOR_INT32);
  176. } else {
  177. NN_RET_CHECK(input.type == bias.type);
  178. }
  179. // The Tensorflow fully connected layer specification says that input should
  180. // be of at least rank 2, so we check. Tflite doesn't check.
  181. NN_RET_CHECK_GE(getNumberOfDimensions(input), 2);
  182. NN_RET_CHECK_EQ(getNumberOfDimensions(weights), 2);
  183. uint32_t input_n_elements = getNumberOfElements(input);
  184. uint32_t num_units = getSizeOfDimension(weights, 0);
  185. uint32_t input_size = getSizeOfDimension(weights, 1);
  186. // Only batch_size can be 0.
  187. NN_RET_CHECK_GT(num_units, 0);
  188. NN_RET_CHECK_GT(input_size, 0);
  189. uint32_t batch_size = input_n_elements / input_size;
  190. NN_RET_CHECK_EQ(getSizeOfDimension(bias, 0), num_units);
  191. NN_RET_CHECK_EQ(input_size * batch_size, input_n_elements);
  192. Shape output = context->getOutputShape(kOutputTensor);
  193. output.type = input.type;
  194. output.dimensions = {batch_size, num_units};
  195. return context->setOutputShape(kOutputTensor, output);
  196. }
  197. bool execute(IOperationExecutionContext* context) {
  198. // Bypass execution in the case of zero-sized input.
  199. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  200. switch (context->getInputType(kInputTensor)) {
  201. case OperandType::TENSOR_FLOAT32:
  202. return fullyConnectedFloat32(context->getInputBuffer<float>(kInputTensor),
  203. context->getInputShape(kInputTensor),
  204. context->getInputBuffer<float>(kWeightsTensor),
  205. context->getInputShape(kWeightsTensor),
  206. context->getInputBuffer<float>(kBiasTensor),
  207. context->getInputShape(kBiasTensor),
  208. context->getInputValue<int32_t>(kActivationScalar),
  209. context->getOutputBuffer<float>(kOutputTensor),
  210. context->getOutputShape(kOutputTensor));
  211. case OperandType::TENSOR_FLOAT16:
  212. return fullyConnectedFloat16(context->getInputBuffer<_Float16>(kInputTensor),
  213. context->getInputShape(kInputTensor),
  214. context->getInputBuffer<_Float16>(kWeightsTensor),
  215. context->getInputShape(kWeightsTensor),
  216. context->getInputBuffer<_Float16>(kBiasTensor),
  217. context->getInputShape(kBiasTensor),
  218. context->getInputValue<int32_t>(kActivationScalar),
  219. context->getOutputBuffer<_Float16>(kOutputTensor),
  220. context->getOutputShape(kOutputTensor));
  221. case OperandType::TENSOR_QUANT8_ASYMM:
  222. return fullyConnectedQuant8(context->getInputBuffer<uint8_t>(kInputTensor),
  223. context->getInputShape(kInputTensor),
  224. context->getInputBuffer<uint8_t>(kWeightsTensor),
  225. context->getInputShape(kWeightsTensor),
  226. context->getInputBuffer<int32_t>(kBiasTensor),
  227. context->getInputShape(kBiasTensor),
  228. context->getInputValue<int32_t>(kActivationScalar),
  229. context->getOutputBuffer<uint8_t>(kOutputTensor),
  230. context->getOutputShape(kOutputTensor));
  231. default:
  232. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
  233. }
  234. }
  235. } // namespace fully_connected
  236. NN_REGISTER_OPERATION(FULLY_CONNECTED, fully_connected::kOperationName, fully_connected::validate,
  237. fully_connected::prepare, fully_connected::execute,
  238. .allowZeroSizedInput = true);
  239. } // namespace nn
  240. } // namespace android