ExecutionBurstServer.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. /*
  2. * Copyright (C) 2019 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
  17. #define ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H
  18. #include "HalInterfaces.h"
  19. #include <android-base/macros.h>
  20. #include <fmq/MessageQueue.h>
  21. #include <hidl/MQDescriptor.h>
  22. #include <atomic>
  23. #include <memory>
  24. #include <optional>
  25. #include <thread>
  26. #include <vector>
  27. namespace android::nn {
  28. using ::android::hardware::MQDescriptorSync;
  29. using FmqRequestDescriptor = MQDescriptorSync<FmqRequestDatum>;
  30. using FmqResultDescriptor = MQDescriptorSync<FmqResultDatum>;
  31. /**
  32. * Function to serialize results.
  33. *
  34. * Prefer calling ResultChannelSender::send.
  35. *
  36. * @param errorStatus Status of the execution.
  37. * @param outputShapes Dynamic shapes of the output tensors.
  38. * @param timing Timing information of the execution.
  39. * @return Serialized FMQ result data.
  40. */
  41. std::vector<FmqResultDatum> serialize(ErrorStatus errorStatus,
  42. const std::vector<OutputShape>& outputShapes, Timing timing);
  43. /**
  44. * Deserialize the FMQ request data.
  45. *
  46. * The three resulting fields are the Request object (where Request::pools is
  47. * empty), slot identifiers (which are stand-ins for Request::pools), and
  48. * whether timing information must be collected for the run.
  49. *
  50. * @param data Serialized FMQ request data.
  51. * @return Request object if successfully deserialized, std::nullopt otherwise.
  52. */
  53. std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> deserialize(
  54. const std::vector<FmqRequestDatum>& data);
  55. /**
  56. * RequestChannelReceiver is responsible for waiting on the channel until the
  57. * packet is available, extracting the packet from the channel, and
  58. * deserializing the packet.
  59. *
  60. * Because the receiver can wait on a packet that may never come (e.g., because
  61. * the sending side of the packet has been closed), this object can be
  62. * invalidating, unblocking the receiver.
  63. */
  64. class RequestChannelReceiver {
  65. using FmqRequestChannel =
  66. hardware::MessageQueue<FmqRequestDatum, hardware::kSynchronizedReadWrite>;
  67. public:
  68. /**
  69. * Create the receiving end of a request channel.
  70. *
  71. * Prefer this call over the constructor.
  72. *
  73. * @param requestChannel Descriptor for the request channel.
  74. * @return RequestChannelReceiver on successful creation, nullptr otherwise.
  75. */
  76. static std::unique_ptr<RequestChannelReceiver> create(
  77. const FmqRequestDescriptor& requestChannel);
  78. /**
  79. * Get the request from the channel.
  80. *
  81. * This method will block until either:
  82. * 1) The packet has been retrieved, or
  83. * 2) The receiver has been invalidated
  84. *
  85. * @return Request object if successfully received, std::nullopt if error or
  86. * if the receiver object was invalidated.
  87. */
  88. std::optional<std::tuple<Request, std::vector<int32_t>, MeasureTiming>> getBlocking();
  89. /**
  90. * Method to mark the channel as invalid, unblocking any current or future
  91. * calls to RequestChannelReceiver::getBlocking.
  92. */
  93. void invalidate();
  94. RequestChannelReceiver(std::unique_ptr<FmqRequestChannel> fmqRequestChannel, bool blocking);
  95. private:
  96. std::optional<std::vector<FmqRequestDatum>> getPacketBlocking();
  97. const std::unique_ptr<FmqRequestChannel> mFmqRequestChannel;
  98. std::atomic<bool> mTeardown{false};
  99. const bool mBlocking;
  100. };
  101. /**
  102. * ResultChannelSender is responsible for serializing the result packet of
  103. * information, sending it on the result channel, and signaling that the data is
  104. * available.
  105. */
  106. class ResultChannelSender {
  107. using FmqResultChannel =
  108. hardware::MessageQueue<FmqResultDatum, hardware::kSynchronizedReadWrite>;
  109. public:
  110. /**
  111. * Create the sending end of a result channel.
  112. *
  113. * Prefer this call over the constructor.
  114. *
  115. * @param resultChannel Descriptor for the result channel.
  116. * @return ResultChannelSender on successful creation, nullptr otherwise.
  117. */
  118. static std::unique_ptr<ResultChannelSender> create(const FmqResultDescriptor& resultChannel);
  119. /**
  120. * Send the result to the channel.
  121. *
  122. * @param errorStatus Status of the execution.
  123. * @param outputShapes Dynamic shapes of the output tensors.
  124. * @param timing Timing information of the execution.
  125. * @return 'true' on successful send, 'false' otherwise.
  126. */
  127. bool send(ErrorStatus errorStatus, const std::vector<OutputShape>& outputShapes, Timing timing);
  128. // prefer calling ResultChannelSender::send
  129. bool sendPacket(const std::vector<FmqResultDatum>& packet);
  130. ResultChannelSender(std::unique_ptr<FmqResultChannel> fmqResultChannel, bool blocking);
  131. private:
  132. const std::unique_ptr<FmqResultChannel> mFmqResultChannel;
  133. const bool mBlocking;
  134. };
  135. /**
  136. * The ExecutionBurstServer class is responsible for waiting for and
  137. * deserializing a request object from a FMQ, performing the inference, and
  138. * serializing the result back across another FMQ.
  139. */
  140. class ExecutionBurstServer : public IBurstContext {
  141. DISALLOW_IMPLICIT_CONSTRUCTORS(ExecutionBurstServer);
  142. public:
  143. /**
  144. * IBurstExecutorWithCache is a callback object passed to
  145. * ExecutionBurstServer's factory function that is used to perform an
  146. * execution. Because some memory resources are needed across multiple
  147. * executions, this object also contains a local cache that can directly be
  148. * used in the execution.
  149. *
  150. * ExecutionBurstServer will never access its IBurstExecutorWithCache object
  151. * with concurrent calls.
  152. */
  153. class IBurstExecutorWithCache {
  154. DISALLOW_COPY_AND_ASSIGN(IBurstExecutorWithCache);
  155. public:
  156. IBurstExecutorWithCache() = default;
  157. virtual ~IBurstExecutorWithCache() = default;
  158. /**
  159. * Checks if a cache entry specified by a slot is present in the cache.
  160. *
  161. * @param slot Identifier of the cache entry.
  162. * @return 'true' if the cache entry is present in the cache, 'false'
  163. * otherwise.
  164. */
  165. virtual bool isCacheEntryPresent(int32_t slot) const = 0;
  166. /**
  167. * Adds an entry specified by a slot to the cache.
  168. *
  169. * The caller of this function must ensure that the cache entry that is
  170. * being added is not already present in the cache. This can be checked
  171. * via isCacheEntryPresent.
  172. *
  173. * @param memory Memory resource to be cached.
  174. * @param slot Slot identifier corresponding to the memory resource.
  175. */
  176. virtual void addCacheEntry(const hidl_memory& memory, int32_t slot) = 0;
  177. /**
  178. * Removes an entry specified by a slot from the cache.
  179. *
  180. * If the cache entry corresponding to the slot number does not exist,
  181. * the call does nothing.
  182. *
  183. * @param slot Slot identifier corresponding to the memory resource.
  184. */
  185. virtual void removeCacheEntry(int32_t slot) = 0;
  186. /**
  187. * Perform an execution.
  188. *
  189. * @param request Request object with inputs and outputs specified.
  190. * Request::pools is empty, and DataLocation::poolIndex instead
  191. * refers to the 'slots' argument as if it were Request::pools.
  192. * @param slots Slots corresponding to the cached memory entries to be
  193. * used.
  194. * @param measure Whether timing information is requested for the
  195. * execution.
  196. * @return Result of the execution, including the status of the
  197. * execution, dynamic output shapes, and any timing information.
  198. */
  199. virtual std::tuple<ErrorStatus, hidl_vec<OutputShape>, Timing> execute(
  200. const Request& request, const std::vector<int32_t>& slots,
  201. MeasureTiming measure) = 0;
  202. };
  203. /**
  204. * Create automated context to manage FMQ-based executions.
  205. *
  206. * This function is intended to be used by a service to automatically:
  207. * 1) Receive data from a provided FMQ
  208. * 2) Execute a model with the given information
  209. * 3) Send the result to the created FMQ
  210. *
  211. * @param callback Callback used to retrieve memories corresponding to
  212. * unrecognized slots.
  213. * @param requestChannel Input FMQ channel through which the client passes the
  214. * request to the service.
  215. * @param resultChannel Output FMQ channel from which the client can retrieve
  216. * the result of the execution.
  217. * @param executorWithCache Object which maintains a local cache of the
  218. * memory pools and executes using the cached memory pools.
  219. * @result IBurstContext Handle to the burst context.
  220. */
  221. static sp<ExecutionBurstServer> create(
  222. const sp<IBurstCallback>& callback, const FmqRequestDescriptor& requestChannel,
  223. const FmqResultDescriptor& resultChannel,
  224. std::shared_ptr<IBurstExecutorWithCache> executorWithCache);
  225. /**
  226. * Create automated context to manage FMQ-based executions.
  227. *
  228. * This function is intended to be used by a service to automatically:
  229. * 1) Receive data from a provided FMQ
  230. * 2) Execute a model with the given information
  231. * 3) Send the result to the created FMQ
  232. *
  233. * @param callback Callback used to retrieve memories corresponding to
  234. * unrecognized slots.
  235. * @param requestChannel Input FMQ channel through which the client passes the
  236. * request to the service.
  237. * @param resultChannel Output FMQ channel from which the client can retrieve
  238. * the result of the execution.
  239. * @param preparedModel PreparedModel that the burst object was created from.
  240. * IPreparedModel::executeSynchronously will be used to perform the
  241. * execution.
  242. * @result IBurstContext Handle to the burst context.
  243. */
  244. static sp<ExecutionBurstServer> create(const sp<IBurstCallback>& callback,
  245. const FmqRequestDescriptor& requestChannel,
  246. const FmqResultDescriptor& resultChannel,
  247. IPreparedModel* preparedModel);
  248. ExecutionBurstServer(const sp<IBurstCallback>& callback,
  249. std::unique_ptr<RequestChannelReceiver> requestChannel,
  250. std::unique_ptr<ResultChannelSender> resultChannel,
  251. std::shared_ptr<IBurstExecutorWithCache> cachedExecutor);
  252. ~ExecutionBurstServer();
  253. // Used by the NN runtime to preemptively remove any stored memory.
  254. Return<void> freeMemory(int32_t slot) override;
  255. private:
  256. // Ensures all cache entries contained in mExecutorWithCache are present in
  257. // the cache. If they are not present, they are retrieved (via
  258. // IBurstCallback::getMemories) and added to mExecutorWithCache.
  259. //
  260. // This method is locked via mMutex when it is called.
  261. void ensureCacheEntriesArePresentLocked(const std::vector<int32_t>& slots);
  262. // Work loop that will continue processing execution requests until the
  263. // ExecutionBurstServer object is freed.
  264. void task();
  265. std::thread mWorker;
  266. std::mutex mMutex;
  267. std::atomic<bool> mTeardown{false};
  268. const sp<IBurstCallback> mCallback;
  269. const std::unique_ptr<RequestChannelReceiver> mRequestChannelReceiver;
  270. const std::unique_ptr<ResultChannelSender> mResultChannelSender;
  271. const std::shared_ptr<IBurstExecutorWithCache> mExecutorWithCache;
  272. };
  273. } // namespace android::nn
  274. #endif // ANDROID_ML_NN_COMMON_EXECUTION_BURST_SERVER_H