MultinomialTest.cpp 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. /*
  2. * Copyright (C) 2018 The Android Open Source Project
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "Multinomial.h"
  17. #include "HalInterfaces.h"
  18. #include "NeuralNetworksWrapper.h"
  19. #include "gmock/gmock-matchers.h"
  20. #include "gtest/gtest.h"
  21. #include "philox_random.h"
  22. #include "simple_philox.h"
  23. #include "unsupported/Eigen/CXX11/Tensor"
  24. namespace android {
  25. namespace nn {
  26. namespace wrapper {
  27. using ::testing::FloatNear;
  28. constexpr int kFixedRandomSeed1 = 37;
  29. constexpr int kFixedRandomSeed2 = 42;
  30. class MultinomialOpModel {
  31. public:
  32. MultinomialOpModel(uint32_t batch_size, uint32_t class_size, uint32_t sample_size)
  33. : batch_size_(batch_size), class_size_(class_size), sample_size_(sample_size) {
  34. std::vector<uint32_t> inputs;
  35. OperandType logitsType(Type::TENSOR_FLOAT32, {batch_size_, class_size_});
  36. inputs.push_back(model_.addOperand(&logitsType));
  37. OperandType samplesType(Type::INT32, {});
  38. inputs.push_back(model_.addOperand(&samplesType));
  39. OperandType seedsType(Type::TENSOR_INT32, {2});
  40. inputs.push_back(model_.addOperand(&seedsType));
  41. std::vector<uint32_t> outputs;
  42. OperandType outputType(Type::TENSOR_INT32, {batch_size_, sample_size_});
  43. outputs.push_back(model_.addOperand(&outputType));
  44. model_.addOperation(ANEURALNETWORKS_RANDOM_MULTINOMIAL, inputs, outputs);
  45. model_.identifyInputsAndOutputs(inputs, outputs);
  46. model_.finish();
  47. }
  48. void Invoke() {
  49. ASSERT_TRUE(model_.isValid());
  50. Compilation compilation(&model_);
  51. compilation.finish();
  52. Execution execution(&compilation);
  53. tensorflow::random::PhiloxRandom rng(kFixedRandomSeed1);
  54. tensorflow::random::SimplePhilox srng(&rng);
  55. const int sample_count = batch_size_ * class_size_;
  56. for (int i = 0; i < sample_count; ++i) {
  57. input_.push_back(srng.RandDouble());
  58. }
  59. ASSERT_EQ(execution.setInput(Multinomial::kInputTensor, input_.data(),
  60. sizeof(float) * input_.size()),
  61. Result::NO_ERROR);
  62. ASSERT_EQ(execution.setInput(Multinomial::kSampleCountParam, &sample_size_,
  63. sizeof(sample_size_)),
  64. Result::NO_ERROR);
  65. std::vector<uint32_t> seeds{kFixedRandomSeed1, kFixedRandomSeed2};
  66. ASSERT_EQ(execution.setInput(Multinomial::kRandomSeedsTensor, seeds.data(),
  67. sizeof(uint32_t) * seeds.size()),
  68. Result::NO_ERROR);
  69. output_.insert(output_.end(), batch_size_ * sample_size_, 0);
  70. ASSERT_EQ(execution.setOutput(Multinomial::kOutputTensor, output_.data(),
  71. sizeof(uint32_t) * output_.size()),
  72. Result::NO_ERROR);
  73. ASSERT_EQ(execution.compute(), Result::NO_ERROR);
  74. }
  75. const std::vector<float>& GetInput() const { return input_; }
  76. const std::vector<uint32_t>& GetOutput() const { return output_; }
  77. private:
  78. Model model_;
  79. const uint32_t batch_size_;
  80. const uint32_t class_size_;
  81. const uint32_t sample_size_;
  82. std::vector<float> input_;
  83. std::vector<uint32_t> output_;
  84. };
  85. TEST(MultinomialOpTest, ProbabilityDeltaWithinTolerance) {
  86. constexpr int kBatchSize = 8;
  87. constexpr int kNumClasses = 10000;
  88. constexpr int kNumSamples = 128;
  89. constexpr float kMaxProbabilityDelta = 0.025;
  90. MultinomialOpModel multinomial(kBatchSize, kNumClasses, kNumSamples);
  91. multinomial.Invoke();
  92. std::vector<uint32_t> output = multinomial.GetOutput();
  93. std::vector<int> class_counts;
  94. class_counts.resize(kNumClasses);
  95. for (auto index : output) {
  96. class_counts[index]++;
  97. }
  98. std::vector<float> input = multinomial.GetInput();
  99. for (int b = 0; b < kBatchSize; ++b) {
  100. float probability_sum = 0;
  101. const int batch_index = kBatchSize * b;
  102. for (int i = 0; i < kNumClasses; ++i) {
  103. probability_sum += expf(input[batch_index + i]);
  104. }
  105. for (int i = 0; i < kNumClasses; ++i) {
  106. float probability =
  107. static_cast<float>(class_counts[i]) / static_cast<float>(kNumSamples);
  108. float probability_expected = expf(input[batch_index + i]) / probability_sum;
  109. EXPECT_THAT(probability, FloatNear(probability_expected, kMaxProbabilityDelta));
  110. }
  111. }
  112. }
  113. } // namespace wrapper
  114. } // namespace nn
  115. } // namespace android