12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- /*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- // Contains the implementation of the operations.
- #define LOG_TAG "Operations"
- #include "CpuOperationUtils.h"
- #include "Operations.h"
- #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
- #include "Tracing.h"
- namespace android {
- namespace nn {
- bool stridedSliceGeneric(const uint8_t* inputData, const Shape& inputShape,
- const int32_t* beginData, const int32_t* endData,
- const int32_t* stridesData, int32_t beginMask, int32_t endMask,
- int32_t shrinkAxisMask, uint8_t* outputData, const Shape& outputShape) {
- NNTRACE_TRANS("stridedSliceGeneric");
- // This Op only supports 1-4D cases and since we use the reference 4D
- // implementation, the 1-3D tensors are mapped to 4D.
- const int kMaxDim = 4;
- std::vector<int> starts;
- std::vector<int> stops;
- std::vector<int> strides;
- int32_t numInputDims = static_cast<int32_t>(getNumberOfDimensions(inputShape));
- for (int32_t idx = numInputDims - 1; idx >= 0; --idx) {
- starts.emplace_back(beginData[idx]);
- stops.emplace_back(endData[idx]);
- strides.emplace_back(stridesData[idx]);
- }
- for (int i = numInputDims; i < kMaxDim; i++) {
- starts.emplace_back(0);
- stops.emplace_back(1);
- strides.emplace_back(1);
- }
- beginMask = ReverseMaskBits(beginMask, numInputDims);
- endMask = ReverseMaskBits(endMask, numInputDims);
- shrinkAxisMask = ReverseMaskBits(shrinkAxisMask, numInputDims);
- if (inputShape.type == OperandType::TENSOR_FLOAT32) {
- NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::float");
- tflite::reference_ops::StridedSlice(
- reinterpret_cast<const float*>(inputData), convertShapeToDims(inputShape),
- beginMask, endMask, shrinkAxisMask, starts, stops, strides,
- reinterpret_cast<float*>(outputData), convertShapeToDims(outputShape));
- } else if (inputShape.type == OperandType::TENSOR_FLOAT16) {
- NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::float16");
- tflite::reference_ops::StridedSlice(
- reinterpret_cast<const _Float16*>(inputData), convertShapeToDims(inputShape),
- beginMask, endMask, shrinkAxisMask, starts, stops, strides,
- reinterpret_cast<_Float16*>(outputData), convertShapeToDims(outputShape));
- } else if (inputShape.type == OperandType::TENSOR_QUANT8_ASYMM) {
- NNTRACE_COMP_SWITCH("reference_ops::StridedSlice::uint8");
- tflite::reference_ops::StridedSlice(
- reinterpret_cast<const uint8_t*>(inputData), convertShapeToDims(inputShape),
- beginMask, endMask, shrinkAxisMask, starts, stops, strides,
- reinterpret_cast<uint8_t*>(outputData), convertShapeToDims(outputShape));
- } else {
- LOG(ERROR) << "Unsupported data type";
- return false;
- }
- return true;
- }
- } // namespace nn
- } // namespace android
|