SVDF.h 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef FRAMEWORKS_ML_NN_SVDF_H
  17. #define FRAMEWORKS_ML_NN_SVDF_H
  18. #include "HalOperation.h"
  19. #include "tensorflow/lite/kernels/internal/tensor_utils.h"
  20. #include <algorithm>
  21. #include <cmath>
  22. namespace android {
  23. namespace nn {
  24. struct SVDFParams {
  25. int rank_;
  26. TfLiteFusedActivation activation_;
  27. };
  28. struct RunTimeOperandInfo;
  29. struct Shape;
  30. class SVDF {
  31. public:
  32. SVDF(const Operation& operation, std::vector<RunTimeOperandInfo>& operands);
  33. static bool Prepare(const Operation& operation, std::vector<RunTimeOperandInfo>& operands,
  34. Shape* stateShape, Shape* outputShape);
  35. bool Eval();
  36. static constexpr int kInputTensor = 0;
  37. static constexpr int kWeightsFeatureTensor = 1;
  38. static constexpr int kWeightsTimeTensor = 2;
  39. static constexpr int kBiasTensor = 3; // Optional
  40. static constexpr int kStateInTensor = 4;
  41. static constexpr int kRankParam = 5;
  42. static constexpr int kActivationParam = 6;
  43. static constexpr int kStateOutTensor = 0;
  44. static constexpr int kOutputTensor = 1;
  45. private:
  46. void EvalFloat32(const float* inputData, const float* inputStateData, const float* biasData,
  47. const float* weightsFeatureData, const float* weightsTimeData,
  48. float* outputData, float* outputStateData);
  49. SVDFParams params_;
  50. const RunTimeOperandInfo* input_;
  51. const RunTimeOperandInfo* weights_feature_;
  52. const RunTimeOperandInfo* weights_time_;
  53. const RunTimeOperandInfo* bias_;
  54. const RunTimeOperandInfo* state_in_;
  55. RunTimeOperandInfo* state_out_;
  56. RunTimeOperandInfo* output_;
  57. };
  58. } // namespace nn
  59. } // namespace android
  60. #endif