OperationResolver.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
  17. #define ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
  18. #include "HalInterfaces.h"
  19. #include "OperationsUtils.h"
  20. namespace android {
  21. namespace nn {
  22. // Encapsulates an operation implementation.
  23. struct OperationRegistration {
  24. OperationType type;
  25. const char* name;
  26. // Validates operand types, shapes, and any values known during graph creation.
  27. std::function<bool(const IOperationValidationContext*)> validate;
  28. // prepare is called when the inputs this operation depends on have been
  29. // computed. Typically, prepare does any remaining validation and sets
  30. // output shapes via context->setOutputShape(...).
  31. std::function<bool(IOperationExecutionContext*)> prepare;
  32. // Executes the operation, reading from context->getInputBuffer(...)
  33. // and writing to context->getOutputBuffer(...).
  34. std::function<bool(IOperationExecutionContext*)> execute;
  35. struct Flag {
  36. // Whether the operation allows at least one operand to be omitted.
  37. bool allowOmittedOperand = false;
  38. // Whether the operation allows at least one input operand to be a zero-sized tensor.
  39. bool allowZeroSizedInput = false;
  40. } flags;
  41. OperationRegistration(OperationType type, const char* name,
  42. std::function<bool(const IOperationValidationContext*)> validate,
  43. std::function<bool(IOperationExecutionContext*)> prepare,
  44. std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
  45. : type(type),
  46. name(name),
  47. validate(validate),
  48. prepare(prepare),
  49. execute(execute),
  50. flags(flags) {}
  51. };
  52. // A registry of operation implementations.
  53. class IOperationResolver {
  54. public:
  55. virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
  56. virtual ~IOperationResolver() {}
  57. };
  58. // A registry of builtin operation implementations.
  59. //
  60. // Note that some operations bypass BuiltinOperationResolver (b/124041202).
  61. //
  62. // Usage:
  63. // const OperationRegistration* operationRegistration =
  64. // BuiltinOperationResolver::get()->findOperation(operationType);
  65. // NN_RET_CHECK(operationRegistration != nullptr);
  66. // NN_RET_CHECK(operationRegistration->validate != nullptr);
  67. // NN_RET_CHECK(operationRegistration->validate(&context));
  68. //
  69. class BuiltinOperationResolver : public IOperationResolver {
  70. DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
  71. public:
  72. static const BuiltinOperationResolver* get() {
  73. static BuiltinOperationResolver instance;
  74. return &instance;
  75. }
  76. const OperationRegistration* findOperation(OperationType operationType) const override;
  77. private:
  78. BuiltinOperationResolver();
  79. void registerOperation(const OperationRegistration* operationRegistration);
  80. const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
  81. };
  82. // NN_REGISTER_OPERATION creates OperationRegistration for consumption by
  83. // OperationResolver.
  84. //
  85. // Usage:
  86. // (check OperationRegistration::Flag for available fields and default values.)
  87. //
  88. // - With default flags.
  89. // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
  90. // foo_op::prepare, foo_op::execute);
  91. //
  92. // - With a customized flag.
  93. // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
  94. // foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
  95. //
  96. // - With multiple customized flags.
  97. // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
  98. // foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
  99. // .allowZeroSizedInput = true);
  100. //
  101. #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
  102. #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \
  103. const OperationRegistration* register_##identifier() { \
  104. static OperationRegistration registration(OperationType::identifier, operationName, \
  105. validate, prepare, execute, {__VA_ARGS__}); \
  106. return &registration; \
  107. }
  108. #else
  109. // This version ignores CPU execution logic (prepare and execute).
  110. // The compiler is supposed to omit that code so that only validation logic
  111. // makes it into libneuralnetworks_utils.
  112. #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
  113. ...) \
  114. const OperationRegistration* register_##identifier() { \
  115. static OperationRegistration registration(OperationType::identifier, operationName, \
  116. validate, nullptr, nullptr, {__VA_ARGS__}); \
  117. return &registration; \
  118. }
  119. #endif
  120. } // namespace nn
  121. } // namespace android
  122. #endif // ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H