MaximumMinimum.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #define LOG_TAG "Operations"
  17. #include "MaximumMinimum.h"
  18. #include "IndexedShapeWrapper.h"
  19. #include "OperationsUtils.h"
  20. #include "Tracing.h"
  21. namespace android {
  22. namespace nn {
  23. namespace maximum_minimum {
  24. namespace {
  25. template <typename T>
  26. bool evalGeneric(const T* aData, const Shape& aShape, const T* bData, const Shape& bShape,
  27. bool isMinimum, T* outputData, const Shape& outputShape) {
  28. IndexedShapeWrapper aShapeIndexed(aShape);
  29. IndexedShapeWrapper bShapeIndexed(bShape);
  30. IndexedShapeWrapper outputShapeIndexed(outputShape);
  31. std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
  32. bool lastIndex = false;
  33. do {
  34. uint32_t outputFlatIndex;
  35. NN_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
  36. uint32_t aFlatIndex;
  37. NN_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
  38. uint32_t bFlatIndex;
  39. NN_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
  40. outputData[outputFlatIndex] = isMinimum ? std::min(aData[aFlatIndex], bData[bFlatIndex])
  41. : std::max(aData[aFlatIndex], bData[bFlatIndex]);
  42. NN_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
  43. } while (!lastIndex);
  44. return true;
  45. }
  46. bool evalQuant8(const uint8_t* aData, const Shape& aShape, const uint8_t* bData,
  47. const Shape& bShape, bool isMinimum, uint8_t* outputData,
  48. const Shape& outputShape) {
  49. IndexedShapeWrapper aShapeIndexed(aShape);
  50. IndexedShapeWrapper bShapeIndexed(bShape);
  51. IndexedShapeWrapper outputShapeIndexed(outputShape);
  52. std::vector<uint32_t> curIndex(outputShape.dimensions.size(), 0);
  53. bool lastIndex = false;
  54. do {
  55. uint32_t outputFlatIndex;
  56. NN_CHECK(outputShapeIndexed.indexToFlatIndex(curIndex, &outputFlatIndex));
  57. uint32_t aFlatIndex;
  58. NN_CHECK(aShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &aFlatIndex));
  59. uint32_t bFlatIndex;
  60. NN_CHECK(bShapeIndexed.broadcastedIndexToFlatIndex(curIndex, &bFlatIndex));
  61. uint8_t aValue = requantize(aData[aFlatIndex], aShape, outputShape);
  62. uint8_t bValue = requantize(bData[bFlatIndex], bShape, outputShape);
  63. outputData[outputFlatIndex] =
  64. isMinimum ? std::min(aValue, bValue) : std::max(aValue, bValue);
  65. NN_CHECK(outputShapeIndexed.nextIndexInplace(&curIndex, &lastIndex));
  66. } while (!lastIndex);
  67. return true;
  68. }
  69. } // namespace
  70. bool prepare(const Shape& in1, const Shape& in2, Shape* out) {
  71. NN_CHECK(in1.type == in2.type);
  72. return calculateBroadcastedShape(in1, in2, out);
  73. }
  74. bool eval(const void* in1, const Shape& shape1, const void* in2, const Shape& shape2,
  75. bool isMinimum, void* output, const Shape& outputShape) {
  76. NNTRACE_COMP("maximum_minimum::eval");
  77. switch (shape1.type) {
  78. case OperandType::TENSOR_FLOAT16: {
  79. return evalGeneric(reinterpret_cast<const _Float16*>(in1), shape1,
  80. reinterpret_cast<const _Float16*>(in2), shape2, isMinimum,
  81. reinterpret_cast<_Float16*>(output), outputShape);
  82. }
  83. case OperandType::TENSOR_FLOAT32: {
  84. return evalGeneric(reinterpret_cast<const float*>(in1), shape1,
  85. reinterpret_cast<const float*>(in2), shape2, isMinimum,
  86. reinterpret_cast<float*>(output), outputShape);
  87. }
  88. case OperandType::TENSOR_INT32: {
  89. return evalGeneric(reinterpret_cast<const int32_t*>(in1), shape1,
  90. reinterpret_cast<const int32_t*>(in2), shape2, isMinimum,
  91. reinterpret_cast<int32_t*>(output), outputShape);
  92. }
  93. case OperandType::TENSOR_QUANT8_ASYMM: {
  94. return evalQuant8(reinterpret_cast<const uint8_t*>(in1), shape1,
  95. reinterpret_cast<const uint8_t*>(in2), shape2, isMinimum,
  96. reinterpret_cast<uint8_t*>(output), outputShape);
  97. }
  98. default: {
  99. LOG(ERROR) << "Unsupported data type: " << toString(shape1.type);
  100. return false;
  101. }
  102. }
  103. }
  104. } // namespace maximum_minimum
  105. } // namespace nn
  106. } // namespace android