StridedSlice.cpp 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. // Contains the implementation of the operations.
  17. #define LOG_TAG "Operations"
  18. #include "CpuOperationUtils.h"
  19. #include "Operations.h"
  20. #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
  21. #include "Tracing.h"
  22. namespace android {
  23. namespace nn {
  24. bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
  25. const int32_t* beginData, const int32_t* endData,
  26. const int32_t* stridesData, int32_t beginMask, int32_t endMask,
  27. int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape) {
  28. NNTRACE_TRANS("stridedSliceGeneric");
  29. // This Op only supports 1-4D cases and since we use the reference 4D
  30. // implementation, the 1-3D tensors are mapped to 4D.
  31. const int kMaxDim = 4;
  32. std::vector<int> starts;
  33. std::vector<int> stops;
  34. std::vector<int> strides;
  35. int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape));
  36. for (int32_t idx = numInputDims - 1; idx >= 0; --idx) {
  37. starts.emplace_back(beginData[idx]);
  38. stops.emplace_back(endData[idx]);
  39. strides.emplace_back(stridesData[idx]);
  40. }
  41. for (int i = numInputDims; i < kMaxDim; i++) {
  42. starts.emplace_back(0);
  43. stops.emplace_back(1);
  44. strides.emplace_back(1);
  45. }
  46. beginMask = ReverseMaskBits(beginMask, numInputDims);
  47. endMask = ReverseMaskBits(endMask, numInputDims);
  48. shrinkAxisMask = ReverseMaskBits(shrinkAxisMask, numInputDims);
  49. if (inputShape.type == OperandType::TENSOR_FLOAT32) {
  50. NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::float");
  51. tflite::reference_ops::StridedSlice(
  52. reinterpret_cast<const float*>(inputData), convertShapeToDims(inputShape),
  53. beginMask, endMask, shrinkAxisMask, starts, stops, strides,
  54. reinterpret_cast<float*>(outputData), convertShapeToDims(outputShape));
  55. } else if (inputShape.type == OperandType::TENSOR_FLOAT16) {
  56. NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::float16");
  57. tflite::reference_ops::StridedSlice(
  58. reinterpret_cast<const _Float16*>(inputData), convertShapeToDims(inputShape),
  59. beginMask, endMask, shrinkAxisMask, starts, stops, strides,
  60. reinterpret_cast<_Float16*>(outputData), convertShapeToDims(outputShape));
  61. } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
  62. NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::uint8");
  63. tflite::reference_ops::StridedSlice(
  64. reinterpret_cast<const uint8_t*>(inputData), convertShapeToDims(inputShape),
  65. beginMask, endMask, shrinkAxisMask, starts, stops, strides,
  66. reinterpret_cast<uint8_t*>(outputData), convertShapeToDims(outputShape));
  67. } else {
  68. LOG(ERROR) << "Unsupported data type";
  69. return false;
  70. }
  71. return true;
  72. }
  73. } // namespace nn
  74. } // namespace android