IndexedShapeWrapper.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef FRAMEWORKS_ML_NN_INDEXED_SHAPE_WRAPPER_H
  17. #define FRAMEWORKS_ML_NN_INDEXED_SHAPE_WRAPPER_H
  18. #include "OperationsUtils.h"
  19. namespace android {
  20. namespace nn {
  21. // A wrapper over a Shape class implementing some indexing logic for a wrapped
  22. // shape.
  23. // To get an offset for an element in a tensor from vector index, one needs to
  24. // calculate strides first. This class removes the need to recalculate strides
  25. // for every indexing and also provides some utility functions.
  26. class IndexedShapeWrapper {
  27. public:
  28. IndexedShapeWrapper(const Shape& wrapped_shape);
  29. // Calculates the next index in a lexicograpical order for a wrapped shape
  30. // inplace. Only accepts valid index for a given shape as an input.
  31. // Sets lastIndex to true if the received index was the last in a
  32. // lexicographical order for a given shape. In this case, index stays the
  33. // same.
  34. bool nextIndexInplace(std::vector<uint32_t>* index, bool* lastIndex) const;
  35. // Given an index as a vector with per-dimension indices, calculates an
  36. // offset of the element in a flattened tensor.
  37. bool indexToFlatIndex(const std::vector<uint32_t>& index, uint32_t* flatIndex) const;
  38. // Same as indexToFlatIndex, only ignores first dimensions of an index if
  39. // they are not present in the shape. Also ignores dimensions of a shape of
  40. // size 1.
  41. // For example:
  42. // for shape: [3, 1, 2]
  43. // and index: [4, 2, 5, 1]
  44. // the function will ignore dimensions with indices 4 and 5 and set
  45. // flatIndex to 5 as a result.
  46. bool broadcastedIndexToFlatIndex(const std::vector<uint32_t>& index, uint32_t* flatIndex) const;
  47. private:
  48. const Shape* const shape;
  49. std::vector<uint32_t> strides;
  50. bool isValid(const std::vector<uint32_t>& index) const;
  51. };
  52. } // namespace nn
  53. } // namespace android
  54. #endif // FRAMEWORKS_ML_NN_INDEXED_SHAPE_WRAPPER_H