RoiAlign.cpp 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "CpuOperationUtils.h"
  17. #include "OperationResolver.h"
  18. #include "OperationsUtils.h"
  19. #include <cfloat>
  20. #include <cmath>
  21. #include "Tracing.h"
  22. #include "tensorflow/lite/kernels/internal/common.h"
  23. namespace android {
  24. namespace nn {
  25. namespace roi_align {
  26. constexpr char kOperationName[] = "ROI_ALIGN";
  27. constexpr uint32_t kNumInputs = 10;
  28. constexpr uint32_t kInputTensor = 0;
  29. constexpr uint32_t kRoiTensor = 1;
  30. constexpr uint32_t kBatchSplitTensor = 2;
  31. constexpr uint32_t kOutputHeightScalar = 3;
  32. constexpr uint32_t kOutputWidthScalar = 4;
  33. constexpr uint32_t kHeightStrideSalar = 5;
  34. constexpr uint32_t kWidthStrideScalar = 6;
  35. constexpr uint32_t kHeightSamplingRatioScalar = 7;
  36. constexpr uint32_t kWidthSamplingRatioScalar = 8;
  37. constexpr uint32_t kLayoutScalar = 9;
  38. constexpr uint32_t kNumOutputs = 1;
  39. constexpr uint32_t kOutputTensor = 0;
  40. namespace {
  41. template <typename T_Input, typename T_Roi>
  42. inline bool roiAlignNhwc(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
  43. const Shape& roiShape, const int32_t* batchSplitData,
  44. const Shape& batchSplitShape, float heightStride, float widthStride,
  45. int32_t heightSamplingRatio, int32_t widthSamplingRatio,
  46. T_Input* outputData, const Shape& outputShape) {
  47. NNTRACE_TRANS("RoiAlign");
  48. const uint32_t kRoiDim = 4;
  49. const T_Roi heightScale = 1.0f / heightStride;
  50. const T_Roi widthScale = 1.0f / widthStride;
  51. uint32_t numBatches = getSizeOfDimension(inputShape, 0);
  52. uint32_t inHeight = getSizeOfDimension(inputShape, 1);
  53. uint32_t inWidth = getSizeOfDimension(inputShape, 2);
  54. uint32_t inDepth = getSizeOfDimension(inputShape, 3);
  55. uint32_t outHeight = getSizeOfDimension(outputShape, 1);
  56. uint32_t outWidth = getSizeOfDimension(outputShape, 2);
  57. uint32_t numRois = getSizeOfDimension(roiShape, 0);
  58. uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
  59. T_Input* outPtr = outputData;
  60. const T_Roi* roiDataEnd = roiData + numRois * roiInfoLength;
  61. uint32_t roiIndex = 0;
  62. for (const T_Roi* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
  63. uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
  64. // Check for malformed data
  65. // 1. invalid batch id
  66. // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
  67. // 3. Invalid region: x2 < x1 || y2 < y1
  68. NN_RET_CHECK_GE(batchId, 0);
  69. NN_RET_CHECK_LT(batchId, numBatches);
  70. NN_RET_CHECK(roiInfo[0] >= 0);
  71. NN_RET_CHECK(roiInfo[1] >= 0);
  72. NN_RET_CHECK(roiInfo[2] >= 0);
  73. NN_RET_CHECK(roiInfo[3] >= 0);
  74. NN_RET_CHECK(roiInfo[0] * widthScale <= inWidth);
  75. NN_RET_CHECK(roiInfo[1] * heightScale <= inHeight);
  76. NN_RET_CHECK(roiInfo[2] * widthScale <= inWidth);
  77. NN_RET_CHECK(roiInfo[3] * heightScale <= inHeight);
  78. NN_RET_CHECK(roiInfo[0] <= roiInfo[2]);
  79. NN_RET_CHECK(roiInfo[1] <= roiInfo[3]);
  80. T_Roi wRoiStart = roiInfo[0] * widthScale;
  81. T_Roi hRoiStart = roiInfo[1] * heightScale;
  82. T_Roi wRoiEnd = roiInfo[2] * widthScale;
  83. T_Roi hRoiEnd = roiInfo[3] * heightScale;
  84. T_Roi roiWidth = std::max(static_cast<float>(wRoiEnd - wRoiStart), 1.0f);
  85. T_Roi roiHeight = std::max(static_cast<float>(hRoiEnd - hRoiStart), 1.0f);
  86. T_Roi wStepSize = roiWidth / static_cast<T_Roi>(outWidth);
  87. T_Roi hStepSize = roiHeight / static_cast<T_Roi>(outHeight);
  88. // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
  89. uint32_t wSamplingRatio = widthSamplingRatio > 0 ? widthSamplingRatio
  90. : std::ceil(static_cast<float>(wStepSize));
  91. uint32_t hSamplingRatio = heightSamplingRatio > 0
  92. ? heightSamplingRatio
  93. : std::ceil(static_cast<float>(hStepSize));
  94. int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
  95. T_Roi wBinSize = wStepSize / static_cast<T_Roi>(wSamplingRatio);
  96. T_Roi hBinSize = hStepSize / static_cast<T_Roi>(hSamplingRatio);
  97. const T_Input* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
  98. for (uint32_t i = 0; i < outHeight; i++) {
  99. for (uint32_t j = 0; j < outWidth; j++) {
  100. T_Roi wStart = wStepSize * j + wRoiStart;
  101. T_Roi wEnd = wStepSize * (j + 1) + wRoiStart;
  102. T_Roi hStart = hStepSize * i + hRoiStart;
  103. T_Roi hEnd = hStepSize * (i + 1) + hRoiStart;
  104. // initialize output to zero
  105. for (uint32_t k = 0; k < inDepth; k++) outPtr[k] = 0;
  106. // calculate the sum of the sampling points
  107. for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
  108. for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
  109. T_Roi y = hStart + hBinSize / 2 + hBinSize * yInd;
  110. T_Roi x = wStart + wBinSize / 2 + wBinSize * xInd;
  111. // bilinear interpolation of point (x,y)
  112. // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
  113. uint32_t x1 = std::floor(static_cast<float>(x));
  114. uint32_t y1 = std::floor(static_cast<float>(y));
  115. uint32_t x2 = x1 + 1, y2 = y1 + 1;
  116. T_Roi dx1 = x - static_cast<T_Roi>(x1);
  117. T_Roi dy1 = y - static_cast<T_Roi>(y1);
  118. // dealing with out of bound samples
  119. if (x1 >= inWidth - 1) {
  120. x1 = x2 = inWidth - 1;
  121. dx1 = 0;
  122. }
  123. if (y1 >= inHeight - 1) {
  124. y1 = y2 = inHeight - 1;
  125. dy1 = 0;
  126. }
  127. T_Roi dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
  128. T_Roi ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
  129. uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
  130. y1 * inWidth * inDepth + x2 * inDepth,
  131. y2 * inWidth * inDepth + x1 * inDepth,
  132. y2 * inWidth * inDepth + x2 * inDepth};
  133. for (uint32_t k = 0; k < inDepth; k++) {
  134. T_Input interpolation = 0;
  135. for (uint32_t c = 0; c < 4; c++) {
  136. interpolation += ws[c] * batchBase[offsets[c] + k];
  137. }
  138. outPtr[k] += interpolation;
  139. }
  140. }
  141. }
  142. // take average
  143. for (uint32_t k = 0; k < inDepth; k++)
  144. outPtr[k] /= static_cast<T_Input>(numSamplingPoints);
  145. outPtr += inDepth;
  146. }
  147. }
  148. }
  149. return true;
  150. }
  151. template <>
  152. inline bool roiAlignNhwc<uint8_t, uint16_t>(const uint8_t* inputData, const Shape& inputShape,
  153. const uint16_t* roiData, const Shape& roiShape,
  154. const int32_t* batchSplitData,
  155. const Shape& batchSplitShape, float heightStride,
  156. float widthStride, int32_t heightSamplingRatio,
  157. int32_t widthSamplingRatio, uint8_t* outputData,
  158. const Shape& outputShape) {
  159. NNTRACE_TRANS("RoiAlignQuant8");
  160. constexpr float wScale = 1.0f / 255.0f;
  161. constexpr uint32_t kRoiDim = 4;
  162. const float heightScale = 1.0f / heightStride;
  163. const float widthScale = 1.0f / widthStride;
  164. uint32_t numBatches = getSizeOfDimension(inputShape, 0);
  165. uint32_t inHeight = getSizeOfDimension(inputShape, 1);
  166. uint32_t inWidth = getSizeOfDimension(inputShape, 2);
  167. uint32_t inDepth = getSizeOfDimension(inputShape, 3);
  168. uint32_t outHeight = getSizeOfDimension(outputShape, 1);
  169. uint32_t outWidth = getSizeOfDimension(outputShape, 2);
  170. uint32_t numRois = getSizeOfDimension(roiShape, 0);
  171. uint32_t roiInfoLength = getSizeOfDimension(roiShape, 1);
  172. uint8_t* outPtr = outputData;
  173. const uint16_t* roiDataEnd = roiData + numRois * roiInfoLength;
  174. uint32_t roiIndex = 0;
  175. for (const uint16_t* roiInfo = roiData; roiInfo < roiDataEnd; roiInfo += kRoiDim, roiIndex++) {
  176. uint32_t batchId = static_cast<uint32_t>(batchSplitData[roiIndex]);
  177. float wRoiStart = static_cast<float>(roiInfo[0]) * widthScale * 0.125f;
  178. float hRoiStart = static_cast<float>(roiInfo[1]) * heightScale * 0.125f;
  179. float wRoiEnd = static_cast<float>(roiInfo[2]) * widthScale * 0.125f;
  180. float hRoiEnd = static_cast<float>(roiInfo[3]) * heightScale * 0.125f;
  181. // Check for malformed data
  182. // 1. invalid batch id
  183. // 2. Region out of bound: x1|x2|y1|y2 < 0 || x1|x2 > inWidth || y1|y2 > inHeight
  184. // 3. Invalid region: x2 < x1 || y2 < y1
  185. NN_RET_CHECK_GE(batchId, 0);
  186. NN_RET_CHECK_LT(batchId, numBatches);
  187. NN_RET_CHECK(wRoiStart <= inWidth);
  188. NN_RET_CHECK(hRoiStart <= inHeight);
  189. NN_RET_CHECK(wRoiEnd <= inWidth);
  190. NN_RET_CHECK(hRoiEnd <= inHeight);
  191. NN_RET_CHECK_LE(wRoiStart, wRoiEnd);
  192. NN_RET_CHECK_LE(hRoiStart, hRoiEnd);
  193. float roiWidth = std::max(wRoiEnd - wRoiStart, 1.0f);
  194. float roiHeight = std::max(hRoiEnd - hRoiStart, 1.0f);
  195. float wStepSize = roiWidth / static_cast<float>(outWidth);
  196. float hStepSize = roiHeight / static_cast<float>(outHeight);
  197. // if samplingRatio = 0, use adaptive value of ceil(roiWidth/outWidth), same for height
  198. uint32_t wSamplingRatio =
  199. widthSamplingRatio > 0 ? widthSamplingRatio : std::ceil(wStepSize);
  200. uint32_t hSamplingRatio =
  201. heightSamplingRatio > 0 ? heightSamplingRatio : std::ceil(hStepSize);
  202. int32_t numSamplingPoints = wSamplingRatio * hSamplingRatio;
  203. float wBinSize = wStepSize / static_cast<float>(wSamplingRatio);
  204. float hBinSize = hStepSize / static_cast<float>(hSamplingRatio);
  205. float realMultiplier = inputShape.scale * wScale / outputShape.scale / numSamplingPoints;
  206. int32_t outputMultiplier = 0;
  207. int32_t outputShift = 0;
  208. if (!QuantizeMultiplierSmallerThanOne(realMultiplier, &outputMultiplier, &outputShift)) {
  209. return false;
  210. }
  211. const uint8_t* batchBase = inputData + batchId * inHeight * inWidth * inDepth;
  212. for (uint32_t i = 0; i < outHeight; i++) {
  213. for (uint32_t j = 0; j < outWidth; j++) {
  214. float wStart = wStepSize * j + wRoiStart;
  215. float wEnd = wStepSize * (j + 1) + wRoiStart;
  216. float hStart = hStepSize * i + hRoiStart;
  217. float hEnd = hStepSize * (i + 1) + hRoiStart;
  218. std::vector<int32_t> outTemp(inDepth, 0);
  219. // calculate the sum of the sampling points
  220. for (uint32_t yInd = 0; yInd < hSamplingRatio; yInd++) {
  221. for (uint32_t xInd = 0; xInd < wSamplingRatio; xInd++) {
  222. float y = hStart + hBinSize / 2 + hBinSize * yInd;
  223. float x = wStart + wBinSize / 2 + wBinSize * xInd;
  224. // bilinear interpolation of point (x,y)
  225. // w.r.t box [(x1,y1), (x1,y2), (x2,y1), (x2,y2)]
  226. uint32_t x1 = std::floor(x), y1 = std::floor(y);
  227. uint32_t x2 = x1 + 1, y2 = y1 + 1;
  228. float dx1 = x - static_cast<float>(x1);
  229. float dy1 = y - static_cast<float>(y1);
  230. // dealing with out of bound samples
  231. if (x1 >= inWidth - 1) {
  232. x1 = x2 = inWidth - 1;
  233. dx1 = 0;
  234. }
  235. if (y1 >= inHeight - 1) {
  236. y1 = y2 = inHeight - 1;
  237. dy1 = 0;
  238. }
  239. float dx2 = 1.0f - dx1, dy2 = 1.0f - dy1;
  240. float ws[] = {dx2 * dy2, dx1 * dy2, dx2 * dy1, dx1 * dy1};
  241. uint32_t offsets[] = {y1 * inWidth * inDepth + x1 * inDepth,
  242. y1 * inWidth * inDepth + x2 * inDepth,
  243. y2 * inWidth * inDepth + x1 * inDepth,
  244. y2 * inWidth * inDepth + x2 * inDepth};
  245. for (uint32_t k = 0; k < inDepth; k++) {
  246. int32_t interpolation = 0;
  247. for (uint32_t c = 0; c < 4; c++) {
  248. int32_t wQuant = static_cast<int32_t>(std::round(ws[c] / wScale));
  249. interpolation +=
  250. wQuant * (static_cast<int32_t>(batchBase[offsets[c] + k]) -
  251. inputShape.offset);
  252. }
  253. outTemp[k] += interpolation;
  254. }
  255. }
  256. }
  257. // take average and cast to output quantization
  258. for (uint32_t k = 0; k < inDepth; k++) {
  259. int32_t raw_out = tflite::MultiplyByQuantizedMultiplier(
  260. outTemp[k], outputMultiplier, -outputShift) +
  261. outputShape.offset;
  262. int32_t clamped_out = std::min(255, std::max(0, raw_out));
  263. outPtr[k] = static_cast<uint8_t>(clamped_out);
  264. }
  265. outPtr += inDepth;
  266. }
  267. }
  268. }
  269. return true;
  270. }
  271. template <typename T_Input, typename T_Roi>
  272. inline bool roiAlign(const T_Input* inputData, const Shape& inputShape, const T_Roi* roiData,
  273. const Shape& roiShape, const int32_t* batchSplitData,
  274. const Shape& batchSplitShape, float heightStride, float widthStride,
  275. int32_t heightSamplingRatio, int32_t widthSamplingRatio, bool useNchw,
  276. T_Input* outputData, const Shape& outputShape) {
  277. InputWithLayout<T_Input> input(useNchw);
  278. OutputWithLayout<T_Input> output(useNchw);
  279. NN_RET_CHECK(input.initialize(inputData, inputShape));
  280. NN_RET_CHECK(output.initialize(outputData, outputShape));
  281. NN_RET_CHECK(roiAlignNhwc(input.getNhwcBuffer(), input.getNhwcShape(), roiData, roiShape,
  282. batchSplitData, batchSplitShape, heightStride, widthStride,
  283. heightSamplingRatio, widthSamplingRatio, output.getNhwcBuffer(),
  284. output.getNhwcShape()));
  285. NN_RET_CHECK(output.commit());
  286. return true;
  287. }
  288. } // namespace
  289. bool validate(const IOperationValidationContext* context) {
  290. NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
  291. NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
  292. std::vector<OperandType> inExpectedTypes;
  293. auto inputType = context->getInputType(kInputTensor);
  294. if (inputType == OperandType::TENSOR_FLOAT32) {
  295. inExpectedTypes = {OperandType::TENSOR_FLOAT32, OperandType::TENSOR_FLOAT32,
  296. OperandType::TENSOR_INT32, OperandType::INT32,
  297. OperandType::INT32, OperandType::FLOAT32,
  298. OperandType::FLOAT32, OperandType::INT32,
  299. OperandType::INT32, OperandType::BOOL};
  300. } else if (inputType == OperandType::TENSOR_FLOAT16) {
  301. inExpectedTypes = {OperandType::TENSOR_FLOAT16, OperandType::TENSOR_FLOAT16,
  302. OperandType::TENSOR_INT32, OperandType::INT32,
  303. OperandType::INT32, OperandType::FLOAT16,
  304. OperandType::FLOAT16, OperandType::INT32,
  305. OperandType::INT32, OperandType::BOOL};
  306. } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
  307. inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM,
  308. OperandType::TENSOR_QUANT16_ASYMM,
  309. OperandType::TENSOR_INT32,
  310. OperandType::INT32,
  311. OperandType::INT32,
  312. OperandType::FLOAT32,
  313. OperandType::FLOAT32,
  314. OperandType::INT32,
  315. OperandType::INT32,
  316. OperandType::BOOL};
  317. } else {
  318. LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
  319. return false;
  320. }
  321. NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
  322. NN_RET_CHECK(validateOutputTypes(context, {inputType}));
  323. return validateHalVersion(context, HalVersion::V1_2);
  324. }
  325. bool prepare(IOperationExecutionContext* context) {
  326. bool useNchw = context->getInputValue<bool>(kLayoutScalar);
  327. Shape input = context->getInputShape(kInputTensor);
  328. Shape roiShape = context->getInputShape(kRoiTensor);
  329. Shape batchSplitShape = context->getInputShape(kBatchSplitTensor);
  330. NN_RET_CHECK_EQ(getNumberOfDimensions(input), 4);
  331. NN_RET_CHECK_EQ(getNumberOfDimensions(roiShape), 2);
  332. uint32_t numBatches = getSizeOfDimension(input, 0);
  333. uint32_t inHeight = getSizeOfDimension(input, useNchw ? 2 : 1);
  334. uint32_t inWidth = getSizeOfDimension(input, useNchw ? 3 : 2);
  335. uint32_t inDepth = getSizeOfDimension(input, useNchw ? 1 : 3);
  336. uint32_t numRois = getSizeOfDimension(roiShape, 0);
  337. // Every dimension must be positive except for numRois.
  338. NN_RET_CHECK_GT(numBatches, 0);
  339. NN_RET_CHECK_GT(inHeight, 0);
  340. NN_RET_CHECK_GT(inWidth, 0);
  341. NN_RET_CHECK_GT(inDepth, 0);
  342. NN_RET_CHECK_EQ(getSizeOfDimension(roiShape, 1), 4);
  343. NN_RET_CHECK_EQ(getSizeOfDimension(batchSplitShape, 0), numRois);
  344. int32_t outputHeight = context->getInputValue<int32_t>(kOutputHeightScalar);
  345. int32_t outputWidth = context->getInputValue<int32_t>(kOutputWidthScalar);
  346. int32_t heightSamplingRatio = context->getInputValue<int32_t>(kHeightSamplingRatioScalar);
  347. int32_t widthSamplingRatio = context->getInputValue<int32_t>(kWidthSamplingRatioScalar);
  348. float heightScale, widthScale;
  349. if (context->getInputType(kInputTensor) == OperandType::TENSOR_FLOAT16) {
  350. heightScale = context->getInputValue<_Float16>(kHeightStrideSalar);
  351. widthScale = context->getInputValue<_Float16>(kWidthStrideScalar);
  352. } else {
  353. heightScale = context->getInputValue<float>(kHeightStrideSalar);
  354. widthScale = context->getInputValue<float>(kWidthStrideScalar);
  355. }
  356. NN_RET_CHECK_GT(outputHeight, 0);
  357. NN_RET_CHECK_GT(outputWidth, 0);
  358. NN_RET_CHECK_GT(heightScale, 0);
  359. NN_RET_CHECK_GT(widthScale, 0);
  360. // Sampling ratio can set to 0 for adaptive value.
  361. NN_RET_CHECK_GE(heightSamplingRatio, 0);
  362. NN_RET_CHECK_GE(widthSamplingRatio, 0);
  363. if (roiShape.type == OperandType::TENSOR_QUANT16_ASYMM) {
  364. NN_RET_CHECK_EQ(roiShape.scale, 0.125f);
  365. NN_RET_CHECK_EQ(roiShape.offset, 0);
  366. }
  367. Shape output = context->getOutputShape(kOutputTensor);
  368. output.type = input.type;
  369. if (useNchw) {
  370. output.dimensions = {numRois, inDepth, static_cast<uint32_t>(outputHeight),
  371. static_cast<uint32_t>(outputWidth)};
  372. } else {
  373. output.dimensions = {numRois, static_cast<uint32_t>(outputHeight),
  374. static_cast<uint32_t>(outputWidth), inDepth};
  375. }
  376. return context->setOutputShape(kOutputTensor, output);
  377. }
  378. bool execute(IOperationExecutionContext* context) {
  379. // Bypass execution in the case of zero-sized input.
  380. if (getNumberOfElements(context->getInputShape(kRoiTensor)) == 0) return true;
  381. switch (context->getInputType(kInputTensor)) {
  382. case OperandType::TENSOR_FLOAT16:
  383. return roiAlign(context->getInputBuffer<_Float16>(kInputTensor),
  384. context->getInputShape(kInputTensor),
  385. context->getInputBuffer<_Float16>(kRoiTensor),
  386. context->getInputShape(kRoiTensor),
  387. context->getInputBuffer<int32_t>(kBatchSplitTensor),
  388. context->getInputShape(kBatchSplitTensor),
  389. context->getInputValue<_Float16>(kHeightStrideSalar),
  390. context->getInputValue<_Float16>(kWidthStrideScalar),
  391. context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
  392. context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
  393. context->getInputValue<bool>(kLayoutScalar),
  394. context->getOutputBuffer<_Float16>(kOutputTensor),
  395. context->getOutputShape(kOutputTensor));
  396. case OperandType::TENSOR_FLOAT32:
  397. return roiAlign(context->getInputBuffer<float>(kInputTensor),
  398. context->getInputShape(kInputTensor),
  399. context->getInputBuffer<float>(kRoiTensor),
  400. context->getInputShape(kRoiTensor),
  401. context->getInputBuffer<int32_t>(kBatchSplitTensor),
  402. context->getInputShape(kBatchSplitTensor),
  403. context->getInputValue<float>(kHeightStrideSalar),
  404. context->getInputValue<float>(kWidthStrideScalar),
  405. context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
  406. context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
  407. context->getInputValue<bool>(kLayoutScalar),
  408. context->getOutputBuffer<float>(kOutputTensor),
  409. context->getOutputShape(kOutputTensor));
  410. case OperandType::TENSOR_QUANT8_ASYMM:
  411. return roiAlign(context->getInputBuffer<uint8_t>(kInputTensor),
  412. context->getInputShape(kInputTensor),
  413. context->getInputBuffer<uint16_t>(kRoiTensor),
  414. context->getInputShape(kRoiTensor),
  415. context->getInputBuffer<int32_t>(kBatchSplitTensor),
  416. context->getInputShape(kBatchSplitTensor),
  417. context->getInputValue<float>(kHeightStrideSalar),
  418. context->getInputValue<float>(kWidthStrideScalar),
  419. context->getInputValue<int32_t>(kHeightSamplingRatioScalar),
  420. context->getInputValue<int32_t>(kWidthSamplingRatioScalar),
  421. context->getInputValue<bool>(kLayoutScalar),
  422. context->getOutputBuffer<uint8_t>(kOutputTensor),
  423. context->getOutputShape(kOutputTensor));
  424. default:
  425. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
  426. }
  427. }
  428. } // namespace roi_align
  429. NN_REGISTER_OPERATION(ROI_ALIGN, roi_align::kOperationName, roi_align::validate, roi_align::prepare,
  430. roi_align::execute, .allowZeroSizedInput = true);
  431. } // namespace nn
  432. } // namespace android