LSHProjectionTest.cpp 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. /*
  2. * Copyright (C) 2017 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "LSHProjection.h"
  17. #include "NeuralNetworksWrapper.h"
  18. #include "gmock/gmock-generated-matchers.h"
  19. #include "gmock/gmock-matchers.h"
  20. #include "gtest/gtest.h"
  21. using ::testing::FloatNear;
  22. using ::testing::Matcher;
  23. namespace android {
  24. namespace nn {
  25. namespace wrapper {
  26. using ::testing::ElementsAre;
  27. #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
  28. ACTION(Hash, float) \
  29. ACTION(Input, int) \
  30. ACTION(Weight, float)
  31. // For all output and intermediate states
  32. #define FOR_ALL_OUTPUT_TENSORS(ACTION) ACTION(Output, int)
  33. class LSHProjectionOpModel {
  34. public:
  35. LSHProjectionOpModel(LSHProjectionType type, std::initializer_list<uint32_t> hash_shape,
  36. std::initializer_list<uint32_t> input_shape,
  37. std::initializer_list<uint32_t> weight_shape)
  38. : type_(type) {
  39. std::vector<uint32_t> inputs;
  40. OperandType HashTy(Type::TENSOR_FLOAT32, hash_shape);
  41. inputs.push_back(model_.addOperand(&HashTy));
  42. OperandType InputTy(Type::TENSOR_INT32, input_shape);
  43. inputs.push_back(model_.addOperand(&InputTy));
  44. OperandType WeightTy(Type::TENSOR_FLOAT32, weight_shape);
  45. inputs.push_back(model_.addOperand(&WeightTy));
  46. OperandType TypeParamTy(Type::INT32, {});
  47. inputs.push_back(model_.addOperand(&TypeParamTy));
  48. std::vector<uint32_t> outputs;
  49. auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
  50. uint32_t sz = 1;
  51. for (uint32_t d : dims) {
  52. sz *= d;
  53. }
  54. return sz;
  55. };
  56. uint32_t outShapeDimension = 0;
  57. if (type == LSHProjectionType_SPARSE || type == LSHProjectionType_SPARSE_DEPRECATED) {
  58. auto it = hash_shape.begin();
  59. Output_.insert(Output_.end(), *it, 0.f);
  60. outShapeDimension = *it;
  61. } else {
  62. Output_.insert(Output_.end(), multiAll(hash_shape), 0.f);
  63. outShapeDimension = multiAll(hash_shape);
  64. }
  65. OperandType OutputTy(Type::TENSOR_INT32, {outShapeDimension});
  66. outputs.push_back(model_.addOperand(&OutputTy));
  67. model_.addOperation(ANEURALNETWORKS_LSH_PROJECTION, inputs, outputs);
  68. model_.identifyInputsAndOutputs(inputs, outputs);
  69. model_.finish();
  70. }
  71. #define DefineSetter(X, T) \
  72. void Set##X(const std::vector<T>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
  73. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
  74. #undef DefineSetter
  75. const std::vector<int>& GetOutput() const { return Output_; }
  76. void Invoke() {
  77. ASSERT_TRUE(model_.isValid());
  78. Compilation compilation(&model_);
  79. compilation.finish();
  80. Execution execution(&compilation);
  81. #define SetInputOrWeight(X, T) \
  82. ASSERT_EQ( \
  83. execution.setInput(LSHProjection::k##X##Tensor, X##_.data(), sizeof(T) * X##_.size()), \
  84. Result::NO_ERROR);
  85. FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
  86. #undef SetInputOrWeight
  87. #define SetOutput(X, T) \
  88. ASSERT_EQ(execution.setOutput(LSHProjection::k##X##Tensor, X##_.data(), \
  89. sizeof(T) * X##_.size()), \
  90. Result::NO_ERROR);
  91. FOR_ALL_OUTPUT_TENSORS(SetOutput);
  92. #undef SetOutput
  93. ASSERT_EQ(execution.setInput(LSHProjection::kTypeParam, &type_, sizeof(type_)),
  94. Result::NO_ERROR);
  95. ASSERT_EQ(execution.compute(), Result::NO_ERROR);
  96. }
  97. private:
  98. Model model_;
  99. LSHProjectionType type_;
  100. std::vector<float> Hash_;
  101. std::vector<int> Input_;
  102. std::vector<float> Weight_;
  103. std::vector<int> Output_;
  104. }; // namespace wrapper
  105. TEST(LSHProjectionOpTest2, DenseWithThreeInputs) {
  106. LSHProjectionOpModel m(LSHProjectionType_DENSE, {4, 2}, {3, 2}, {3});
  107. m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
  108. m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
  109. m.SetWeight({0.12, 0.34, 0.56});
  110. m.Invoke();
  111. EXPECT_THAT(m.GetOutput(), ElementsAre(1, 1, 1, 0, 1, 1, 1, 0));
  112. }
  113. TEST(LSHProjectionOpTest2, SparseDeprecatedWithTwoInputs) {
  114. LSHProjectionOpModel m(LSHProjectionType_SPARSE_DEPRECATED, {4, 2}, {3, 2}, {0});
  115. m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
  116. m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
  117. m.Invoke();
  118. EXPECT_THAT(m.GetOutput(), ElementsAre(1, 2, 2, 0));
  119. }
  120. TEST(LSHProjectionOpTest2, SparseWithTwoInputs) {
  121. LSHProjectionOpModel m(LSHProjectionType_SPARSE, {4, 2}, {3, 2}, {0});
  122. m.SetInput({12345, 54321, 67890, 9876, -12345678, -87654321});
  123. m.SetHash({0.123, 0.456, -0.321, -0.654, 1.234, 5.678, -4.321, -8.765});
  124. m.Invoke();
  125. EXPECT_THAT(m.GetOutput(), ElementsAre(1, 6, 10, 12));
  126. }
  127. } // namespace wrapper
  128. } // namespace nn
  129. } // namespace android