Split.cpp 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #define LOG_TAG "Operations"
  17. #include "Operations.h"
  18. #include "OperationsUtils.h"
  19. #include "Tracing.h"
  20. namespace android {
  21. namespace nn {
  22. template <typename Scalar>
  23. bool splitGeneric(const Scalar* inputData, const Shape& inputShape, int32_t axis,
  24. const std::vector<Scalar*>* outputDataPtrs,
  25. const std::vector<Shape>& outputShapes) {
  26. NN_CHECK(handleNegativeAxis(inputShape, &axis));
  27. int outerSize = 1;
  28. for (int i = 0; i < axis; ++i) {
  29. outerSize *= inputShape.dimensions[i];
  30. }
  31. int baseInnerSize = 1;
  32. int concatDimensions = getNumberOfDimensions(inputShape);
  33. for (int i = axis + 1; i < concatDimensions; ++i) {
  34. baseInnerSize *= inputShape.dimensions[i];
  35. }
  36. const Scalar* inputPtr = inputData;
  37. for (int k = 0; k < outerSize; k++) {
  38. for (int i = 0; i < outputDataPtrs->size(); ++i) {
  39. const int copySize = outputShapes[i].dimensions[axis] * baseInnerSize;
  40. memcpy(outputDataPtrs->at(i) + k * copySize, inputPtr, copySize * sizeof(Scalar));
  41. inputPtr += copySize;
  42. }
  43. }
  44. return true;
  45. }
  46. bool splitFloat16(const _Float16* inputData, const Shape& inputShape, int32_t axis,
  47. const std::vector<_Float16*>* outputDataPtrs,
  48. const std::vector<Shape>& outputShapes) {
  49. NNTRACE_COMP("splitFloat16");
  50. return splitGeneric<_Float16>(inputData, inputShape, axis, outputDataPtrs, outputShapes);
  51. }
  52. bool splitFloat32(const float* inputData, const Shape& inputShape, int32_t axis,
  53. const std::vector<float*>* outputDataPtrs,
  54. const std::vector<Shape>& outputShapes) {
  55. NNTRACE_COMP("splitFloat32");
  56. return splitGeneric<float>(inputData, inputShape, axis, outputDataPtrs, outputShapes);
  57. }
  58. bool splitQuant8(const uint8_t* inputData, const Shape& inputShape, int32_t axis,
  59. const std::vector<uint8_t*>* outputDataPtrs,
  60. const std::vector<Shape>& outputShapes) {
  61. NNTRACE_COMP("splitQuant8");
  62. return splitGeneric<uint8_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes);
  63. }
  64. bool splitInt32(const int32_t* inputData, const Shape& inputShape, int32_t axis,
  65. const std::vector<int32_t*>* outputDataPtrs,
  66. const std::vector<Shape>& outputShapes) {
  67. NNTRACE_COMP("splitInt32");
  68. return splitGeneric<int32_t>(inputData, inputShape, axis, outputDataPtrs, outputShapes);
  69. }
  70. } // namespace nn
  71. } // namespace android