RoiPooling.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "CpuOperationUtils.h"
  17. #include "OperationResolver.h"
  18. #include "OperationsUtils.h"
  19. #include <cfloat>
  20. #include <cmath>
  21. #include "Tracing.h"
  22. namespace android {
  23. namespace nn {
  24. namespace roi_pooling {
  25. constexpr char kOperationName[] = "ROI_POOLING";
  26. constexpr uint32_t kNumInputs = 8;
  27. constexpr uint32_t kInputTensor = 0;
  28. constexpr uint32_t kRoiTensor = 1;
  29. constexpr uint32_t kBatchSplitTensor = 2;
  30. constexpr uint32_t kOutputHeightScalar = 3;
  31. constexpr uint32_t kOutputWidthScalar = 4;
  32. constexpr uint32_t kHeightStrideSalar = 5;
  33. constexpr uint32_t kWidthStrideScalar = 6;
  34. constexpr uint32_t kLayoutScalar = 7;
  35. constexpr uint32_t kNumOutputs = 1;
  36. constexpr uint32_t kOutputTensor = 0;
  37. namespace {
  38. template <typename T_Input, typename T_Roi>
  39. inline bool roiPoolingNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
  40. const Shape& roiShape, const int32_t* batchSplitData,
  41. const Shape& batchSplitShape, float heightStride, float widthStride,
  42. T_Input* outputData, const Shape& outputShape) {
  43. NNTRACE_TRANS("RoiPooling");
  44. const uint32_t kRoiDim = 4;
  45. const T_Roi heightScale = 1.0f / heightStride;
  46. const T_Roi widthScale = 1.0f / widthStride;
  47. uint32_t numBatches = getSizeOfDimension(inputShape, 0);
  48. uint32_t inHeight = getSizeOfDimension(inputShape, 1);
  49. uint32_t inWidth = getSizeOfDimension(inputShape, 2);
  50. uint32_t inDepth = getSizeOfDimension(inputShape, 3);
  51. uint32_t outHeight = getSizeOfDimension(outputShape, 1);
  52. uint32_t outWidth = getSizeOfDimension(outputShape, 2);
  53. uint32_t numRois = getSizeOfDimension(roiShape, 0);
  54. uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
  55. T_Input* outPtr = outputData;
  56. const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
  57. uint32_t roiIndex = 0;
  58. for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
  59. uint32_t batchId = batchSplitData[roiIndex];
  60. // Check for malformed data
  61. // 1. invalid batch id
  62. // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
  63. // 3. Invalid region: x2 < x1 || y2 < y1
  64. NN_RET_CHECK_GE(batchId, 0);
  65. NN_RET_CHECK_LT(batchId, numBatches);
  66. NN_RET_CHECK(roiInfo[0] >= 0);
  67. NN_RET_CHECK(roiInfo[1] >= 0);
  68. NN_RET_CHECK(roiInfo[2] >= 0);
  69. NN_RET_CHECK(roiInfo[3] >= 0);
  70. NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
  71. NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
  72. NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
  73. NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
  74. NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
  75. NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
  76. int32_t wRoiStart = std::round(static_cast<float>(roiInfo[0] * widthScale));
  77. int32_t hRoiStart = std::round(static_cast<float>(roiInfo[1] * heightScale));
  78. int32_t wRoiEnd = std::round(static_cast<float>(roiInfo[2] * widthScale));
  79. int32_t hRoiEnd = std::round(static_cast<float>(roiInfo[3] * heightScale));
  80. // Rois with width/height < 1 are considered malformed and are forced to be 1
  81. T_Roi roiWidth = static_cast<T_Roi>(std::max(wRoiEnd - wRoiStart + 1, 1));
  82. T_Roi roiHeight = static_cast<T_Roi>(std::max(hRoiEnd - hRoiStart + 1, 1));
  83. T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
  84. T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
  85. const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
  86. for (uint32_t i = 0; i < outHeight; i++) {
  87. for (uint32_t j = 0; j < outWidth; j++) {
  88. // Take floor on start, ceil on end, start included, end excluded, i.e. [start, end)
  89. // end is guaranteed to larger than start by at least 1
  90. uint32_t wStart = std::floor(static_cast<float>(wStepSize * j + wRoiStart));
  91. uint32_t wEnd = std::ceil(static_cast<float>(wStepSize * (j + 1) + wRoiStart));
  92. uint32_t hStart = std::floor(static_cast<float>(hStepSize * i + hRoiStart));
  93. uint32_t hEnd = std::ceil(static_cast<float>(hStepSize * (i + 1) + hRoiStart));
  94. wStart = std::min(wStart, inWidth);
  95. wEnd = std::min(wEnd, inWidth);
  96. hStart = std::min(hStart, inHeight);
  97. hEnd = std::min(hEnd, inHeight);
  98. for (uint32_t k = 0; k < inDepth; k++) {
  99. T_Input maxValue = static_cast<T_Input>(inputShape.offset);
  100. bool first = true;
  101. for (uint32_t h = hStart; h < hEnd; h++) {
  102. for (uint32_t w = wStart; w < wEnd; w++) {
  103. T_Input inputValue = batchBase[h * inWidth * inDepth + w * inDepth + k];
  104. if (first || inputValue > maxValue) {
  105. maxValue = inputValue;
  106. first = false;
  107. }
  108. }
  109. }
  110. outPtr[k] = maxValue;
  111. }
  112. outPtr += inDepth;
  113. }
  114. }
  115. }
  116. return true;
  117. }
  118. template <typename T_Input, typename T_Roi>
  119. inline bool roiPooling(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
  120. const Shape& roiShape, const int32_t* batchSplitData,
  121. const Shape& batchSplitShape, float heightStride, float widthStride,
  122. bool useNchw, T_Input* outputData, const Shape& outputShape) {
  123. InputWithLayout<T_Input> input(useNchw);
  124. OutputWithLayout<T_Input> output(useNchw);
  125. NN_RET_CHECK(input.initialize(inputData, inputShape));
  126. NN_RET_CHECK(output.initialize(outputData, outputShape));
  127. NN_RET_CHECK(roiPoolingNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
  128. batchSplitData, batchSplitShape, heightStride, widthStride,
  129. output.getNhwcBuffer(), output.getNhwcShape()));
  130. NN_RET_CHECK(output.commit());
  131. return true;
  132. }
  133. template <>
  134. inline bool roiPooling<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape,
  135. const uint16_t* roiData, const Shape& roiShape,
  136. const int32_t* batchSplitData,
  137. const Shape& batchSplitShape, float heightStride,
  138. float widthStride, bool useNchw, uint8_t* outputData,
  139. const Shape& outputShape) {
  140. std::vector<float> roi_float32(getNumberOfElements(roiShape));
  141. convertQuantToFloat32(roiData, roiShape.scale, roiShape.offset, &roi_float32);
  142. NN_RET_CHECK(roiPooling(inputData, inputShape, roi_float32.data(), roiShape, batchSplitData,
  143. batchSplitShape, heightStride, widthStride, useNchw, outputData,
  144. outputShape));
  145. return true;
  146. }
  147. } // namespace
  148. bool validate(const IOperationValidationContext* context) {
  149. NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
  150. NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
  151. std::vector<OperandType> inExpectedTypes;
  152. auto inputType = context->getInputType(kInputTensor);
  153. if (inputType == OperandType::TENSOR_FLOAT32) {
  154. inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
  155. OperandType::TENSOR_INT32, OperandType::INT32,
  156. OperandType::INT32, OperandType::FLOAT32,
  157. OperandType::FLOAT32, OperandType::BOOL};
  158. } else if (inputType == OperandType::TENSOR_FLOAT16) {
  159. inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
  160. OperandType::TENSOR_INT32, OperandType::INT32,
  161. OperandType::INT32, OperandType::FLOAT16,
  162. OperandType::FLOAT16, OperandType::BOOL};
  163. } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
  164. inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
  165. OperandType::TENSOR_QUANT16_ASYMM,
  166. OperandType::TENSOR_INT32,
  167. OperandType::INT32,
  168. OperandType::INT32,
  169. OperandType::FLOAT32,
  170. OperandType::FLOAT32,
  171. OperandType::BOOL};
  172. } else {
  173. LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
  174. return false;
  175. }
  176. NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
  177. NN_RET_CHECK(validateOutputTypes(context, {inputType}));
  178. return validateHalVersion(context, HalVersion::V1_2);
  179. }
  180. bool prepare(IOperationExecutionContext* context) {
  181. bool useNchw = context->getInputValue<bool>(kLayoutScalar);
  182. Shape input = context->getInputShape(kInputTensor);
  183. Shape roiShape = context->getInputShape(kRoiTensor);
  184. Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
  185. NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
  186. NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
  187. uint32_t numBatches = getSizeOfDimension(input, 0);
  188. uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
  189. uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
  190. uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
  191. uint32_t numRois = getSizeOfDimension(roiShape, 0);
  192. NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
  193. NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
  194. auto outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
  195. auto outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
  196. float heightStride, widthStride;
  197. if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
  198. heightStride = context->getInputValue<_Float16>(kHeightStrideSalar);
  199. widthStride = context->getInputValue<_Float16>(kWidthStrideScalar);
  200. } else {
  201. heightStride = context->getInputValue<float>(kHeightStrideSalar);
  202. widthStride = context->getInputValue<float>(kWidthStrideScalar);
  203. }
  204. NN_RET_CHECK_GT(outputHeight, 0);
  205. NN_RET_CHECK_GT(outputWidth, 0);
  206. NN_RET_CHECK_GT(heightStride, 0);
  207. NN_RET_CHECK_GT(widthStride, 0);
  208. if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
  209. NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
  210. NN_RET_CHECK_EQ(roiShape.offset, 0);
  211. }
  212. Shape output = input;
  213. if (useNchw) {
  214. output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
  215. static_cast<uint32_t>(outputWidth)};
  216. } else {
  217. output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
  218. static_cast<uint32_t>(outputWidth), inDepth};
  219. }
  220. return context->setOutputShape(kOutputTensor, output);
  221. }
  222. bool execute(IOperationExecutionContext* context) {
  223. switch (context->getInputType(kInputTensor)) {
  224. case OperandType::TENSOR_FLOAT16:
  225. return roiPooling(context->getInputBuffer<_Float16>(kInputTensor),
  226. context->getInputShape(kInputTensor),
  227. context->getInputBuffer<_Float16>(kRoiTensor),
  228. context->getInputShape(kRoiTensor),
  229. context->getInputBuffer<int32_t>(kBatchSplitTensor),
  230. context->getInputShape(kBatchSplitTensor),
  231. context->getInputValue<_Float16>(kHeightStrideSalar),
  232. context->getInputValue<_Float16>(kWidthStrideScalar),
  233. context->getInputValue<bool>(kLayoutScalar),
  234. context->getOutputBuffer<_Float16>(kOutputTensor),
  235. context->getOutputShape(kOutputTensor));
  236. case OperandType::TENSOR_FLOAT32:
  237. return roiPooling(context->getInputBuffer<float>(kInputTensor),
  238. context->getInputShape(kInputTensor),
  239. context->getInputBuffer<float>(kRoiTensor),
  240. context->getInputShape(kRoiTensor),
  241. context->getInputBuffer<int32_t>(kBatchSplitTensor),
  242. context->getInputShape(kBatchSplitTensor),
  243. context->getInputValue<float>(kHeightStrideSalar),
  244. context->getInputValue<float>(kWidthStrideScalar),
  245. context->getInputValue<bool>(kLayoutScalar),
  246. context->getOutputBuffer<float>(kOutputTensor),
  247. context->getOutputShape(kOutputTensor));
  248. case OperandType::TENSOR_QUANT8_ASYMM:
  249. return roiPooling(context->getInputBuffer<uint8_t>(kInputTensor),
  250. context->getInputShape(kInputTensor),
  251. context->getInputBuffer<uint16_t>(kRoiTensor),
  252. context->getInputShape(kRoiTensor),
  253. context->getInputBuffer<int32_t>(kBatchSplitTensor),
  254. context->getInputShape(kBatchSplitTensor),
  255. context->getInputValue<float>(kHeightStrideSalar),
  256. context->getInputValue<float>(kWidthStrideScalar),
  257. context->getInputValue<bool>(kLayoutScalar),
  258. context->getOutputBuffer<uint8_t>(kOutputTensor),
  259. context->getOutputShape(kOutputTensor));
  260. default:
  261. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
  262. }
  263. }
  264. } // namespace roi_pooling
  265. NN_REGISTER_OPERATION(ROI_POOLING, roi_pooling::kOperationName, roi_pooling::validate,
  266. roi_pooling::prepare, roi_pooling::execute);
  267. } // namespace nn
  268. } // namespace android