Broadcast.cpp 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. // Contains the implementation of the operations.
  17. #define LOG_TAG "Operations"
  18. #include "CpuOperationUtils.h"
  19. #include "OperationResolver.h"
  20. #include "tensorflow/lite/kernels/internal/optimized/legacy_optimized_ops.h"
  21. #include "tensorflow/lite/kernels/internal/reference/legacy_reference_ops.h"
  22. #include "Tracing.h"
  23. #include <algorithm>
  24. namespace android {
  25. namespace nn {
  26. namespace broadcast {
  27. constexpr uint32_t kNumInputs = 3;
  28. constexpr uint32_t kInputTensor1 = 0;
  29. constexpr uint32_t kInputTensor2 = 1;
  30. constexpr uint32_t kActivationScalar = 2;
  31. constexpr uint32_t kNumOutputs = 1;
  32. constexpr uint32_t kOutputTensor = 0;
  33. namespace {
  34. #define ANDROID_NN_MACRO_DISPATCH(macro) \
  35. switch (activation) { \
  36. case (int32_t)FusedActivationFunc::NONE: \
  37. macro(kNone); \
  38. break; \
  39. case (int32_t)FusedActivationFunc::RELU: \
  40. macro(kRelu); \
  41. break; \
  42. case (int32_t)FusedActivationFunc::RELU1: \
  43. macro(kRelu1); \
  44. break; \
  45. case (int32_t)FusedActivationFunc::RELU6: \
  46. macro(kRelu6); \
  47. break; \
  48. default: \
  49. LOG(ERROR) << "Unsupported fused activation function type"; \
  50. return false; \
  51. }
  52. using binaryFunctionFloat32 = std::function<bool(
  53. const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
  54. int32_t activation, float* out, const Shape& shapeOut)>;
  55. bool binaryOperationFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2,
  56. const Shape& shape2, int32_t activation, _Float16* out,
  57. const Shape& shapeOut, binaryFunctionFloat32 operationFloat32) {
  58. std::vector<float> in1_float32(getNumberOfElements(shape1));
  59. convertFloat16ToFloat32(in1, &in1_float32);
  60. std::vector<float> in2_float32(getNumberOfElements(shape2));
  61. convertFloat16ToFloat32(in2, &in2_float32);
  62. std::vector<float> out_float32(getNumberOfElements(shapeOut));
  63. operationFloat32(in1_float32.data(), shape1, in2_float32.data(), shape2, activation,
  64. out_float32.data(), shapeOut);
  65. convertFloat32ToFloat16(out_float32, out);
  66. return true;
  67. }
  68. bool addFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
  69. int32_t activation, float* out, const Shape& shapeOut) {
  70. NNTRACE_TRANS("addFloat32");
  71. bool needBroadcast = !SameShape(shape1, shape2);
  72. if (needBroadcast) {
  73. NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
  74. #define ANDROID_NN_BROADCAST_ADD(activation) \
  75. tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
  76. in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
  77. convertShapeToDims(shapeOut))
  78. ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
  79. #undef ANDROID_NN_BROADCAST_ADD
  80. } else {
  81. NNTRACE_COMP_SWITCH("optimized_ops::Add");
  82. #define ANDROID_NN_ADD(activation) \
  83. tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
  84. in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
  85. convertShapeToDims(shapeOut))
  86. ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_ADD)
  87. #undef ANDROID_NN_ADD
  88. }
  89. return true;
  90. }
  91. bool addFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
  92. int32_t activation, _Float16* out, const Shape& shapeOut) {
  93. NNTRACE_TRANS("addFloat16");
  94. return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &addFloat32);
  95. }
  96. bool addQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
  97. int32_t activation, uint8_t* out, const Shape& shapeOut) {
  98. NNTRACE_TRANS("addQuant8");
  99. bool needBroadcast = !SameShape(shape1, shape2);
  100. const int32_t input1_offset = -shape1.offset;
  101. const int32_t input2_offset = -shape2.offset;
  102. const int32_t output_offset = shapeOut.offset;
  103. const int left_shift = 20;
  104. const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
  105. const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
  106. const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
  107. const double real_output_multiplier =
  108. twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
  109. int32_t input1_multiplier;
  110. int32_t input1_shift;
  111. if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
  112. &input1_shift)) {
  113. return false;
  114. }
  115. int32_t input2_multiplier;
  116. int32_t input2_shift;
  117. if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
  118. &input2_shift)) {
  119. return false;
  120. }
  121. int32_t output_multiplier;
  122. int32_t output_shift;
  123. if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
  124. &output_shift)) {
  125. return false;
  126. }
  127. int32_t output_activation_min;
  128. int32_t output_activation_max;
  129. CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
  130. &output_activation_max);
  131. if (needBroadcast) {
  132. NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
  133. #define ANDROID_NN_BROADCAST_ADD(activation) \
  134. tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
  135. left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier, \
  136. input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier, \
  137. input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
  138. output_activation_max, out, convertShapeToDims(shapeOut))
  139. ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
  140. #undef ANDROID_NN_BROADCAST_ADD
  141. } else {
  142. NNTRACE_COMP_SWITCH("optimized_ops::Add");
  143. #define ANDROID_NN_NORMAL_ADD(activation) \
  144. tflite::optimized_ops::Add<tflite::FusedActivationFunctionType::activation>( \
  145. left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier, \
  146. input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier, \
  147. input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
  148. output_activation_max, out, convertShapeToDims(shapeOut))
  149. ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_NORMAL_ADD)
  150. #undef ANDROID_NN_NORMAL_ADD
  151. }
  152. return true;
  153. }
  154. bool mulFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
  155. int32_t activation, float* out, const Shape& shapeOut) {
  156. NNTRACE_TRANS("mulFloat32");
  157. bool needBroadcast = !SameShape(shape1, shape2);
  158. if (needBroadcast) {
  159. NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
  160. #define ANDROID_NN_BROADCAST_MUL(activation) \
  161. tflite::optimized_ops::BroadcastMul<tflite::FusedActivationFunctionType::activation>( \
  162. in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2), out, \
  163. convertShapeToDims(shapeOut))
  164. ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_MUL)
  165. #undef ANDROID_NN_BROADCAST_MUL
  166. } else {
  167. float output_activation_min, output_activation_max;
  168. CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
  169. NNTRACE_COMP_SWITCH("optimized_ops::Mul");
  170. tflite::optimized_ops::Mul(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
  171. output_activation_min, output_activation_max, out,
  172. convertShapeToDims(shapeOut));
  173. }
  174. return true;
  175. }
  176. bool mulFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
  177. int32_t activation, _Float16* out, const Shape& shapeOut) {
  178. NNTRACE_TRANS("mulFloat16");
  179. return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &mulFloat32);
  180. }
  181. bool mulQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
  182. int32_t activation, uint8_t* out, const Shape& shapeOut) {
  183. NNTRACE_TRANS("mulQuant8");
  184. const int32_t input1_offset = -shape1.offset;
  185. const int32_t input2_offset = -shape2.offset;
  186. const int32_t output_offset = shapeOut.offset;
  187. const double input_product_scale = shape1.scale * shape2.scale;
  188. const double real_multiplier = input_product_scale / shapeOut.scale;
  189. int32 output_multiplier;
  190. int output_shift;
  191. if (!QuantizeMultiplierSmallerThanOne(real_multiplier, &output_multiplier, &output_shift)) {
  192. return false;
  193. }
  194. int32_t output_activation_min;
  195. int32_t output_activation_max;
  196. CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
  197. &output_activation_max);
  198. // Use BROADCAST version to handle the normal case.
  199. NNTRACE_COMP_SWITCH("optimized_ops::BroadcastMul");
  200. tflite::optimized_ops::BroadcastMul(in1, convertShapeToDims(shape1), input1_offset, in2,
  201. convertShapeToDims(shape2), input2_offset, output_offset,
  202. output_multiplier, output_shift, output_activation_min,
  203. output_activation_max, out, convertShapeToDims(shapeOut));
  204. return true;
  205. }
  206. bool subFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
  207. int32_t activation, float* out, const Shape& shapeOut) {
  208. NNTRACE_TRANS("subFloat32");
  209. NNTRACE_COMP_SWITCH("optimized_ops::Sub");
  210. tflite::optimized_ops::Sub(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
  211. out, convertShapeToDims(shapeOut));
  212. // TFLite does not apply activation to broadcast sub.
  213. float output_activation_min, output_activation_max;
  214. CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
  215. uint32_t numOutputElements = getNumberOfElements(shapeOut);
  216. for (uint32_t i = 0; i < numOutputElements; i++) {
  217. out[i] = std::min(std::max(out[i], output_activation_min), output_activation_max);
  218. }
  219. return true;
  220. }
  221. bool subFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
  222. int32_t activation, _Float16* out, const Shape& shapeOut) {
  223. NNTRACE_TRANS("subFloat16");
  224. return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &subFloat32);
  225. }
  226. bool subQuant8(const uint8_t* in1, const Shape& shape1, const uint8_t* in2, const Shape& shape2,
  227. int32_t activation, uint8_t* out, const Shape& shapeOut) {
  228. NNTRACE_TRANS("subQuant8");
  229. const int32_t input1_offset = -shape1.offset;
  230. const int32_t input2_offset = -shape2.offset;
  231. const int32_t output_offset = shapeOut.offset;
  232. const int left_shift = 20;
  233. const double twice_max_input_scale = 2 * std::max(shape1.scale, shape2.scale);
  234. const double real_input1_multiplier = shape1.scale / twice_max_input_scale;
  235. const double real_input2_multiplier = shape2.scale / twice_max_input_scale;
  236. const double real_output_multiplier =
  237. twice_max_input_scale / ((1 << left_shift) * shapeOut.scale);
  238. int32_t input1_multiplier;
  239. int32_t input1_shift;
  240. if (!QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
  241. &input1_shift)) {
  242. return false;
  243. }
  244. int32_t input2_multiplier;
  245. int32_t input2_shift;
  246. if (!QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
  247. &input2_shift)) {
  248. return false;
  249. }
  250. input2_multiplier *= -1;
  251. int32_t output_multiplier;
  252. int32_t output_shift;
  253. if (!QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
  254. &output_shift)) {
  255. return false;
  256. }
  257. int32_t output_activation_min;
  258. int32_t output_activation_max;
  259. CalculateActivationRangeUint8(activation, shapeOut, &output_activation_min,
  260. &output_activation_max);
  261. // We are using tflite::optimized_ops::BroadcastAdd unconditionally here
  262. // because tflite::optimized_ops::Add fails to pass some of the
  263. // sub_quantized_different_scales tests.
  264. NNTRACE_COMP_SWITCH("optimized_ops::BroadcastAdd");
  265. #define ANDROID_NN_BROADCAST_ADD(activation) \
  266. tflite::optimized_ops::BroadcastAdd<tflite::FusedActivationFunctionType::activation>( \
  267. left_shift, in1, convertShapeToDims(shape1), input1_offset, input1_multiplier, \
  268. input1_shift, in2, convertShapeToDims(shape2), input2_offset, input2_multiplier, \
  269. input2_shift, output_offset, output_multiplier, output_shift, output_activation_min, \
  270. output_activation_max, out, convertShapeToDims(shapeOut))
  271. ANDROID_NN_MACRO_DISPATCH(ANDROID_NN_BROADCAST_ADD)
  272. #undef ANDROID_NN_BROADCAST_ADD
  273. return true;
  274. }
  275. bool divFloat32(const float* in1, const Shape& shape1, const float* in2, const Shape& shape2,
  276. int32_t activation, float* out, const Shape& shapeOut) {
  277. NNTRACE_TRANS("divFloat32");
  278. float output_activation_min, output_activation_max;
  279. CalculateActivationRangeFloat(activation, &output_activation_min, &output_activation_max);
  280. bool needBroadcast = !SameShape(shape1, shape2);
  281. if (needBroadcast) {
  282. NNTRACE_COMP_SWITCH("optimized_ops::BroadcastDiv");
  283. tflite::optimized_ops::BroadcastDiv(
  284. in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
  285. output_activation_min, output_activation_max, out, convertShapeToDims(shapeOut));
  286. } else {
  287. NNTRACE_COMP_SWITCH("optimized_ops::Div");
  288. tflite::optimized_ops::Div(in1, convertShapeToDims(shape1), in2, convertShapeToDims(shape2),
  289. output_activation_min, output_activation_max, out,
  290. convertShapeToDims(shapeOut));
  291. }
  292. return true;
  293. }
  294. bool divFloat16(const _Float16* in1, const Shape& shape1, const _Float16* in2, const Shape& shape2,
  295. int32_t activation, _Float16* out, const Shape& shapeOut) {
  296. NNTRACE_TRANS("divFloat16");
  297. return binaryOperationFloat16(in1, shape1, in2, shape2, activation, out, shapeOut, &divFloat32);
  298. }
  299. } // namespace
  300. bool validate(OperationType opType, const IOperationValidationContext* context) {
  301. const HalVersion opIntroducedAt = (opType == OperationType::DIV || opType == OperationType::SUB)
  302. ? HalVersion::V1_1
  303. : HalVersion::V1_0;
  304. NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
  305. NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
  306. auto inputType = context->getInputType(kInputTensor1);
  307. if (inputType == OperandType::TENSOR_FLOAT32) {
  308. NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
  309. } else if (inputType == OperandType::TENSOR_FLOAT16) {
  310. NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
  311. } else if (inputType == OperandType::TENSOR_QUANT8_ASYMM) {
  312. if (opType == OperationType::SUB) {
  313. NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_2, opIntroducedAt)));
  314. } else if (opType == OperationType::DIV) {
  315. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
  316. } else if (opType == OperationType::MUL) {
  317. Shape output = context->getOutputShape(kOutputTensor);
  318. Shape input1 = context->getInputShape(kInputTensor1);
  319. Shape input2 = context->getInputShape(kInputTensor2);
  320. NN_RET_CHECK_GT(output.scale, input1.scale * input2.scale);
  321. NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
  322. } else {
  323. NN_RET_CHECK(validateHalVersion(context, std::max(HalVersion::V1_0, opIntroducedAt)));
  324. }
  325. } else {
  326. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << getOperationName(opType);
  327. }
  328. return validateInputTypes(context, {inputType, inputType, OperandType::INT32}) &&
  329. validateOutputTypes(context, {inputType});
  330. }
  331. bool prepare(IOperationExecutionContext* context) {
  332. Shape input1 = context->getInputShape(kInputTensor1);
  333. Shape input2 = context->getInputShape(kInputTensor2);
  334. Shape output = context->getOutputShape(kOutputTensor);
  335. NN_RET_CHECK_LE(getNumberOfDimensions(input1), 4);
  336. NN_RET_CHECK_LE(getNumberOfDimensions(input2), 4);
  337. NN_RET_CHECK(calculateBroadcastedShape(input1, input2, &output));
  338. return context->setOutputShape(kOutputTensor, output);
  339. }
  340. bool executeAdd(IOperationExecutionContext* context) {
  341. // Bypass execution in the case of zero-sized input.
  342. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  343. switch (context->getInputType(kInputTensor1)) {
  344. case OperandType::TENSOR_FLOAT16:
  345. return addFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
  346. context->getInputShape(kInputTensor1),
  347. context->getInputBuffer<_Float16>(kInputTensor2),
  348. context->getInputShape(kInputTensor2),
  349. context->getInputValue<int32_t>(kActivationScalar),
  350. context->getOutputBuffer<_Float16>(kOutputTensor),
  351. context->getOutputShape(kOutputTensor));
  352. case OperandType::TENSOR_FLOAT32:
  353. return addFloat32(context->getInputBuffer<float>(kInputTensor1),
  354. context->getInputShape(kInputTensor1),
  355. context->getInputBuffer<float>(kInputTensor2),
  356. context->getInputShape(kInputTensor2),
  357. context->getInputValue<int32_t>(kActivationScalar),
  358. context->getOutputBuffer<float>(kOutputTensor),
  359. context->getOutputShape(kOutputTensor));
  360. case OperandType::TENSOR_QUANT8_ASYMM:
  361. return addQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
  362. context->getInputShape(kInputTensor1),
  363. context->getInputBuffer<uint8_t>(kInputTensor2),
  364. context->getInputShape(kInputTensor2),
  365. context->getInputValue<int32_t>(kActivationScalar),
  366. context->getOutputBuffer<uint8_t>(kOutputTensor),
  367. context->getOutputShape(kOutputTensor));
  368. default:
  369. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation ADD";
  370. }
  371. }
  372. bool executeMul(IOperationExecutionContext* context) {
  373. // Bypass execution in the case of zero-sized input.
  374. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  375. switch (context->getInputType(kInputTensor1)) {
  376. case OperandType::TENSOR_FLOAT16:
  377. return mulFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
  378. context->getInputShape(kInputTensor1),
  379. context->getInputBuffer<_Float16>(kInputTensor2),
  380. context->getInputShape(kInputTensor2),
  381. context->getInputValue<int32_t>(kActivationScalar),
  382. context->getOutputBuffer<_Float16>(kOutputTensor),
  383. context->getOutputShape(kOutputTensor));
  384. case OperandType::TENSOR_FLOAT32:
  385. return mulFloat32(context->getInputBuffer<float>(kInputTensor1),
  386. context->getInputShape(kInputTensor1),
  387. context->getInputBuffer<float>(kInputTensor2),
  388. context->getInputShape(kInputTensor2),
  389. context->getInputValue<int32_t>(kActivationScalar),
  390. context->getOutputBuffer<float>(kOutputTensor),
  391. context->getOutputShape(kOutputTensor));
  392. case OperandType::TENSOR_QUANT8_ASYMM:
  393. return mulQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
  394. context->getInputShape(kInputTensor1),
  395. context->getInputBuffer<uint8_t>(kInputTensor2),
  396. context->getInputShape(kInputTensor2),
  397. context->getInputValue<int32_t>(kActivationScalar),
  398. context->getOutputBuffer<uint8_t>(kOutputTensor),
  399. context->getOutputShape(kOutputTensor));
  400. default:
  401. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation MUL";
  402. }
  403. }
  404. bool executeSub(IOperationExecutionContext* context) {
  405. // Bypass execution in the case of zero-sized input.
  406. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  407. switch (context->getInputType(kInputTensor1)) {
  408. case OperandType::TENSOR_FLOAT16:
  409. return subFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
  410. context->getInputShape(kInputTensor1),
  411. context->getInputBuffer<_Float16>(kInputTensor2),
  412. context->getInputShape(kInputTensor2),
  413. context->getInputValue<int32_t>(kActivationScalar),
  414. context->getOutputBuffer<_Float16>(kOutputTensor),
  415. context->getOutputShape(kOutputTensor));
  416. case OperandType::TENSOR_FLOAT32:
  417. return subFloat32(context->getInputBuffer<float>(kInputTensor1),
  418. context->getInputShape(kInputTensor1),
  419. context->getInputBuffer<float>(kInputTensor2),
  420. context->getInputShape(kInputTensor2),
  421. context->getInputValue<int32_t>(kActivationScalar),
  422. context->getOutputBuffer<float>(kOutputTensor),
  423. context->getOutputShape(kOutputTensor));
  424. case OperandType::TENSOR_QUANT8_ASYMM:
  425. return subQuant8(context->getInputBuffer<uint8_t>(kInputTensor1),
  426. context->getInputShape(kInputTensor1),
  427. context->getInputBuffer<uint8_t>(kInputTensor2),
  428. context->getInputShape(kInputTensor2),
  429. context->getInputValue<int32_t>(kActivationScalar),
  430. context->getOutputBuffer<uint8_t>(kOutputTensor),
  431. context->getOutputShape(kOutputTensor));
  432. default:
  433. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation SUB";
  434. }
  435. }
  436. bool executeDiv(IOperationExecutionContext* context) {
  437. // Bypass execution in the case of zero-sized input.
  438. if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
  439. switch (context->getInputType(kInputTensor1)) {
  440. case OperandType::TENSOR_FLOAT16:
  441. return divFloat16(context->getInputBuffer<_Float16>(kInputTensor1),
  442. context->getInputShape(kInputTensor1),
  443. context->getInputBuffer<_Float16>(kInputTensor2),
  444. context->getInputShape(kInputTensor2),
  445. context->getInputValue<int32_t>(kActivationScalar),
  446. context->getOutputBuffer<_Float16>(kOutputTensor),
  447. context->getOutputShape(kOutputTensor));
  448. case OperandType::TENSOR_FLOAT32:
  449. return divFloat32(context->getInputBuffer<float>(kInputTensor1),
  450. context->getInputShape(kInputTensor1),
  451. context->getInputBuffer<float>(kInputTensor2),
  452. context->getInputShape(kInputTensor2),
  453. context->getInputValue<int32_t>(kActivationScalar),
  454. context->getOutputBuffer<float>(kOutputTensor),
  455. context->getOutputShape(kOutputTensor));
  456. default:
  457. NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation DIV";
  458. }
  459. }
  460. } // namespace broadcast
  461. using std::placeholders::_1;
  462. NN_REGISTER_OPERATION(ADD, "ADD", std::bind(broadcast::validate, OperationType::ADD, _1),
  463. broadcast::prepare, broadcast::executeAdd, .allowZeroSizedInput = true);
  464. NN_REGISTER_OPERATION(MUL, "MUL", std::bind(broadcast::validate, OperationType::MUL, _1),
  465. broadcast::prepare, broadcast::executeMul, .allowZeroSizedInput = true);
  466. NN_REGISTER_OPERATION(SUB, "SUB", std::bind(broadcast::validate, OperationType::SUB, _1),
  467. broadcast::prepare, broadcast::executeSub, .allowZeroSizedInput = true);
  468. NN_REGISTER_OPERATION(DIV, "DIV", std::bind(broadcast::validate, OperationType::DIV, _1),
  469. broadcast::prepare, broadcast::executeDiv, .allowZeroSizedInput = true);
  470. } // namespace nn
  471. } // namespace android