123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377 |
- /*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "QuantizedLSTM.h"
- #include "NeuralNetworksWrapper.h"
- #include "gmock/gmock-matchers.h"
- #include "gtest/gtest.h"
- #include <iostream>
- namespace android {
- namespace nn {
- namespace wrapper {
- namespace {
- struct OperandTypeParams {
- Type type;
- std::vector<uint32_t> shape;
- float scale;
- int32_t zeroPoint;
- OperandTypeParams(Type type, std::vector<uint32_t> shape, float scale, int32_t zeroPoint)
- : type(type), shape(shape), scale(scale), zeroPoint(zeroPoint) {}
- };
- } // namespace
- using ::testing::Each;
- using ::testing::ElementsAreArray;
- using ::testing::FloatNear;
- using ::testing::Matcher;
- class QuantizedLSTMOpModel {
- public:
- QuantizedLSTMOpModel(const std::vector<OperandTypeParams>& inputOperandTypeParams) {
- std::vector<uint32_t> inputs;
- for (int i = 0; i < NUM_INPUTS; ++i) {
- const auto& curOTP = inputOperandTypeParams[i];
- OperandType curType(curOTP.type, curOTP.shape, curOTP.scale, curOTP.zeroPoint);
- inputs.push_back(model_.addOperand(&curType));
- }
- const uint32_t numBatches = inputOperandTypeParams[0].shape[0];
- inputSize_ = inputOperandTypeParams[0].shape[0];
- const uint32_t outputSize =
- inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor].shape[1];
- outputSize_ = outputSize;
- std::vector<uint32_t> outputs;
- OperandType cellStateOutOperandType(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize},
- 1. / 2048., 0);
- outputs.push_back(model_.addOperand(&cellStateOutOperandType));
- OperandType outputOperandType(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize},
- 1. / 128., 128);
- outputs.push_back(model_.addOperand(&outputOperandType));
- model_.addOperation(ANEURALNETWORKS_QUANTIZED_16BIT_LSTM, inputs, outputs);
- model_.identifyInputsAndOutputs(inputs, outputs);
- initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kInputTensor], &input_);
- initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevOutputTensor],
- &prevOutput_);
- initializeInputData(inputOperandTypeParams[QuantizedLSTMCell::kPrevCellStateTensor],
- &prevCellState_);
- cellStateOut_.resize(numBatches * outputSize, 0);
- output_.resize(numBatches * outputSize, 0);
- model_.finish();
- }
- void invoke() {
- ASSERT_TRUE(model_.isValid());
- Compilation compilation(&model_);
- compilation.finish();
- Execution execution(&compilation);
- // Set all the inputs.
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputTensor, input_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToInputWeightsTensor,
- inputToInputWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToForgetWeightsTensor,
- inputToForgetWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToCellWeightsTensor,
- inputToCellWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kInputToOutputWeightsTensor,
- inputToOutputWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToInputWeightsTensor,
- recurrentToInputWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToForgetWeightsTensor,
- recurrentToForgetWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToCellWeightsTensor,
- recurrentToCellWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kRecurrentToOutputWeightsTensor,
- recurrentToOutputWeights_),
- Result::NO_ERROR);
- ASSERT_EQ(
- setInputTensor(&execution, QuantizedLSTMCell::kInputGateBiasTensor, inputGateBias_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kForgetGateBiasTensor,
- forgetGateBias_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kCellGateBiasTensor, cellGateBias_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kOutputGateBiasTensor,
- outputGateBias_),
- Result::NO_ERROR);
- ASSERT_EQ(
- setInputTensor(&execution, QuantizedLSTMCell::kPrevCellStateTensor, prevCellState_),
- Result::NO_ERROR);
- ASSERT_EQ(setInputTensor(&execution, QuantizedLSTMCell::kPrevOutputTensor, prevOutput_),
- Result::NO_ERROR);
- // Set all the outputs.
- ASSERT_EQ(
- setOutputTensor(&execution, QuantizedLSTMCell::kCellStateOutTensor, &cellStateOut_),
- Result::NO_ERROR);
- ASSERT_EQ(setOutputTensor(&execution, QuantizedLSTMCell::kOutputTensor, &output_),
- Result::NO_ERROR);
- ASSERT_EQ(execution.compute(), Result::NO_ERROR);
- // Put state outputs into inputs for the next step
- prevOutput_ = output_;
- prevCellState_ = cellStateOut_;
- }
- int inputSize() { return inputSize_; }
- int outputSize() { return outputSize_; }
- void setInput(const std::vector<uint8_t>& input) { input_ = input; }
- void setWeightsAndBiases(std::vector<uint8_t> inputToInputWeights,
- std::vector<uint8_t> inputToForgetWeights,
- std::vector<uint8_t> inputToCellWeights,
- std::vector<uint8_t> inputToOutputWeights,
- std::vector<uint8_t> recurrentToInputWeights,
- std::vector<uint8_t> recurrentToForgetWeights,
- std::vector<uint8_t> recurrentToCellWeights,
- std::vector<uint8_t> recurrentToOutputWeights,
- std::vector<int32_t> inputGateBias,
- std::vector<int32_t> forgetGateBias,
- std::vector<int32_t> cellGateBias, //
- std::vector<int32_t> outputGateBias) {
- inputToInputWeights_ = inputToInputWeights;
- inputToForgetWeights_ = inputToForgetWeights;
- inputToCellWeights_ = inputToCellWeights;
- inputToOutputWeights_ = inputToOutputWeights;
- recurrentToInputWeights_ = recurrentToInputWeights;
- recurrentToForgetWeights_ = recurrentToForgetWeights;
- recurrentToCellWeights_ = recurrentToCellWeights;
- recurrentToOutputWeights_ = recurrentToOutputWeights;
- inputGateBias_ = inputGateBias;
- forgetGateBias_ = forgetGateBias;
- cellGateBias_ = cellGateBias;
- outputGateBias_ = outputGateBias;
- }
- template <typename T>
- void initializeInputData(OperandTypeParams params, std::vector<T>* vec) {
- int size = 1;
- for (int d : params.shape) {
- size *= d;
- }
- vec->clear();
- vec->resize(size, params.zeroPoint);
- }
- std::vector<uint8_t> getOutput() { return output_; }
- private:
- static constexpr int NUM_INPUTS = 15;
- static constexpr int NUM_OUTPUTS = 2;
- Model model_;
- // Inputs
- std::vector<uint8_t> input_;
- std::vector<uint8_t> inputToInputWeights_;
- std::vector<uint8_t> inputToForgetWeights_;
- std::vector<uint8_t> inputToCellWeights_;
- std::vector<uint8_t> inputToOutputWeights_;
- std::vector<uint8_t> recurrentToInputWeights_;
- std::vector<uint8_t> recurrentToForgetWeights_;
- std::vector<uint8_t> recurrentToCellWeights_;
- std::vector<uint8_t> recurrentToOutputWeights_;
- std::vector<int32_t> inputGateBias_;
- std::vector<int32_t> forgetGateBias_;
- std::vector<int32_t> cellGateBias_;
- std::vector<int32_t> outputGateBias_;
- std::vector<int16_t> prevCellState_;
- std::vector<uint8_t> prevOutput_;
- // Outputs
- std::vector<int16_t> cellStateOut_;
- std::vector<uint8_t> output_;
- int inputSize_;
- int outputSize_;
- template <typename T>
- Result setInputTensor(Execution* execution, int tensor, const std::vector<T>& data) {
- return execution->setInput(tensor, data.data(), sizeof(T) * data.size());
- }
- template <typename T>
- Result setOutputTensor(Execution* execution, int tensor, std::vector<T>* data) {
- return execution->setOutput(tensor, data->data(), sizeof(T) * data->size());
- }
- };
- class QuantizedLstmTest : public ::testing::Test {
- protected:
- void VerifyGoldens(const std::vector<std::vector<uint8_t>>& input,
- const std::vector<std::vector<uint8_t>>& output,
- QuantizedLSTMOpModel* lstm) {
- const int numBatches = input.size();
- EXPECT_GT(numBatches, 0);
- const int inputSize = lstm->inputSize();
- EXPECT_GT(inputSize, 0);
- const int inputSequenceSize = input[0].size() / inputSize;
- EXPECT_GT(inputSequenceSize, 0);
- for (int i = 0; i < inputSequenceSize; ++i) {
- std::vector<uint8_t> inputStep;
- for (int b = 0; b < numBatches; ++b) {
- const uint8_t* batchStart = input[b].data() + i * inputSize;
- const uint8_t* batchEnd = batchStart + inputSize;
- inputStep.insert(inputStep.end(), batchStart, batchEnd);
- }
- lstm->setInput(inputStep);
- lstm->invoke();
- const int outputSize = lstm->outputSize();
- std::vector<float> expected;
- for (int b = 0; b < numBatches; ++b) {
- const uint8_t* goldenBatchStart = output[b].data() + i * outputSize;
- const uint8_t* goldenBatchEnd = goldenBatchStart + outputSize;
- expected.insert(expected.end(), goldenBatchStart, goldenBatchEnd);
- }
- EXPECT_THAT(lstm->getOutput(), ElementsAreArray(expected));
- }
- }
- };
- // Inputs and weights in this test are random and the test only checks that the
- // outputs are equal to outputs obtained from running TF Lite version of
- // quantized LSTM on the same inputs.
- TEST_F(QuantizedLstmTest, BasicQuantizedLstmTest) {
- const int numBatches = 2;
- const int inputSize = 2;
- const int outputSize = 4;
- float weightsScale = 0.00408021;
- int weightsZeroPoint = 100;
- // OperandType biasOperandType(Type::TENSOR_INT32, input_shapes[3],
- // weightsScale / 128., 0);
- // inputs.push_back(model_.addOperand(&biasOperandType));
- // OperandType prevCellStateOperandType(Type::TENSOR_QUANT16_SYMM, input_shapes[4],
- // 1. / 2048., 0);
- // inputs.push_back(model_.addOperand(&prevCellStateOperandType));
- QuantizedLSTMOpModel lstm({
- // input
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, inputSize}, 1. / 128., 128),
- // inputToInputWeights
- // inputToForgetWeights
- // inputToCellWeights
- // inputToOutputWeights
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
- weightsZeroPoint),
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
- weightsZeroPoint),
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
- weightsZeroPoint),
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, inputSize}, weightsScale,
- weightsZeroPoint),
- // recurrentToInputWeights
- // recurrentToForgetWeights
- // recurrentToCellWeights
- // recurrentToOutputWeights
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
- weightsZeroPoint),
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
- weightsZeroPoint),
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
- weightsZeroPoint),
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {outputSize, outputSize}, weightsScale,
- weightsZeroPoint),
- // inputGateBias
- // forgetGateBias
- // cellGateBias
- // outputGateBias
- OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
- OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
- OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
- OperandTypeParams(Type::TENSOR_INT32, {outputSize}, weightsScale / 128., 0),
- // prevCellState
- OperandTypeParams(Type::TENSOR_QUANT16_SYMM, {numBatches, outputSize}, 1. / 2048., 0),
- // prevOutput
- OperandTypeParams(Type::TENSOR_QUANT8_ASYMM, {numBatches, outputSize}, 1. / 128., 128),
- });
- lstm.setWeightsAndBiases(
- // inputToInputWeights
- {146, 250, 235, 171, 10, 218, 171, 108},
- // inputToForgetWeights
- {24, 50, 132, 179, 158, 110, 3, 169},
- // inputToCellWeights
- {133, 34, 29, 49, 206, 109, 54, 183},
- // inputToOutputWeights
- {195, 187, 11, 99, 109, 10, 218, 48},
- // recurrentToInputWeights
- {254, 206, 77, 168, 71, 20, 215, 6, 223, 7, 118, 225, 59, 130, 174, 26},
- // recurrentToForgetWeights
- {137, 240, 103, 52, 68, 51, 237, 112, 0, 220, 89, 23, 69, 4, 207, 253},
- // recurrentToCellWeights
- {172, 60, 205, 65, 14, 0, 140, 168, 240, 223, 133, 56, 142, 64, 246, 216},
- // recurrentToOutputWeights
- {106, 214, 67, 23, 59, 158, 45, 3, 119, 132, 49, 205, 129, 218, 11, 98},
- // inputGateBias
- {-7876, 13488, -726, 32839},
- // forgetGateBias
- {9206, -46884, -11693, -38724},
- // cellGateBias
- {39481, 48624, 48976, -21419},
- // outputGateBias
- {-58999, -17050, -41852, -40538});
- // LSTM input is stored as numBatches x (sequenceLength x inputSize) vector.
- std::vector<std::vector<uint8_t>> lstmInput;
- // clang-format off
- lstmInput = {{154, 166,
- 166, 179,
- 141, 141},
- {100, 200,
- 50, 150,
- 111, 222}};
- // clang-format on
- // LSTM output is stored as numBatches x (sequenceLength x outputSize) vector.
- std::vector<std::vector<uint8_t>> lstmGoldenOutput;
- // clang-format off
- lstmGoldenOutput = {{136, 150, 140, 115,
- 140, 151, 146, 112,
- 139, 153, 146, 114},
- {135, 152, 138, 112,
- 136, 156, 142, 112,
- 141, 154, 146, 108}};
- // clang-format on
- VerifyGoldens(lstmInput, lstmGoldenOutput, &lstm);
- };
- } // namespace wrapper
- } // namespace nn
- } // namespace android
|