Transpose.cpp 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. /*
  2. * Copyright (C) 2019 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "CpuOperationUtils.h"
  17. #include "OperationResolver.h"
  18. #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
  19. #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
  20. #include "Tracing.h"
  21. namespace android {
  22. namespace nn {
  23. namespace transpose {
  24. constexpr char kOperationName[] = "TRANSPOSE";
  25. constexpr uint32_t kNumInputs = 2;
  26. constexpr uint32_t kInputTensor = 0;
  27. constexpr uint32_t kPermTensor = 1;
  28. constexpr uint32_t kNumOutputs = 1;
  29. constexpr uint32_t kOutputTensor = 0;
  30. namespace {
  31. template <typename T>
  32. bool transposeGeneric(const T* inputData, const Shape& inputShape, const int32_t* perm,
  33. const Shape& permShape, T* outputData, const Shape& outputShape) {
  34. NNTRACE_TRANS("transposeGeneric");
  35. // Reverse the permuted axes and convert to 4D due to the way Dims are
  36. // constructed.
  37. const int32_t kOutputDimensionNum = 4;
  38. // permData can be NO_VALUE representing a regular 2D matrix transpose
  39. int32_t permSize = perm == nullptr ? 2 : static_cast<int32_t>(getSizeOfDimension(permShape, 0));
  40. int32_t perm_tmp[2] = {1, 0};
  41. if (perm == nullptr) {
  42. perm = perm_tmp;
  43. }
  44. int32_t reversed_perm[kOutputDimensionNum];
  45. for (int32_t output_k = 0, input_k = permSize - 1; output_k < permSize; ++output_k, --input_k) {
  46. reversed_perm[output_k] = permSize - perm[input_k] - 1;
  47. }
  48. for (int32_t k = permSize; k < kOutputDimensionNum; ++k) {
  49. reversed_perm[k] = k;
  50. }
  51. NNTRACE_COMP_SWITCH("reference_ops::Transpose");
  52. tflite::reference_ops::Transpose(inputData, convertShapeToDims(inputShape), outputData,
  53. convertShapeToDims(outputShape), reversed_perm);
  54. return true;
  55. }
  56. } // namespace
  57. bool validate(const IOperationValidationContext* context) {
  58. NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
  59. NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
  60. const OperandType inputType = context->getInputType(kInputTensor);
  61. if (inputType == OperandType::TENSOR_FLOAT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM) {
  62. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_1));
  63. } else if (inputType == OperandType::TENSOR_FLOAT16) {
  64. NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
  65. } else {
  66. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
  67. }
  68. return validateInputTypes(context, {inputType, OperandType::TENSOR_INT32}) &&
  69. validateOutputTypes(context, {inputType});
  70. }
  71. bool prepare(IOperationExecutionContext* context) {
  72. // Only the permutation tensor can be omitted.
  73. NN_RET_CHECK(!context->isOmittedInput(kInputTensor));
  74. NN_RET_CHECK(!context->isOmittedOutput(kOutputTensor));
  75. const Shape& input = context->getInputShape(kInputTensor);
  76. uint32_t numInputDims = getNumberOfDimensions(input);
  77. Shape output = context->getOutputShape(kOutputTensor);
  78. output.type = input.type;
  79. output.offset = input.offset;
  80. output.scale = input.scale;
  81. // permData can be NO_VALUE representing a regular 2D matrix transpose
  82. if (context->isOmittedInput(kPermTensor)) {
  83. NN_RET_CHECK_EQ(numInputDims, 2);
  84. output.dimensions = {getSizeOfDimension(input, 1), getSizeOfDimension(input, 0)};
  85. } else {
  86. const Shape& permShape = context->getInputShape(kPermTensor);
  87. const int32_t* permData = context->getInputBuffer<int32_t>(kPermTensor);
  88. // Transpose op only supports 1D-4D input arrays.
  89. NN_RET_CHECK_LE(numInputDims, 4);
  90. // perm need to be provided as a 1-D int32 tensor.
  91. NN_RET_CHECK(permShape.type == OperandType::TENSOR_INT32);
  92. NN_RET_CHECK_EQ(getNumberOfDimensions(permShape), 1);
  93. NN_RET_CHECK_EQ(numInputDims, getSizeOfDimension(permShape, 0));
  94. std::vector<uint32_t> outDims(numInputDims);
  95. for (int32_t idx = 0; idx < static_cast<int32_t>(numInputDims); ++idx) {
  96. NN_RET_CHECK(permData[idx] >= 0 && permData[idx] < static_cast<int32_t>(numInputDims));
  97. outDims[idx] = getSizeOfDimension(input, permData[idx]);
  98. }
  99. output.dimensions = outDims;
  100. }
  101. return context->setOutputShape(kOutputTensor, output);
  102. }
  103. bool execute(IOperationExecutionContext* context) {
  104. // Bypass execution in the case of zero-sized input.
  105. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  106. switch (context->getInputType(kInputTensor)) {
  107. case OperandType::TENSOR_FLOAT32:
  108. return transposeGeneric(context->getInputBuffer<float>(kInputTensor),
  109. context->getInputShape(kInputTensor),
  110. context->getInputBuffer<int32_t>(kPermTensor),
  111. context->getInputShape(kPermTensor),
  112. context->getOutputBuffer<float>(kOutputTensor),
  113. context->getOutputShape(kOutputTensor));
  114. case OperandType::TENSOR_FLOAT16:
  115. return transposeGeneric(context->getInputBuffer<_Float16>(kInputTensor),
  116. context->getInputShape(kInputTensor),
  117. context->getInputBuffer<int32_t>(kPermTensor),
  118. context->getInputShape(kPermTensor),
  119. context->getOutputBuffer<_Float16>(kOutputTensor),
  120. context->getOutputShape(kOutputTensor));
  121. case OperandType::TENSOR_QUANT8_ASYMM:
  122. return transposeGeneric(context->getInputBuffer<uint8_t>(kInputTensor),
  123. context->getInputShape(kInputTensor),
  124. context->getInputBuffer<int32_t>(kPermTensor),
  125. context->getInputShape(kPermTensor),
  126. context->getOutputBuffer<uint8_t>(kOutputTensor),
  127. context->getOutputShape(kOutputTensor));
  128. default:
  129. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
  130. }
  131. }
  132. } // namespace transpose
  133. NN_REGISTER_OPERATION(TRANSPOSE, transpose::kOperationName, transpose::validate, transpose::prepare,
  134. transpose::execute, .allowOmittedOperand = true, .allowZeroSizedInput = true);
  135. } // namespace nn
  136. } // namespace android