Operations.h 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef ANDROID_ML_NN_COMMON_OPERATIONS_H
  17. #define ANDROID_ML_NN_COMMON_OPERATIONS_H
  18. #include "operations/BidirectionalSequenceLSTM.h"
  19. #include "operations/Cast.h"
  20. #include "operations/EmbeddingLookup.h"
  21. #include "operations/ExpandDims.h"
  22. #include "operations/HashtableLookup.h"
  23. #include "operations/LSHProjection.h"
  24. #include "operations/LSTM.h"
  25. #include "operations/MaximumMinimum.h"
  26. #include "operations/Multinomial.h"
  27. #include "operations/Pow.h"
  28. #include "operations/QuantizedLSTM.h"
  29. #include "operations/RNN.h"
  30. #include "operations/SVDF.h"
  31. #include "operations/Tile.h"
  32. #include "operations/TopK_V2.h"
  33. #include <stddef.h>
  34. #include <cstdint>
  35. #include <vector>
  36. namespace android {
  37. namespace nn {
  38. struct Shape;
  39. bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape);
  40. bool floorFloat32(const float* inputData, float* outputData, const Shape& shape);
  41. bool depthwiseConvFloat16(const _Float16* inputData, const Shape& inputShape,
  42. const _Float16* filterData, const Shape& filterShape,
  43. const _Float16* biasData, const Shape& biasShape, int32_t paddingLeft,
  44. int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
  45. int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
  46. int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
  47. _Float16* outputData, const Shape& outputShape);
  48. bool depthwiseConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
  49. const Shape& filterShape, const float* biasData, const Shape& biasShape,
  50. int32_t paddingLeft, int32_t paddingRight, int32_t paddingTop,
  51. int32_t paddingBottom, int32_t strideWidth, int32_t strideHeight,
  52. int32_t dilationWidthFactor, int32_t dilationHeightFactor,
  53. int32_t depthMultiplier, int32_t activation, float* outputData,
  54. const Shape& outputShape);
  55. bool depthwiseConvQuant8(const uint8_t* inputData, const Shape& inputShape,
  56. const uint8_t* filterData, const Shape& filterShape,
  57. const int32_t* biasData, const Shape& biasShape, int32_t paddingLeft,
  58. int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
  59. int32_t strideWidth, int32_t strideHeight, int32_t dilationWidthFactor,
  60. int32_t dilationHeightFactor, int32_t depthMultiplier, int32_t activation,
  61. uint8_t* outputData, const Shape& outputShape);
  62. bool depthwiseConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape,
  63. const int8_t* filterData, const Shape& filterShape,
  64. const float* filterScales, const int32_t* biasData,
  65. const Shape& biasShape, int32_t paddingLeft,
  66. int32_t paddingRight, int32_t paddingTop, int32_t paddingBottom,
  67. int32_t strideWidth, int32_t strideHeight,
  68. int32_t dilationWidthFactor, int32_t dilationHeightFactor,
  69. int32_t depthMultiplier, int32_t activation, uint8_t* outputData,
  70. const Shape& outputShape);
  71. bool localResponseNormFloat16(const _Float16* inputData, const Shape& inputShape, int32_t radius,
  72. float bias, float alpha, float beta, int32_t axis,
  73. _Float16* outputData, const Shape& outputShape);
  74. bool localResponseNormFloat32(const float* inputData, const Shape& inputShape, int32_t radius,
  75. float bias, float alpha, float beta, int32_t axis, float* outputData,
  76. const Shape& outputShape);
  77. bool copyData(const void* inputData, const Shape& inputShape, void* outputData,
  78. const Shape& outputShape);
  79. template <typename T>
  80. bool depthToSpaceGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
  81. T* outputData, const Shape& outputShape);
  82. template <typename T>
  83. bool spaceToDepthGeneric(const T* inputData, const Shape& inputShape, int32_t blockSize,
  84. T* outputData, const Shape& outputShape);
  85. template <typename T>
  86. bool padGeneric(const T* inputData, const Shape& inputShape, const int32_t* paddings, T pad_value,
  87. T* outputData, const Shape& outputShape);
  88. template <typename T>
  89. bool batchToSpaceGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
  90. T* outputData, const Shape& outputShape);
  91. template <typename T>
  92. bool spaceToBatchGeneric(const T* inputData, const Shape& inputShape, const int32_t* blockSize,
  93. const int32_t* padding, const Shape& paddingShape, T* outputData,
  94. const Shape& outputShape);
  95. bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
  96. const Shape& axisShape, bool keepDims, _Float16* outputData,
  97. const Shape& outputShape);
  98. template <typename T, typename U>
  99. bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
  100. bool keepDims, T* outputData, const Shape& outputShape);
  101. bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
  102. const int32_t* beginData, const int32_t* endData,
  103. const int32_t* stridesData, int32_t beginMask, int32_t endMask,
  104. int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape);
  105. bool argMinMaxGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
  106. bool isArgMin, uint8_t* outputData, const Shape& outputShape);
  107. bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis,
  108. const std::vector<_Float16*>* outputDataPtrs,
  109. const std::vector<Shape>& outputShapes);
  110. bool splitFloat32(const float* inputData, const Shape& inputShape, const int32_t axis,
  111. const std::vector<float*>* outputDataPtrs,
  112. const std::vector<Shape>& outputShapes);
  113. bool splitInt32(const int32_t* inputData, const Shape& inputShape, const int32_t axis,
  114. const std::vector<int32_t*>* outputDataPtrs,
  115. const std::vector<Shape>& outputShapes);
  116. bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, const int32_t axis,
  117. const std::vector<uint8_t*>* outputDataPtrs,
  118. const std::vector<Shape>& outputShapes);
  119. bool groupedConvFloat16(const _Float16* inputData, const Shape& inputShape,
  120. const _Float16* filterData, const Shape& filterShape,
  121. const _Float16* biasData, const Shape& biasShape, int32_t numGroups,
  122. int32_t padding_left, int32_t padding_right, int32_t padding_top,
  123. int32_t padding_bottom, int32_t stride_width, int32_t stride_height,
  124. int32_t activation, _Float16* outputData, const Shape& outputShape);
  125. bool groupedConvFloat32(const float* inputData, const Shape& inputShape, const float* filterData,
  126. const Shape& filterShape, const float* biasData, const Shape& biasShape,
  127. int32_t numGroups, int32_t padding_left, int32_t padding_right,
  128. int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
  129. int32_t stride_height, int32_t activation, float* outputData,
  130. const Shape& outputShape);
  131. bool groupedConvQuant8(const uint8_t* inputData, const Shape& inputShape, const uint8_t* filterData,
  132. const Shape& filterShape, const int32_t* biasData, const Shape& biasShape,
  133. int32_t numGroups, int32_t padding_left, int32_t padding_right,
  134. int32_t padding_top, int32_t padding_bottom, int32_t stride_width,
  135. int32_t stride_height, int32_t activation, uint8_t* outputData,
  136. const Shape& outputShape);
  137. bool groupedConvQuant8PerChannel(const uint8_t* inputData, const Shape& inputShape,
  138. const int8_t* filterData, const Shape& filterShape,
  139. const float* filterScales, const int32_t* biasData,
  140. const Shape& biasShape, int32_t padding_left,
  141. int32_t padding_right, int32_t padding_top, int32_t padding_bottom,
  142. int32_t stride_width, int32_t stride_height, int32_t numGroups,
  143. int32_t activation, uint8_t* outputData, const Shape& outputShape);
  144. bool channelShuffleGeneric(const uint8_t* inputData, const Shape& inputShape, int32_t numGroups,
  145. int32_t axis, uint8_t* outputData, const Shape& outputShape);
  146. } // namespace nn
  147. } // namespace android
  148. #endif // ANDROID_ML_NN_COMMON_OPERATIONS_H