123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141 |
- /*
- * Copyright (C) 2018 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
- #define ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
- #include "HalInterfaces.h"
- #include "OperationsUtils.h"
- namespace android {
- namespace nn {
- // Encapsulates an operation implementation.
- struct OperationRegistration {
- OperationType type;
- const char* name;
- // Validates operand types, shapes, and any values known during graph creation.
- std::function<bool(const IOperationValidationContext*)> validate;
- // prepare is called when the inputs this operation depends on have been
- // computed. Typically, prepare does any remaining validation and sets
- // output shapes via context->setOutputShape(...).
- std::function<bool(IOperationExecutionContext*)> prepare;
- // Executes the operation, reading from context->getInputBuffer(...)
- // and writing to context->getOutputBuffer(...).
- std::function<bool(IOperationExecutionContext*)> execute;
- struct Flag {
- // Whether the operation allows at least one operand to be omitted.
- bool allowOmittedOperand = false;
- // Whether the operation allows at least one input operand to be a zero-sized tensor.
- bool allowZeroSizedInput = false;
- } flags;
- OperationRegistration(OperationType type, const char* name,
- std::function<bool(const IOperationValidationContext*)> validate,
- std::function<bool(IOperationExecutionContext*)> prepare,
- std::function<bool(IOperationExecutionContext*)> execute, Flag flags)
- : type(type),
- name(name),
- validate(validate),
- prepare(prepare),
- execute(execute),
- flags(flags) {}
- };
- // A registry of operation implementations.
- class IOperationResolver {
- public:
- virtual const OperationRegistration* findOperation(OperationType operationType) const = 0;
- virtual ~IOperationResolver() {}
- };
- // A registry of builtin operation implementations.
- //
- // Note that some operations bypass BuiltinOperationResolver (b/124041202).
- //
- // Usage:
- // const OperationRegistration* operationRegistration =
- // BuiltinOperationResolver::get()->findOperation(operationType);
- // NN_RET_CHECK(operationRegistration != nullptr);
- // NN_RET_CHECK(operationRegistration->validate != nullptr);
- // NN_RET_CHECK(operationRegistration->validate(&context));
- //
- class BuiltinOperationResolver : public IOperationResolver {
- DISALLOW_COPY_AND_ASSIGN(BuiltinOperationResolver);
- public:
- static const BuiltinOperationResolver* get() {
- static BuiltinOperationResolver instance;
- return &instance;
- }
- const OperationRegistration* findOperation(OperationType operationType) const override;
- private:
- BuiltinOperationResolver();
- void registerOperation(const OperationRegistration* operationRegistration);
- const OperationRegistration* mRegistrations[kNumberOfOperationTypes] = {};
- };
- // NN_REGISTER_OPERATION creates OperationRegistration for consumption by
- // OperationResolver.
- //
- // Usage:
- // (check OperationRegistration::Flag for available fields and default values.)
- //
- // - With default flags.
- // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
- // foo_op::prepare, foo_op::execute);
- //
- // - With a customized flag.
- // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
- // foo_op::prepare, foo_op::execute, .allowZeroSizedInput = true);
- //
- // - With multiple customized flags.
- // NN_REGISTER_OPERATION(FOO_OP, foo_op::kOperationName, foo_op::validate,
- // foo_op::prepare, foo_op::execute, .allowOmittedOperand = true,
- // .allowZeroSizedInput = true);
- //
- #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
- #define NN_REGISTER_OPERATION(identifier, operationName, validate, prepare, execute, ...) \
- const OperationRegistration* register_##identifier() { \
- static OperationRegistration registration(OperationType::identifier, operationName, \
- validate, prepare, execute, {__VA_ARGS__}); \
- return ®istration; \
- }
- #else
- // This version ignores CPU execution logic (prepare and execute).
- // The compiler is supposed to omit that code so that only validation logic
- // makes it into libneuralnetworks_utils.
- #define NN_REGISTER_OPERATION(identifier, operationName, validate, unused_prepare, unused_execute, \
- ...) \
- const OperationRegistration* register_##identifier() { \
- static OperationRegistration registration(OperationType::identifier, operationName, \
- validate, nullptr, nullptr, {__VA_ARGS__}); \
- return ®istration; \
- }
- #endif
- } // namespace nn
- } // namespace android
- #endif // ANDROID_ML_NN_COMMON_OPERATION_RESOLVER_H
|