HeatmapMaxKeypoint.cpp 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "CpuOperationUtils.h"
  17. #include "OperationResolver.h"
  18. #include "OperationsUtils.h"
  19. #include <cfloat>
  20. #include <cmath>
  21. #include "Tracing.h"
  22. namespace android {
  23. namespace nn {
  24. namespace heatmap_max_keypoint {
  25. constexpr char kOperationName[] = "HEATMAP_MAX_KEYPOINT";
  26. constexpr uint32_t kNumInputs = 3;
  27. constexpr uint32_t kHeatmapTensor = 0;
  28. constexpr uint32_t kBoxesTensor = 1;
  29. constexpr uint32_t kLayoutScalar = 2;
  30. constexpr uint32_t kNumOutputs = 2;
  31. constexpr uint32_t kOutputScoreTensor = 0;
  32. constexpr uint32_t kOutputKeypointTensor = 1;
  33. namespace {
  34. // This function uses Taylor expansion up to the quatratic term to approximate bicubic
  35. // upscaling result.
  36. // 2nd order Taylor expansion: D(x) = D - b'x + 1/2 * x'Ax
  37. // where D = grid[1][1], Taylor expansion center, the original score,
  38. // x = delta, the correction on max keypoint position,
  39. // D(x) = deltaScore, the accuracy score after correction
  40. static void solveForDelta(const float grid[3][3], float* delta, float* deltaScore,
  41. float fpAtol = 1e-5f, float fpRtol = 1e-5f) {
  42. // b: negative 1st order derivative at center
  43. // A: Hessian matrix at center (2nd order derivative)
  44. float A[2][2], b[2];
  45. b[0] = -(grid[1][2] - grid[1][0]) / 2.0f;
  46. b[1] = -(grid[2][1] - grid[0][1]) / 2.0f;
  47. A[0][0] = grid[1][0] - 2.0f * grid[1][1] + grid[1][2];
  48. A[0][1] = (grid[2][2] - grid[2][0] - grid[0][2] + grid[0][0]) / 4.0f;
  49. A[1][0] = A[0][1];
  50. A[1][1] = grid[0][1] - 2.0f * grid[1][1] + grid[2][1];
  51. // solve Ax=b, where x=delta -> delta = inv(A) * b
  52. float crossProd1 = A[0][0] * A[1][1], crossProd2 = A[0][1] * A[1][0];
  53. float detA = crossProd1 - crossProd2;
  54. // check if A is invertible
  55. if (std::abs(detA) < (fpAtol + fpRtol * crossProd1)) return;
  56. delta[0] = (A[1][1] * b[0] - A[0][1] * b[1]) / detA;
  57. delta[1] = (A[0][0] * b[1] - A[1][0] * b[0]) / detA;
  58. // clip out of range delta, i.e. delta > 3/2
  59. if (std::abs(delta[0]) > 1.5f || std::abs(delta[1]) > 1.5f) {
  60. float scale = 1.5f / std::max(std::abs(delta[0]), std::abs(delta[1]));
  61. delta[0] *= scale;
  62. delta[1] *= scale;
  63. }
  64. *deltaScore = grid[1][1] - b[0] * delta[0] - b[1] * delta[1] +
  65. ((A[0][0] * delta[0] + A[0][1] * delta[1]) * delta[0] +
  66. (A[1][0] * delta[0] + A[1][1] * delta[1]) * delta[1]) /
  67. 2.0f;
  68. }
  69. inline bool heatmapMaxKeypointFloat32Nhwc(const float* heatmap, const Shape& heatmapShape,
  70. const float* boxes, const Shape& boxesShape,
  71. float* outputScoreData, const Shape& outputScoreShape,
  72. float* outputKeypointData,
  73. const Shape& outputKeypointShape, float fpAtol,
  74. float fpRtol) {
  75. NNTRACE_TRANS("HeatmapMaxKeypoint");
  76. uint32_t numBoxes = getSizeOfDimension(heatmapShape, 0);
  77. uint32_t heatmapSize = getSizeOfDimension(heatmapShape, 1);
  78. uint32_t numKeypoints = getSizeOfDimension(heatmapShape, 3);
  79. uint32_t boxInfoLength = getSizeOfDimension(boxesShape, 1);
  80. const float* heatmapBase = heatmap;
  81. const float* boxInfoBase = boxes;
  82. float* outputScoreBase = outputScoreData;
  83. float* outputKeypointBase = outputKeypointData;
  84. for (uint32_t i = 0; i < numBoxes; i++) {
  85. NN_RET_CHECK_LE(boxInfoBase[0], boxInfoBase[2]);
  86. NN_RET_CHECK_LE(boxInfoBase[1], boxInfoBase[3]);
  87. for (uint32_t j = 0; j < numKeypoints; j++) {
  88. // find max score and its index
  89. uint32_t maxIndex = 0;
  90. float maxScore = -FLT_MAX;
  91. for (uint32_t k = 0; k < heatmapSize * heatmapSize; k++) {
  92. float val = heatmapBase[k * numKeypoints + j];
  93. if (maxScore < val) {
  94. maxScore = val;
  95. maxIndex = k;
  96. }
  97. }
  98. uint32_t maxIndexWidth = maxIndex % heatmapSize;
  99. uint32_t maxIndexHeight = maxIndex / heatmapSize;
  100. // get local 3x3 grid
  101. float localGrid[3][3];
  102. for (int32_t dh = -1; dh <= 1; dh++) {
  103. for (int32_t dw = -1; dw <= 1; dw++) {
  104. // cast uint32_t to int32_t
  105. int32_t h = static_cast<int32_t>(maxIndexHeight) + dh;
  106. int32_t w = static_cast<int32_t>(maxIndexWidth) + dw;
  107. // use mirroring for out of bound indexing
  108. // need to ensure heatmapSize >= 2
  109. h = h < 0 ? 1 : (h >= heatmapSize ? heatmapSize - 2 : h);
  110. w = w < 0 ? 1 : (w >= heatmapSize ? heatmapSize - 2 : w);
  111. uint32_t heatmapIndex = static_cast<uint32_t>(h) * heatmapSize * numKeypoints +
  112. static_cast<uint32_t>(w) * numKeypoints + j;
  113. localGrid[dh + 1][dw + 1] = heatmapBase[heatmapIndex];
  114. }
  115. }
  116. float delta[2] = {0.0f, 0.0f}, deltaScore = maxScore;
  117. solveForDelta(localGrid, delta, &deltaScore, fpAtol, fpRtol);
  118. float wRoiStart = boxInfoBase[0];
  119. float hRoiStart = boxInfoBase[1];
  120. float wRoiEnd = boxInfoBase[2];
  121. float hRoiEnd = boxInfoBase[3];
  122. float roiWidth = wRoiEnd - wRoiStart;
  123. float roiHeight = hRoiEnd - hRoiStart;
  124. float wRelativePos = (static_cast<float>(maxIndexWidth) + delta[0] + 0.5f) /
  125. static_cast<float>(heatmapSize);
  126. float hRelativePos = (static_cast<float>(maxIndexHeight) + delta[1] + 0.5f) /
  127. static_cast<float>(heatmapSize);
  128. *outputScoreBase++ = deltaScore;
  129. outputKeypointBase[0] = wRelativePos * roiWidth + wRoiStart;
  130. outputKeypointBase[1] = hRelativePos * roiHeight + hRoiStart;
  131. outputKeypointBase += 2;
  132. }
  133. boxInfoBase += boxInfoLength;
  134. heatmapBase += heatmapSize * heatmapSize * numKeypoints;
  135. }
  136. return true;
  137. }
  138. inline bool heatmapMaxKeypointFloat32(const float* heatmap, const Shape& heatmapShape,
  139. const float* boxes, const Shape& boxesShape, bool layout,
  140. float* outputScoreData, const Shape& outputScoreShape,
  141. float* outputKeypointData, const Shape& outputKeypointShape,
  142. float fpAtol, float fpRtol) {
  143. std::vector<float> heatmap_nhwc;
  144. Shape heatmapShape_nhwc;
  145. if (layout) {
  146. NN_RET_CHECK(convertNchwToNhwc(heatmap, heatmapShape, &heatmap_nhwc, &heatmapShape_nhwc));
  147. }
  148. const float* heatmap_tmp = layout ? heatmap_nhwc.data() : heatmap;
  149. const Shape& heatmapShape_tmp = layout ? heatmapShape_nhwc : heatmapShape;
  150. return heatmapMaxKeypointFloat32Nhwc(heatmap_tmp, heatmapShape_tmp, boxes, boxesShape,
  151. outputScoreData, outputScoreShape, outputKeypointData,
  152. outputKeypointShape, fpAtol, fpRtol);
  153. }
  154. inline bool heatmapMaxKeypointQuant(const uint8_t* heatmap, const Shape& heatmapShape,
  155. const uint16_t* boxes, const Shape& boxesShape, bool layout,
  156. uint8_t* outputScoreData, const Shape& outputScoreShape,
  157. uint16_t* outputKeypointData, const Shape& outputKeypointShape,
  158. float fpAtol, float fpRtol) {
  159. std::vector<float> heatmap_float32(getNumberOfElements(heatmapShape));
  160. convertQuantToFloat32(heatmap, heatmapShape.scale, heatmapShape.offset, &heatmap_float32);
  161. std::vector<float> boxes_float32(getNumberOfElements(boxesShape));
  162. convertQuantToFloat32(boxes, boxesShape.scale, boxesShape.offset, &boxes_float32);
  163. std::vector<float> outputScore_float32(getNumberOfElements(outputScoreShape));
  164. std::vector<float> outputKeypoint_float32(getNumberOfElements(outputKeypointShape));
  165. NN_RET_CHECK(heatmapMaxKeypointFloat32(
  166. heatmap_float32.data(), heatmapShape, boxes_float32.data(), boxesShape, layout,
  167. outputScore_float32.data(), outputScoreShape, outputKeypoint_float32.data(),
  168. outputKeypointShape, fpAtol, fpRtol));
  169. convertFloat32ToQuant(outputScore_float32, outputScoreShape.scale, outputScoreShape.offset,
  170. outputScoreData);
  171. convertFloat32ToQuant(outputKeypoint_float32, outputKeypointShape.scale,
  172. outputKeypointShape.offset, outputKeypointData);
  173. return true;
  174. }
  175. } // namespace
  176. bool validate(const IOperationValidationContext* context) {
  177. NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
  178. NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
  179. std::vector<OperandType> inExpectedTypes;
  180. std::vector<OperandType> outExpectedTypes;
  181. auto inputType = context->getInputType(kHeatmapTensor);
  182. if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_FLOAT16) {
  183. inExpectedTypes = {inputType, inputType, OperandType::BOOL};
  184. outExpectedTypes = {inputType, inputType};
  185. } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
  186. inExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT16_ASYMM,
  187. OperandType::BOOL};
  188. outExpectedTypes = {OperandType::TENSOR_QUANT8_ASYMM, OperandType::TENSOR_QUANT16_ASYMM};
  189. } else {
  190. LOG(ERROR) << "Unsupported input tensor type for operation " << kOperationName;
  191. return false;
  192. }
  193. NN_RET_CHECK(validateInputTypes(context, inExpectedTypes));
  194. NN_RET_CHECK(validateOutputTypes(context, outExpectedTypes));
  195. return validateHalVersion(context, HalVersion::V1_2);
  196. }
  197. bool prepare(IOperationExecutionContext* context) {
  198. bool layout = context->getInputValue<bool>(kLayoutScalar);
  199. Shape heatmapShape = context->getInputShape(kHeatmapTensor);
  200. Shape boxesShape = context->getInputShape(kBoxesTensor);
  201. NN_RET_CHECK_EQ(getNumberOfDimensions(heatmapShape), 4);
  202. NN_RET_CHECK_EQ(getNumberOfDimensions(boxesShape), 2);
  203. uint32_t numBoxes = getSizeOfDimension(heatmapShape, 0);
  204. uint32_t heatmapSize = getSizeOfDimension(heatmapShape, 2);
  205. uint32_t numKeypoints = getSizeOfDimension(heatmapShape, layout ? 1 : 3);
  206. uint32_t boxInfoLength = getSizeOfDimension(boxesShape, 1);
  207. NN_RET_CHECK_EQ(getSizeOfDimension(heatmapShape, layout ? 3 : 1), heatmapSize);
  208. NN_RET_CHECK_GE(heatmapSize, 2);
  209. NN_RET_CHECK_EQ(getSizeOfDimension(boxesShape, 0), numBoxes);
  210. NN_RET_CHECK_EQ(boxInfoLength, 4);
  211. if (heatmapShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
  212. NN_RET_CHECK_EQ(boxesShape.scale, 0.125f);
  213. NN_RET_CHECK_EQ(boxesShape.offset, 0);
  214. }
  215. Shape outputScore = context->getOutputShape(kOutputScoreTensor);
  216. outputScore.type = heatmapShape.type;
  217. outputScore.dimensions = {numBoxes, numKeypoints};
  218. NN_RET_CHECK(context->setOutputShape(kOutputScoreTensor, outputScore));
  219. Shape outputKeypoint = context->getOutputShape(kOutputKeypointTensor);
  220. outputKeypoint.type = boxesShape.type;
  221. outputKeypoint.dimensions = {numBoxes, numKeypoints, 2};
  222. outputKeypoint.offset = 0;
  223. outputKeypoint.scale = 0.125f;
  224. NN_RET_CHECK(context->setOutputShape(kOutputKeypointTensor, outputKeypoint));
  225. return true;
  226. }
  227. bool execute(IOperationExecutionContext* context) {
  228. bool layout = context->getInputValue<bool>(kLayoutScalar);
  229. switch (context->getInputType(kHeatmapTensor)) {
  230. case OperandType::TENSOR_FLOAT16: {
  231. const auto heatmap = context->getInputBuffer<_Float16>(kHeatmapTensor);
  232. const auto heatmapShape = context->getInputShape(kHeatmapTensor);
  233. const auto boxes = context->getInputBuffer<_Float16>(kBoxesTensor);
  234. const auto boxesShape = context->getInputShape(kBoxesTensor);
  235. auto outputScoreData = context->getOutputBuffer<_Float16>(kOutputScoreTensor);
  236. const auto outputScoreShape = context->getOutputShape(kOutputScoreTensor);
  237. auto outputKeypointData = context->getOutputBuffer<_Float16>(kOutputKeypointTensor);
  238. const auto outputKeypointShape = context->getOutputShape(kOutputKeypointTensor);
  239. std::vector<float> heatmap_float32(getNumberOfElements(heatmapShape));
  240. convertFloat16ToFloat32(heatmap, &heatmap_float32);
  241. std::vector<float> boxes_float32(getNumberOfElements(boxesShape));
  242. convertFloat16ToFloat32(boxes, &boxes_float32);
  243. std::vector<float> outputScore_float32(getNumberOfElements(outputScoreShape));
  244. std::vector<float> outputKeypoint_float32(getNumberOfElements(outputKeypointShape));
  245. NN_RET_CHECK(heatmapMaxKeypointFloat32(
  246. heatmap_float32.data(), heatmapShape, boxes_float32.data(), boxesShape, layout,
  247. outputScore_float32.data(), outputScoreShape, outputKeypoint_float32.data(),
  248. outputKeypointShape, 1e-3f, 1e-3f));
  249. convertFloat32ToFloat16(outputScore_float32, outputScoreData);
  250. convertFloat32ToFloat16(outputKeypoint_float32, outputKeypointData);
  251. return true;
  252. }
  253. case OperandType::TENSOR_FLOAT32: {
  254. return heatmapMaxKeypointFloat32(context->getInputBuffer<float>(kHeatmapTensor),
  255. context->getInputShape(kHeatmapTensor),
  256. context->getInputBuffer<float>(kBoxesTensor),
  257. context->getInputShape(kBoxesTensor), layout,
  258. context->getOutputBuffer<float>(kOutputScoreTensor),
  259. context->getOutputShape(kOutputScoreTensor),
  260. context->getOutputBuffer<float>(kOutputKeypointTensor),
  261. context->getOutputShape(kOutputKeypointTensor), 1e-5f,
  262. 1e-5f);
  263. }
  264. case OperandType::TENSOR_QUANT8_ASYMM: {
  265. return heatmapMaxKeypointQuant(
  266. context->getInputBuffer<uint8_t>(kHeatmapTensor),
  267. context->getInputShape(kHeatmapTensor),
  268. context->getInputBuffer<uint16_t>(kBoxesTensor),
  269. context->getInputShape(kBoxesTensor), layout,
  270. context->getOutputBuffer<uint8_t>(kOutputScoreTensor),
  271. context->getOutputShape(kOutputScoreTensor),
  272. context->getOutputBuffer<uint16_t>(kOutputKeypointTensor),
  273. context->getOutputShape(kOutputKeypointTensor), 1e-5f, 1e-5f);
  274. }
  275. default:
  276. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
  277. }
  278. }
  279. } // namespace heatmap_max_keypoint
  280. NN_REGISTER_OPERATION(HEATMAP_MAX_KEYPOINT, heatmap_max_keypoint::kOperationName,
  281. heatmap_max_keypoint::validate, heatmap_max_keypoint::prepare,
  282. heatmap_max_keypoint::execute);
  283. } // namespace nn
  284. } // namespace android