HashtableLookupTest.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "HashtableLookup.h"
  17. #include "NeuralNetworksWrapper.h"
  18. #include "gmock/gmock-matchers.h"
  19. #include "gtest/gtest.h"
  20. using ::testing::FloatNear;
  21. using ::testing::Matcher;
  22. namespace android {
  23. namespace nn {
  24. namespace wrapper {
  25. namespace {
  26. std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
  27. float max_abs_error=1.e-6) {
  28. std::vector<Matcher<float>> matchers;
  29. matchers.reserve(values.size());
  30. for (const float& v : values) {
  31. matchers.emplace_back(FloatNear(v, max_abs_error));
  32. }
  33. return matchers;
  34. }
  35. } // namespace
  36. using ::testing::ElementsAreArray;
  37. #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
  38. ACTION(Lookup, int) \
  39. ACTION(Key, int) \
  40. ACTION(Value, float)
  41. // For all output and intermediate states
  42. #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
  43. ACTION(Output, float) \
  44. ACTION(Hits, uint8_t)
  45. class HashtableLookupOpModel {
  46. public:
  47. HashtableLookupOpModel(std::initializer_list<uint32_t> lookup_shape,
  48. std::initializer_list<uint32_t> key_shape,
  49. std::initializer_list<uint32_t> value_shape) {
  50. auto it_vs = value_shape.begin();
  51. rows_ = *it_vs++;
  52. features_ = *it_vs;
  53. std::vector<uint32_t> inputs;
  54. // Input and weights
  55. OperandType LookupTy(Type::TENSOR_INT32, lookup_shape);
  56. inputs.push_back(model_.addOperand(&LookupTy));
  57. OperandType KeyTy(Type::TENSOR_INT32, key_shape);
  58. inputs.push_back(model_.addOperand(&KeyTy));
  59. OperandType ValueTy(Type::TENSOR_FLOAT32, value_shape);
  60. inputs.push_back(model_.addOperand(&ValueTy));
  61. // Output and other intermediate state
  62. std::vector<uint32_t> outputs;
  63. std::vector<uint32_t> out_dim(lookup_shape.begin(), lookup_shape.end());
  64. out_dim.push_back(features_);
  65. OperandType OutputOpndTy(Type::TENSOR_FLOAT32, out_dim);
  66. outputs.push_back(model_.addOperand(&OutputOpndTy));
  67. OperandType HitsOpndTy(Type::TENSOR_QUANT8_ASYMM, lookup_shape, 1.f, 0);
  68. outputs.push_back(model_.addOperand(&HitsOpndTy));
  69. auto multiAll = [](const std::vector<uint32_t> &dims) -> uint32_t {
  70. uint32_t sz = 1;
  71. for (uint32_t d : dims) { sz *= d; }
  72. return sz;
  73. };
  74. Value_.insert(Value_.end(), multiAll(value_shape), 0.f);
  75. Output_.insert(Output_.end(), multiAll(out_dim), 0.f);
  76. Hits_.insert(Hits_.end(), multiAll(lookup_shape), 0);
  77. model_.addOperation(ANEURALNETWORKS_HASHTABLE_LOOKUP, inputs, outputs);
  78. model_.identifyInputsAndOutputs(inputs, outputs);
  79. model_.finish();
  80. }
  81. void Invoke() {
  82. ASSERT_TRUE(model_.isValid());
  83. Compilation compilation(&model_);
  84. compilation.finish();
  85. Execution execution(&compilation);
  86. #define SetInputOrWeight(X, T) \
  87. ASSERT_EQ(execution.setInput(HashtableLookup::k##X##Tensor, X##_.data(), \
  88. sizeof(T) * X##_.size()), \
  89. Result::NO_ERROR);
  90. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
  91. #undef SetInputOrWeight
  92. #define SetOutput(X, T) \
  93. ASSERT_EQ(execution.setOutput(HashtableLookup::k##X##Tensor, X##_.data(), \
  94. sizeof(T) * X##_.size()), \
  95. Result::NO_ERROR);
  96. FOR_ALL_OUTPUT_TENSORS(SetOutput);
  97. #undef SetOutput
  98. ASSERT_EQ(execution.compute(), Result::NO_ERROR);
  99. }
  100. #define DefineSetter(X, T) \
  101. void Set##X(const std::vector<T>& f) { \
  102. X##_.insert(X##_.end(), f.begin(), f.end()); \
  103. }
  104. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
  105. #undef DefineSetter
  106. void SetHashtableValue(const std::function<float(uint32_t, uint32_t)>& function) {
  107. for (uint32_t i = 0; i < rows_; i++) {
  108. for (uint32_t j = 0; j < features_; j++) {
  109. Value_[i * features_ + j] = function(i, j);
  110. }
  111. }
  112. }
  113. const std::vector<float>& GetOutput() const { return Output_; }
  114. const std::vector<uint8_t>& GetHits() const { return Hits_; }
  115. private:
  116. Model model_;
  117. uint32_t rows_;
  118. uint32_t features_;
  119. #define DefineTensor(X, T) std::vector<T> X##_;
  120. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
  121. FOR_ALL_OUTPUT_TENSORS(DefineTensor);
  122. #undef DefineTensor
  123. };
  124. TEST(HashtableLookupOpTest, BlackBoxTest) {
  125. HashtableLookupOpModel m({4}, {3}, {3, 2});
  126. m.SetLookup({1234, -292, -11, 0});
  127. m.SetKey({-11, 0, 1234});
  128. m.SetHashtableValue([](int i, int j) { return i + j / 10.0f; });
  129. m.Invoke();
  130. EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({
  131. 2.0, 2.1, // 2-rd item
  132. 0, 0, // Not found
  133. 0.0, 0.1, // 0-th item
  134. 1.0, 1.1, // 1-st item
  135. })));
  136. EXPECT_EQ(m.GetHits(), std::vector<uint8_t>({
  137. 1, 0, 1, 1,
  138. }));
  139. }
  140. } // namespace wrapper
  141. } // namespace nn
  142. } // namespace android