Pooling.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "CpuOperationUtils.h"
  17. #include "OperationResolver.h"
  18. #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
  19. #include "Tracing.h"
  20. namespace android {
  21. namespace nn {
  22. namespace pooling {
  23. constexpr uint32_t kInputTensor = 0;
  24. constexpr uint32_t kNumOutputs = 1;
  25. constexpr uint32_t kOutputTensor = 0;
  26. namespace {
  27. struct PoolingParam {
  28. int32_t padding_left, padding_right;
  29. int32_t padding_top, padding_bottom;
  30. int32_t stride_width, stride_height;
  31. int32_t filter_width, filter_height;
  32. int32_t activation;
  33. bool useNchw = false;
  34. bool initialize(const IOperationExecutionContext* context) {
  35. uint32_t inCount = context->getNumInputs();
  36. int32_t padding_implicit = 0;
  37. if (inCount >= 10) {
  38. padding_left = context->getInputValue<int32_t>(1);
  39. padding_right = context->getInputValue<int32_t>(2);
  40. padding_top = context->getInputValue<int32_t>(3);
  41. padding_bottom = context->getInputValue<int32_t>(4);
  42. stride_width = context->getInputValue<int32_t>(5);
  43. stride_height = context->getInputValue<int32_t>(6);
  44. filter_width = context->getInputValue<int32_t>(7);
  45. filter_height = context->getInputValue<int32_t>(8);
  46. activation = context->getInputValue<int32_t>(9);
  47. if (inCount == 11) {
  48. useNchw = context->getInputValue<bool>(10);
  49. }
  50. } else {
  51. padding_implicit = context->getInputValue<int32_t>(1);
  52. stride_width = context->getInputValue<int32_t>(2);
  53. stride_height = context->getInputValue<int32_t>(3);
  54. filter_width = context->getInputValue<int32_t>(4);
  55. filter_height = context->getInputValue<int32_t>(5);
  56. activation = context->getInputValue<int32_t>(6);
  57. if (inCount == 8) {
  58. useNchw = context->getInputValue<bool>(7);
  59. }
  60. }
  61. if (inCount <= 8) {
  62. Shape inputShape = context->getInputShape(kInputTensor);
  63. int32_t input_height = getSizeOfDimension(inputShape, useNchw ? 2 : 1);
  64. int32_t input_width = getSizeOfDimension(inputShape, useNchw ? 3 : 2);
  65. calculateExplicitPadding(input_width, stride_width, filter_width, padding_implicit,
  66. &padding_left, &padding_right);
  67. calculateExplicitPadding(input_height, stride_height, filter_height, padding_implicit,
  68. &padding_top, &padding_bottom);
  69. }
  70. NN_RET_CHECK_GE(padding_left, 0);
  71. NN_RET_CHECK_GE(padding_right, 0);
  72. NN_RET_CHECK_GE(padding_top, 0);
  73. NN_RET_CHECK_GE(padding_bottom, 0);
  74. NN_RET_CHECK_GT(stride_width, 0);
  75. NN_RET_CHECK_GT(stride_height, 0);
  76. NN_RET_CHECK_GT(filter_width, 0);
  77. NN_RET_CHECK_GT(filter_height, 0);
  78. NN_RET_CHECK_GE(activation, 0);
  79. NN_RET_CHECK_GT(filter_width, padding_left);
  80. NN_RET_CHECK_GT(filter_width, padding_right);
  81. NN_RET_CHECK_GT(filter_height, padding_top);
  82. NN_RET_CHECK_GT(filter_height, padding_bottom);
  83. return true;
  84. }
  85. tflite::PoolParams toTfliteParam(const Shape& output) const {
  86. tflite::PoolParams params = {
  87. .stride_height = stride_height,
  88. .stride_width = stride_width,
  89. .filter_height = filter_height,
  90. .filter_width = filter_width,
  91. .padding_values = {.height = static_cast<int16_t>(padding_top),
  92. .width = static_cast<int16_t>(padding_left)}};
  93. if (output.type == OperandType::TENSOR_QUANT8_ASYMM) {
  94. int32_t output_activation_min = 0;
  95. int32_t output_activation_max = 0;
  96. CalculateActivationRangeUint8(activation, output, &output_activation_min,
  97. &output_activation_max);
  98. params.quantized_activation_min = output_activation_min;
  99. params.quantized_activation_max = output_activation_max;
  100. } else {
  101. float output_activation_min, output_activation_max;
  102. CalculateActivationRangeFloat(activation, &output_activation_min,
  103. &output_activation_max);
  104. params.float_activation_min = output_activation_min;
  105. params.float_activation_max = output_activation_max;
  106. }
  107. return params;
  108. }
  109. };
  110. bool averagePoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
  111. float* outputData, const Shape& outputShape) {
  112. NNTRACE_TRANS("averagePoolFloat32");
  113. auto op_params = param.toTfliteParam(outputShape);
  114. NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
  115. tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
  116. convertShapeToTflshape(outputShape), outputData);
  117. return true;
  118. }
  119. bool averagePoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
  120. _Float16* outputData, const Shape& outputShape) {
  121. NNTRACE_TRANS("averagePoolFloat16");
  122. std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
  123. std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
  124. convertFloat16ToFloat32(inputData, &inputDataFloat32);
  125. averagePoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(),
  126. outputShape);
  127. convertFloat32ToFloat16(outputDataFloat32, outputData);
  128. return true;
  129. }
  130. bool averagePoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
  131. uint8_t* outputData, const Shape& outputShape) {
  132. NNTRACE_TRANS("averagePoolQuant8");
  133. auto op_params = param.toTfliteParam(outputShape);
  134. NNTRACE_COMP_SWITCH("optimized_ops::AveragePool");
  135. tflite::optimized_ops::AveragePool(op_params, convertShapeToTflshape(inputShape), inputData,
  136. convertShapeToTflshape(outputShape), outputData);
  137. return true;
  138. }
  139. bool l2PoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
  140. float* outputData, const Shape& outputShape) {
  141. NNTRACE_TRANS("l2PoolFloat32");
  142. auto op_params = param.toTfliteParam(outputShape);
  143. NNTRACE_COMP_SWITCH("optimized_ops::L2Pool");
  144. tflite::optimized_ops::L2Pool(op_params, convertShapeToTflshape(inputShape), inputData,
  145. convertShapeToTflshape(outputShape), outputData);
  146. return true;
  147. }
  148. bool l2PoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
  149. _Float16* outputData, const Shape& outputShape) {
  150. NNTRACE_TRANS("l2PoolFloat16");
  151. std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
  152. std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
  153. convertFloat16ToFloat32(inputData, &inputDataFloat32);
  154. l2PoolNhwc(inputDataFloat32.data(), inputShape, param, outputDataFloat32.data(), outputShape);
  155. convertFloat32ToFloat16(outputDataFloat32, outputData);
  156. return true;
  157. }
  158. bool maxPoolNhwc(const float* inputData, const Shape& inputShape, const PoolingParam& param,
  159. float* outputData, const Shape& outputShape) {
  160. NNTRACE_TRANS("maxPoolFloat32");
  161. auto op_params = param.toTfliteParam(outputShape);
  162. NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
  163. tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
  164. convertShapeToTflshape(outputShape), outputData);
  165. return true;
  166. }
  167. bool maxPoolNhwc(const uint8_t* inputData, const Shape& inputShape, const PoolingParam& param,
  168. uint8_t* outputData, const Shape& outputShape) {
  169. NNTRACE_TRANS("maxPoolQuant8");
  170. auto op_params = param.toTfliteParam(outputShape);
  171. NNTRACE_COMP_SWITCH("optimized_ops::MaxPool");
  172. tflite::optimized_ops::MaxPool(op_params, convertShapeToTflshape(inputShape), inputData,
  173. convertShapeToTflshape(outputShape), outputData);
  174. return true;
  175. }
  176. bool maxPoolNhwc(const _Float16* inputData, const Shape& inputShape, const PoolingParam& param,
  177. _Float16* outputData, const Shape& outputShape) {
  178. NNTRACE_TRANS("maxPoolFloat16");
  179. std::vector<float> inputData_float32(getNumberOfElements(inputShape));
  180. std::vector<float> outputData_float32(getNumberOfElements(outputShape));
  181. convertFloat16ToFloat32(inputData, &inputData_float32);
  182. maxPoolNhwc(inputData_float32.data(), inputShape, param, outputData_float32.data(),
  183. outputShape);
  184. convertFloat32ToFloat16(outputData_float32, outputData);
  185. return true;
  186. }
  187. template <typename T>
  188. bool averagePool(const T* inputData, const Shape& inputShape, const PoolingParam& param,
  189. T* outputData, const Shape& outputShape) {
  190. InputWithLayout<T> input(param.useNchw);
  191. OutputWithLayout<T> output(param.useNchw);
  192. NN_RET_CHECK(input.initialize(inputData, inputShape));
  193. NN_RET_CHECK(output.initialize(outputData, outputShape));
  194. NN_RET_CHECK(averagePoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
  195. output.getNhwcBuffer(), output.getNhwcShape()));
  196. NN_RET_CHECK(output.commit());
  197. return true;
  198. }
  199. template <typename T>
  200. bool l2Pool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
  201. const Shape& outputShape) {
  202. InputWithLayout<T> input(param.useNchw);
  203. OutputWithLayout<T> output(param.useNchw);
  204. NN_RET_CHECK(input.initialize(inputData, inputShape));
  205. NN_RET_CHECK(output.initialize(outputData, outputShape));
  206. NN_RET_CHECK(l2PoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
  207. output.getNhwcBuffer(), output.getNhwcShape()));
  208. NN_RET_CHECK(output.commit());
  209. return true;
  210. }
  211. template <typename T>
  212. bool maxPool(const T* inputData, const Shape& inputShape, const PoolingParam& param, T* outputData,
  213. const Shape& outputShape) {
  214. InputWithLayout<T> input(param.useNchw);
  215. OutputWithLayout<T> output(param.useNchw);
  216. NN_RET_CHECK(input.initialize(inputData, inputShape));
  217. NN_RET_CHECK(output.initialize(outputData, outputShape));
  218. NN_RET_CHECK(maxPoolNhwc(input.getNhwcBuffer(), input.getNhwcShape(), param,
  219. output.getNhwcBuffer(), output.getNhwcShape()));
  220. NN_RET_CHECK(output.commit());
  221. return true;
  222. }
  223. } // namespace
  224. bool validate(OperationType opType, const IOperationValidationContext* context) {
  225. NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
  226. auto inputCount = context->getNumInputs();
  227. NN_RET_CHECK(inputCount == 11 || inputCount == 10 || inputCount == 8 || inputCount == 7);
  228. auto inputType = context->getInputType(kInputTensor);
  229. std::vector<OperandType> inExpectedTypes;
  230. if (inputType == OperandType::TENSOR_FLOAT32) {
  231. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
  232. inExpectedTypes = {
  233. inputType, OperandType::INT32, OperandType::INT32, OperandType::INT32,
  234. OperandType::INT32, OperandType::INT32, OperandType::INT32,
  235. };
  236. } else if (inputType == OperandType::TENSOR_FLOAT16) {
  237. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
  238. inExpectedTypes = {
  239. OperandType::TENSOR_FLOAT16, OperandType::INT32, OperandType::INT32,
  240. OperandType::INT32, OperandType::INT32, OperandType::INT32,
  241. OperandType::INT32,
  242. };
  243. } else if (opType != OperationType::L2_POOL_2D &&
  244. inputType == OperandType::TENSOR_QUANT8_ASYMM) {
  245. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
  246. inExpectedTypes = {
  247. OperandType::TENSOR_QUANT8_ASYMM,
  248. OperandType::INT32,
  249. OperandType::INT32,
  250. OperandType::INT32,
  251. OperandType::INT32,
  252. OperandType::INT32,
  253. OperandType::INT32,
  254. };
  255. } else {
  256. NN_RET_CHECK_FAIL() << "Unsupported input tensor type for operation "
  257. << getOperationName(opType);
  258. }
  259. if (inputCount >= 10) {
  260. std::vector<OperandType> explicitScalarTypes(3, OperandType::INT32);
  261. inExpectedTypes.insert(inExpectedTypes.end(), explicitScalarTypes.begin(),
  262. explicitScalarTypes.end());
  263. }
  264. if (inputCount == 11 || inputCount == 8) {
  265. inExpectedTypes.push_back(OperandType::BOOL);
  266. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
  267. } else {
  268. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_0));
  269. }
  270. return validateInputTypes(context, inExpectedTypes) &&
  271. validateOutputTypes(context, {inputType});
  272. }
  273. bool prepare(IOperationExecutionContext* context) {
  274. Shape input = context->getInputShape(kInputTensor);
  275. NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
  276. PoolingParam param;
  277. NN_RET_CHECK(param.initialize(context));
  278. // Only batches can be zero.
  279. uint32_t batches = getSizeOfDimension(input, 0);
  280. uint32_t height = getSizeOfDimension(input, param.useNchw ? 2 : 1);
  281. uint32_t width = getSizeOfDimension(input, param.useNchw ? 3 : 2);
  282. uint32_t channels = getSizeOfDimension(input, param.useNchw ? 1 : 3);
  283. NN_RET_CHECK_GT(height, 0);
  284. NN_RET_CHECK_GT(width, 0);
  285. NN_RET_CHECK_GT(channels, 0);
  286. uint32_t outWidth = computeOutSize(width, param.filter_width, param.stride_width,
  287. param.padding_left, param.padding_right);
  288. uint32_t outHeight = computeOutSize(height, param.filter_height, param.stride_height,
  289. param.padding_top, param.padding_bottom);
  290. Shape output = input;
  291. if (param.useNchw) {
  292. output.dimensions = {batches, channels, outHeight, outWidth};
  293. } else {
  294. output.dimensions = {batches, outHeight, outWidth, channels};
  295. }
  296. return context->setOutputShape(kOutputTensor, output);
  297. }
  298. #define POOLING_DISPATCH_INPUT_TYPE(name, type, cppType) \
  299. case OperandType::type: \
  300. return name(context->getInputBuffer<cppType>(kInputTensor), \
  301. context->getInputShape(kInputTensor), param, \
  302. context->getOutputBuffer<cppType>(kOutputTensor), \
  303. context->getOutputShape(kOutputTensor))
  304. bool executeAveragePool(IOperationExecutionContext* context) {
  305. // Bypass execution in the case of zero-sized input.
  306. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  307. PoolingParam param;
  308. NN_RET_CHECK(param.initialize(context));
  309. switch (context->getInputType(kInputTensor)) {
  310. POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT32, float);
  311. POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_FLOAT16, _Float16);
  312. POOLING_DISPATCH_INPUT_TYPE(averagePool, TENSOR_QUANT8_ASYMM, uint8_t);
  313. default:
  314. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation AVERAGE_POOL_2D";
  315. }
  316. }
  317. bool executeL2Pool(IOperationExecutionContext* context) {
  318. // Bypass execution in the case of zero-sized input.
  319. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  320. PoolingParam param;
  321. NN_RET_CHECK(param.initialize(context));
  322. switch (context->getInputType(kInputTensor)) {
  323. POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT32, float);
  324. POOLING_DISPATCH_INPUT_TYPE(l2Pool, TENSOR_FLOAT16, _Float16);
  325. default:
  326. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation L2_POOL_2D";
  327. }
  328. }
  329. bool executeMaxPool(IOperationExecutionContext* context) {
  330. // Bypass execution in the case of zero-sized input.
  331. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  332. PoolingParam param;
  333. NN_RET_CHECK(param.initialize(context));
  334. switch (context->getInputType(kInputTensor)) {
  335. POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT32, float);
  336. POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_FLOAT16, _Float16);
  337. POOLING_DISPATCH_INPUT_TYPE(maxPool, TENSOR_QUANT8_ASYMM, uint8_t);
  338. default:
  339. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MAX_POOL_2D";
  340. }
  341. }
  342. #undef POOLING_DISPATCH_INPUT_TYPE
  343. } // namespace pooling
  344. using std::placeholders::_1;
  345. NN_REGISTER_OPERATION(AVERAGE_POOL_2D, "AVERAGE_POOL_2D",
  346. std::bind(pooling::validate, OperationType::AVERAGE_POOL_2D, _1),
  347. pooling::prepare, pooling::executeAveragePool, .allowZeroSizedInput = true);
  348. NN_REGISTER_OPERATION(L2_POOL_2D, "L2_POOL_2D",
  349. std::bind(pooling::validate, OperationType::L2_POOL_2D, _1), pooling::prepare,
  350. pooling::executeL2Pool, .allowZeroSizedInput = true);
  351. NN_REGISTER_OPERATION(MAX_POOL_2D, "MAX_POOL_2D",
  352. std::bind(pooling::validate, OperationType::MAX_POOL_2D, _1),
  353. pooling::prepare, pooling::executeMaxPool, .allowZeroSizedInput = true);
  354. } // namespace nn
  355. } // namespace android