SimpleMath.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. // Contains the implementation of the operations.
  17. #define LOG_TAG "Operations"
  18. #include "CpuOperationUtils.h"
  19. #include "Operations.h"
  20. #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
  21. #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
  22. #include "Tracing.h"
  23. namespace android {
  24. namespace nn {
  25. bool floorFloat16(const _Float16* inputData, _Float16* outputData, const Shape& shape) {
  26. NNTRACE_TRANS("floorFloat16");
  27. std::vector<float> inputDataFloat32(getNumberOfElements(shape));
  28. convertFloat16ToFloat32(inputData, &inputDataFloat32);
  29. std::vector<float> outputDataFloat32(getNumberOfElements(shape));
  30. floorFloat32(inputDataFloat32.data(), outputDataFloat32.data(), shape);
  31. convertFloat32ToFloat16(outputDataFloat32, outputData);
  32. return true;
  33. }
  34. bool floorFloat32(const float* inputData, float* outputData, const Shape& shape) {
  35. NNTRACE_TRANS("floorFloat32");
  36. tflite::Dims<4> dim = convertShapeToDims(shape);
  37. NNTRACE_COMP_SWITCH("optimized_ops::Floor");
  38. tflite::optimized_ops::Floor(inputData, dim, outputData, dim);
  39. return true;
  40. }
  41. bool meanFloat16(_Float16* inputData, const Shape& inputShape, const int32_t* axis,
  42. const Shape& axisShape, bool keepDims, _Float16* outputData,
  43. const Shape& outputShape) {
  44. NNTRACE_TRANS("meanFloat16");
  45. std::vector<float> inputDataFloat32(getNumberOfElements(inputShape));
  46. convertFloat16ToFloat32(inputData, &inputDataFloat32);
  47. std::vector<float> outputDataFloat32(getNumberOfElements(outputShape));
  48. meanGeneric<float, float>(inputDataFloat32.data(), inputShape, axis, axisShape, keepDims,
  49. outputDataFloat32.data(), outputShape);
  50. convertFloat32ToFloat16(outputDataFloat32, outputData);
  51. return true;
  52. }
  53. template <typename T, typename U>
  54. bool meanGeneric(T* inputData, const Shape& inputShape, const int32_t* axis, const Shape& axisShape,
  55. bool keepDims, T* outputData, const Shape& outputShape) {
  56. NNTRACE_TRANS("meanGeneric");
  57. // Creates a temp index to iterate through input data.
  58. int32_t* scratchBuffer = new int32_t[getNumberOfDimensions(inputShape)];
  59. // Creates a temp tensor to store resolved axis given input data.
  60. int32_t axisSize = static_cast<int32_t>(getSizeOfDimension(axisShape, 0));
  61. int32_t* resolvedAxis = new int32_t[axisSize];
  62. bool result = true;
  63. U* tempSumBuffer = new (std::nothrow) U[getNumberOfElements(outputShape)];
  64. if (!tempSumBuffer) {
  65. LOG(ERROR) << "Failed to allocate tempSumBuffer for MEAN";
  66. result = false;
  67. } else {
  68. NNTRACE_COMP_SWITCH("optimized_ops::Mean");
  69. tflite::reference_ops::Mean<T, U>(
  70. inputData, reinterpret_cast<const int*>(inputShape.dimensions.data()),
  71. getNumberOfDimensions(inputShape), outputData,
  72. reinterpret_cast<const int*>(outputShape.dimensions.data()),
  73. getNumberOfDimensions(outputShape), axis, axisSize, keepDims, scratchBuffer,
  74. resolvedAxis, tempSumBuffer);
  75. delete[] tempSumBuffer;
  76. }
  77. delete[] scratchBuffer;
  78. delete[] resolvedAxis;
  79. return result;
  80. }
  81. template bool meanGeneric<float, float>(float* inputData, const Shape& inputShape,
  82. const int32_t* axis, const Shape& axisShape, bool keepDims,
  83. float* outputData, const Shape& outputShape);
  84. template bool meanGeneric<uint8_t, int32_t>(uint8_t* inputData, const Shape& inputShape,
  85. const int32_t* axis, const Shape& axisShape,
  86. bool keepDims, uint8_t* outputData,
  87. const Shape& outputShape);
  88. } // namespace nn
  89. } // namespace android