Reshape.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. // Contains the implementation of the operations.
  17. #define LOG_TAG "Operations"
  18. #include "CpuOperationUtils.h"
  19. #include "Operations.h"
  20. #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
  21. #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
  22. #include "Tracing.h"
  23. namespace android {
  24. namespace nn {
  25. bool copyData(const void* inputData, const Shape& inputShape, void* outputData,
  26. const Shape& outputShape) {
  27. NNTRACE_COMP("copyData");
  28. size_t count = nonExtensionOperandSizeOfData(inputShape.type, inputShape.dimensions);
  29. memcpy(outputData, inputData, count);
  30. return true;
  31. }
  32. template <typename T>
  33. bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
  34. T* outputData, const Shape& outputShape) {
  35. NNTRACE_COMP("optimized_ops::DepthToSpace");
  36. tflite::optimized_ops::DepthToSpace(inputData, convertShapeToDims(inputShape), blockSize,
  37. outputData, convertShapeToDims(outputShape));
  38. return true;
  39. }
  40. template bool depthToSpaceGeneric<float>(const float* inputData, const Shape& inputShape,
  41. int32_t blockSize, float* outputData,
  42. const Shape& outputShape);
  43. template bool depthToSpaceGeneric<_Float16>(const _Float16* inputData, const Shape& inputShape,
  44. int32_t blockSize, _Float16* outputData,
  45. const Shape& outputShape);
  46. template bool depthToSpaceGeneric<uint8_t>(const uint8_t* inputData, const Shape& inputShape,
  47. int32_t blockSize, uint8_t* outputData,
  48. const Shape& outputShape);
  49. template <typename T>
  50. bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
  51. T* outputData, const Shape& outputShape) {
  52. NNTRACE_COMP("optimized_ops::SpaceToDepth");
  53. tflite::optimized_ops::SpaceToDepth(inputData, convertShapeToDims(inputShape), blockSize,
  54. outputData, convertShapeToDims(outputShape));
  55. return true;
  56. }
  57. template bool spaceToDepthGeneric<float>(const float* inputData, const Shape& inputShape,
  58. int32_t blockSize, float* outputData,
  59. const Shape& outputShape);
  60. template bool spaceToDepthGeneric<_Float16>(const _Float16* inputData, const Shape& inputShape,
  61. int32_t blockSize, _Float16* outputData,
  62. const Shape& outputShape);
  63. template bool spaceToDepthGeneric<uint8_t>(const uint8_t* inputData, const Shape& inputShape,
  64. int32_t blockSize, uint8_t* outputData,
  65. const Shape& outputShape);
  66. template <typename T>
  67. bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T padValue,
  68. T* outputData, const Shape& outputShape) {
  69. NNTRACE_TRANS("padGeneric");
  70. // Based on
  71. // http://google3/third_party/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h?l=6194&rcl=213557260
  72. // TFLite runtime calls are currently fixed at 4 dimensions. Copy inputs so
  73. // we can pad them to 4 dims (yes, we are "padding the padding").
  74. int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape));
  75. NN_OPS_CHECK(numInputDims <= 4);
  76. std::vector<int> leftPaddings(4 - numInputDims, 0);
  77. std::vector<int> rightPaddings(4 - numInputDims, 0);
  78. for (int32_t i = 0; i < numInputDims; ++i) {
  79. leftPaddings.push_back(paddings[i * 2]);
  80. rightPaddings.push_back(paddings[i * 2 + 1]);
  81. }
  82. const int leftBPadding = leftPaddings[0];
  83. const int leftHPadding = leftPaddings[1];
  84. const int leftWPadding = leftPaddings[2];
  85. const int leftDPadding = leftPaddings[3];
  86. const int rightBPadding = rightPaddings[0];
  87. const int rightHPadding = rightPaddings[1];
  88. const int rightWPadding = rightPaddings[2];
  89. const int rightDPadding = rightPaddings[3];
  90. const auto extInputShape =
  91. tflite::RuntimeShape::ExtendedShape(4, convertShapeToTflshape(inputShape));
  92. const auto extOutputShape =
  93. tflite::RuntimeShape::ExtendedShape(4, convertShapeToTflshape(outputShape));
  94. const int outputBatch = extOutputShape.Dims(0);
  95. const int outputHeight = extOutputShape.Dims(1);
  96. const int outputWidth = extOutputShape.Dims(2);
  97. const int outputDepth = extOutputShape.Dims(3);
  98. const int inputDepth = extInputShape.Dims(3);
  99. NNTRACE_COMP_SWITCH("padGeneric");
  100. if (leftBPadding != 0) {
  101. tflite::optimized_ops::TypedMemset<T>(
  102. outputData, padValue, leftBPadding * outputHeight * outputWidth * outputDepth);
  103. }
  104. for (int outB = leftBPadding; outB < outputBatch - rightBPadding; ++outB) {
  105. if (leftHPadding != 0) {
  106. tflite::optimized_ops::TypedMemset<T>(
  107. outputData + tflite::Offset(extOutputShape, outB, 0, 0, 0), padValue,
  108. leftHPadding * outputWidth * outputDepth);
  109. }
  110. for (int outH = leftHPadding; outH < outputHeight - rightHPadding; ++outH) {
  111. if (leftWPadding != 0) {
  112. tflite::optimized_ops::TypedMemset<T>(
  113. outputData + tflite::Offset(extOutputShape, outB, outH, 0, 0), padValue,
  114. leftWPadding * outputDepth);
  115. }
  116. for (int outW = leftWPadding; outW < outputWidth - rightWPadding; ++outW) {
  117. if (leftDPadding != 0) {
  118. tflite::optimized_ops::TypedMemset<T>(
  119. outputData + tflite::Offset(extOutputShape, outB, outH, outW, 0),
  120. padValue, leftDPadding);
  121. }
  122. T* out =
  123. outputData + tflite::Offset(extOutputShape, outB, outH, outW, leftDPadding);
  124. const T* in =
  125. inputData + tflite::Offset(extInputShape, outB - leftBPadding,
  126. outH - leftHPadding, outW - leftWPadding, 0);
  127. memcpy(out, in, inputDepth * sizeof(T));
  128. if (rightDPadding != 0) {
  129. tflite::optimized_ops::TypedMemset<T>(
  130. outputData + tflite::Offset(extOutputShape, outB, outH, outW,
  131. outputDepth - rightDPadding),
  132. padValue, rightDPadding);
  133. }
  134. }
  135. if (rightWPadding != 0) {
  136. tflite::optimized_ops::TypedMemset<T>(
  137. outputData + tflite::Offset(extOutputShape, outB, outH,
  138. outputWidth - rightWPadding, 0),
  139. padValue, rightWPadding * outputDepth);
  140. }
  141. }
  142. if (rightHPadding != 0) {
  143. tflite::optimized_ops::TypedMemset<T>(
  144. outputData + tflite::Offset(extOutputShape, outB, outputHeight - rightHPadding,
  145. 0, 0),
  146. padValue, rightHPadding * outputWidth * outputDepth);
  147. }
  148. }
  149. if (rightBPadding != 0) {
  150. tflite::optimized_ops::TypedMemset<T>(
  151. outputData + tflite::Offset(extOutputShape, outputBatch - rightBPadding, 0, 0, 0),
  152. padValue, rightBPadding * outputHeight * outputWidth * outputDepth);
  153. }
  154. return true;
  155. }
  156. template bool padGeneric<float>(const float* inputData, const Shape& inputShape,
  157. const int32_t* paddings, float padValue, float* outputData,
  158. const Shape& outputShape);
  159. template bool padGeneric<_Float16>(const _Float16* inputData, const Shape& inputShape,
  160. const int32_t* paddings, _Float16 padValue, _Float16* outputData,
  161. const Shape& outputShape);
  162. template bool padGeneric<uint8_t>(const uint8_t* inputData, const Shape& inputShape,
  163. const int32_t* paddings, uint8_t padValue, uint8_t* outputData,
  164. const Shape& outputShape);
  165. template <typename T>
  166. bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
  167. T* outputData, const Shape& outputShape) {
  168. // Needed by low level implementation, but not really used.
  169. tflite::Dims<4> blockSizeDim, cropsDim;
  170. const int32 crops[4] = {0, 0, 0, 0};
  171. NNTRACE_COMP("optimized_ops::BatchToSpaceND");
  172. tflite::optimized_ops::BatchToSpaceND(inputData, convertShapeToDims(inputShape), blockSize,
  173. blockSizeDim, crops, cropsDim, outputData,
  174. convertShapeToDims(outputShape));
  175. return true;
  176. }
  177. template bool batchToSpaceGeneric<float>(const float* inputData, const Shape& inputShape,
  178. const int32_t* blockSize, float* outputData,
  179. const Shape& outputShape);
  180. template bool batchToSpaceGeneric<_Float16>(const _Float16* inputData, const Shape& inputShape,
  181. const int32_t* blockSize, _Float16* outputData,
  182. const Shape& outputShape);
  183. template bool batchToSpaceGeneric<uint8_t>(const uint8_t* inputData, const Shape& inputShape,
  184. const int32_t* blockSize, uint8_t* outputData,
  185. const Shape& outputShape);
  186. template <typename T>
  187. bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
  188. const int32_t* padding, const Shape& paddingShape, T* outputData,
  189. const Shape& outputShape) {
  190. // Needed by low level implementation, but not really used.
  191. tflite::RuntimeShape blockSizeDim;
  192. NNTRACE_COMP("optimized_ops::SpaceToBatchND");
  193. tflite::optimized_ops::SpaceToBatchND(
  194. {.output_offset = outputShape.offset}, convertShapeToTflshape(inputShape), inputData,
  195. blockSizeDim, blockSize, convertShapeToTflshape(paddingShape), padding,
  196. convertShapeToTflshape(outputShape), outputData);
  197. return true;
  198. }
  199. template bool spaceToBatchGeneric<float>(const float* inputData, const Shape& inputShape,
  200. const int32_t* blockSize, const int32_t* padding,
  201. const Shape& paddingShape, float* outputData,
  202. const Shape& outputShape);
  203. template bool spaceToBatchGeneric<_Float16>(const _Float16* inputData, const Shape& inputShape,
  204. const int32_t* blockSize, const int32_t* padding,
  205. const Shape& paddingShape, _Float16* outputData,
  206. const Shape& outputShape);
  207. template bool spaceToBatchGeneric<uint8_t>(const uint8_t* inputData, const Shape& inputShape,
  208. const int32_t* blockSize, const int32_t* padding,
  209. const Shape& paddingShape, uint8_t* outputData,
  210. const Shape& outputShape);
  211. } // namespace nn
  212. } // namespace android