123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428 |
- /*
- * Copyright (C) 2017 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "LSTM.h"
- #include <android-base/logging.h>
- #include "NeuralNetworksWrapper.h"
- #include "gmock/gmock-matchers.h"
- #include "gtest/gtest.h"
- #include <sstream>
- #include <string>
- #include <vector>
- namespace android {
- namespace nn {
- namespace wrapper {
- using ::testing::Each;
- using ::testing::FloatNear;
- using ::testing::Matcher;
- namespace {
- std::vector<Matcher<float>> ArrayFloatNear(const std::vector<float>& values,
- float max_abs_error = 1.e-6) {
- std::vector<Matcher<float>> matchers;
- matchers.reserve(values.size());
- for (const float& v : values) {
- matchers.emplace_back(FloatNear(v, max_abs_error));
- }
- return matchers;
- }
- } // anonymous namespace
- #define FOR_ALL_INPUT_AND_WEIGHT_TENSORS(ACTION) \
- ACTION(Input) \
- ACTION(InputToInputWeights) \
- ACTION(InputToCellWeights) \
- ACTION(InputToForgetWeights) \
- ACTION(InputToOutputWeights) \
- ACTION(RecurrentToInputWeights) \
- ACTION(RecurrentToCellWeights) \
- ACTION(RecurrentToForgetWeights) \
- ACTION(RecurrentToOutputWeights) \
- ACTION(CellToInputWeights) \
- ACTION(CellToForgetWeights) \
- ACTION(CellToOutputWeights) \
- ACTION(InputGateBias) \
- ACTION(CellGateBias) \
- ACTION(ForgetGateBias) \
- ACTION(OutputGateBias) \
- ACTION(ProjectionWeights) \
- ACTION(ProjectionBias) \
- ACTION(OutputStateIn) \
- ACTION(CellStateIn)
- #define FOR_ALL_LAYER_NORM_WEIGHTS(ACTION) \
- ACTION(InputLayerNormWeights) \
- ACTION(ForgetLayerNormWeights) \
- ACTION(CellLayerNormWeights) \
- ACTION(OutputLayerNormWeights)
- // For all output and intermediate states
- #define FOR_ALL_OUTPUT_TENSORS(ACTION) \
- ACTION(ScratchBuffer) \
- ACTION(OutputStateOut) \
- ACTION(CellStateOut) \
- ACTION(Output)
- class LayerNormLSTMOpModel {
- public:
- LayerNormLSTMOpModel(uint32_t n_batch, uint32_t n_input, uint32_t n_cell, uint32_t n_output,
- bool use_cifg, bool use_peephole, bool use_projection_weights,
- bool use_projection_bias, float cell_clip, float proj_clip,
- const std::vector<std::vector<uint32_t>>& input_shapes0)
- : n_input_(n_input),
- n_output_(n_output),
- use_cifg_(use_cifg),
- use_peephole_(use_peephole),
- use_projection_weights_(use_projection_weights),
- use_projection_bias_(use_projection_bias),
- activation_(ActivationFn::kActivationTanh),
- cell_clip_(cell_clip),
- proj_clip_(proj_clip) {
- std::vector<uint32_t> inputs;
- std::vector<std::vector<uint32_t>> input_shapes(input_shapes0);
- auto it = input_shapes.begin();
- // Input and weights
- #define AddInput(X) \
- CHECK(it != input_shapes.end()); \
- OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it++); \
- inputs.push_back(model_.addOperand(&X##OpndTy));
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(AddInput);
- // Parameters
- OperandType ActivationOpndTy(Type::INT32, {});
- inputs.push_back(model_.addOperand(&ActivationOpndTy));
- OperandType CellClipOpndTy(Type::FLOAT32, {});
- inputs.push_back(model_.addOperand(&CellClipOpndTy));
- OperandType ProjClipOpndTy(Type::FLOAT32, {});
- inputs.push_back(model_.addOperand(&ProjClipOpndTy));
- FOR_ALL_LAYER_NORM_WEIGHTS(AddInput);
- #undef AddOperand
- // Output and other intermediate state
- std::vector<std::vector<uint32_t>> output_shapes{
- {n_batch, n_cell * (use_cifg ? 3 : 4)},
- {n_batch, n_output},
- {n_batch, n_cell},
- {n_batch, n_output},
- };
- std::vector<uint32_t> outputs;
- auto it2 = output_shapes.begin();
- #define AddOutput(X) \
- CHECK(it2 != output_shapes.end()); \
- OperandType X##OpndTy(Type::TENSOR_FLOAT32, *it2++); \
- outputs.push_back(model_.addOperand(&X##OpndTy));
- FOR_ALL_OUTPUT_TENSORS(AddOutput);
- #undef AddOutput
- model_.addOperation(ANEURALNETWORKS_LSTM, inputs, outputs);
- model_.identifyInputsAndOutputs(inputs, outputs);
- Input_.insert(Input_.end(), n_batch * n_input, 0.f);
- OutputStateIn_.insert(OutputStateIn_.end(), n_batch * n_output, 0.f);
- CellStateIn_.insert(CellStateIn_.end(), n_batch * n_cell, 0.f);
- auto multiAll = [](const std::vector<uint32_t>& dims) -> uint32_t {
- uint32_t sz = 1;
- for (uint32_t d : dims) {
- sz *= d;
- }
- return sz;
- };
- it2 = output_shapes.begin();
- #define ReserveOutput(X) X##_.insert(X##_.end(), multiAll(*it2++), 0.f);
- FOR_ALL_OUTPUT_TENSORS(ReserveOutput);
- #undef ReserveOutput
- model_.finish();
- }
- #define DefineSetter(X) \
- void Set##X(const std::vector<float>& f) { X##_.insert(X##_.end(), f.begin(), f.end()); }
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineSetter);
- FOR_ALL_LAYER_NORM_WEIGHTS(DefineSetter);
- #undef DefineSetter
- void ResetOutputState() {
- std::fill(OutputStateIn_.begin(), OutputStateIn_.end(), 0.f);
- std::fill(OutputStateOut_.begin(), OutputStateOut_.end(), 0.f);
- }
- void ResetCellState() {
- std::fill(CellStateIn_.begin(), CellStateIn_.end(), 0.f);
- std::fill(CellStateOut_.begin(), CellStateOut_.end(), 0.f);
- }
- void SetInput(int offset, const float* begin, const float* end) {
- for (; begin != end; begin++, offset++) {
- Input_[offset] = *begin;
- }
- }
- uint32_t num_inputs() const { return n_input_; }
- uint32_t num_outputs() const { return n_output_; }
- const std::vector<float>& GetOutput() const { return Output_; }
- void Invoke() {
- ASSERT_TRUE(model_.isValid());
- OutputStateIn_.swap(OutputStateOut_);
- CellStateIn_.swap(CellStateOut_);
- Compilation compilation(&model_);
- compilation.finish();
- Execution execution(&compilation);
- #define SetInputOrWeight(X) \
- ASSERT_EQ( \
- execution.setInput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
- Result::NO_ERROR);
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(SetInputOrWeight);
- FOR_ALL_LAYER_NORM_WEIGHTS(SetInputOrWeight);
- #undef SetInputOrWeight
- #define SetOutput(X) \
- ASSERT_EQ( \
- execution.setOutput(LSTMCell::k##X##Tensor, X##_.data(), sizeof(float) * X##_.size()), \
- Result::NO_ERROR);
- FOR_ALL_OUTPUT_TENSORS(SetOutput);
- #undef SetOutput
- if (use_cifg_) {
- execution.setInput(LSTMCell::kInputToInputWeightsTensor, nullptr, 0);
- execution.setInput(LSTMCell::kRecurrentToInputWeightsTensor, nullptr, 0);
- }
- if (use_peephole_) {
- if (use_cifg_) {
- execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
- }
- } else {
- execution.setInput(LSTMCell::kCellToInputWeightsTensor, nullptr, 0);
- execution.setInput(LSTMCell::kCellToForgetWeightsTensor, nullptr, 0);
- execution.setInput(LSTMCell::kCellToOutputWeightsTensor, nullptr, 0);
- }
- if (use_projection_weights_) {
- if (!use_projection_bias_) {
- execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
- }
- } else {
- execution.setInput(LSTMCell::kProjectionWeightsTensor, nullptr, 0);
- execution.setInput(LSTMCell::kProjectionBiasTensor, nullptr, 0);
- }
- ASSERT_EQ(execution.setInput(LSTMCell::kActivationParam, &activation_, sizeof(activation_)),
- Result::NO_ERROR);
- ASSERT_EQ(execution.setInput(LSTMCell::kCellClipParam, &cell_clip_, sizeof(cell_clip_)),
- Result::NO_ERROR);
- ASSERT_EQ(execution.setInput(LSTMCell::kProjClipParam, &proj_clip_, sizeof(proj_clip_)),
- Result::NO_ERROR);
- ASSERT_EQ(execution.compute(), Result::NO_ERROR);
- }
- private:
- Model model_;
- // Execution execution_;
- const uint32_t n_input_;
- const uint32_t n_output_;
- const bool use_cifg_;
- const bool use_peephole_;
- const bool use_projection_weights_;
- const bool use_projection_bias_;
- const int activation_;
- const float cell_clip_;
- const float proj_clip_;
- #define DefineTensor(X) std::vector<float> X##_;
- FOR_ALL_INPUT_AND_WEIGHT_TENSORS(DefineTensor);
- FOR_ALL_LAYER_NORM_WEIGHTS(DefineTensor);
- FOR_ALL_OUTPUT_TENSORS(DefineTensor);
- #undef DefineTensor
- };
- TEST(LSTMOpTest, LayerNormNoCifgPeepholeProjectionNoClipping) {
- const int n_batch = 2;
- const int n_input = 5;
- // n_cell and n_output have the same size when there is no projection.
- const int n_cell = 4;
- const int n_output = 3;
- LayerNormLSTMOpModel lstm(n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/false, /*use_peephole=*/true,
- /*use_projection_weights=*/true,
- /*use_projection_bias=*/false,
- /*cell_clip=*/0.0, /*proj_clip=*/0.0,
- {
- {n_batch, n_input}, // input tensor
- {n_cell, n_input}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
- {n_cell, n_output}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
- {n_cell}, // cell_to_input_weight tensor
- {n_cell}, // cell_to_forget_weight tensor
- {n_cell}, // cell_to_output_weight tensor
- {n_cell}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
- {n_output, n_cell}, // projection_weight tensor
- {0}, // projection_bias tensor
- {n_batch, n_output}, // output_state_in tensor
- {n_batch, n_cell}, // cell_state_in tensor
- {n_cell}, // input_layer_norm_weights tensor
- {n_cell}, // forget_layer_norm_weights tensor
- {n_cell}, // cell_layer_norm_weights tensor
- {n_cell}, // output_layer_norm_weights tensor
- });
- lstm.SetInputToInputWeights({0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2, 0.3, -0.4, 0.5,
- -0.8, 0.7, -0.6, 0.5, -0.4, -0.5, -0.4, -0.3, -0.2, -0.1});
- lstm.SetInputToForgetWeights({-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2, -0.4, 0.3, -0.8,
- -0.4, 0.3, -0.5, -0.4, -0.6, 0.3, -0.4, -0.6, -0.5, -0.5});
- lstm.SetInputToCellWeights({-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2, -0.3, -0.2, -0.6,
- 0.6, -0.1, -0.4, -0.3, -0.7, 0.7, -0.9, -0.5, 0.8, 0.6});
- lstm.SetInputToOutputWeights({-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3, -0.3, -0.8, -0.2,
- 0.6, -0.2, 0.4, -0.7, -0.3, -0.5, 0.1, 0.5, -0.6, -0.4});
- lstm.SetInputGateBias({0.03, 0.15, 0.22, 0.38});
- lstm.SetForgetGateBias({0.1, -0.3, -0.2, 0.1});
- lstm.SetCellGateBias({-0.05, 0.72, 0.25, 0.08});
- lstm.SetOutputGateBias({0.05, -0.01, 0.2, 0.1});
- lstm.SetRecurrentToInputWeights(
- {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9, -0.2, -0.3, -0.7, 0.05, -0.2, -0.6});
- lstm.SetRecurrentToCellWeights(
- {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08, -0.2, 0.3, 0.8, -0.6, -0.1, 0.2});
- lstm.SetRecurrentToForgetWeights(
- {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4, 0.9, 0.3, -0.1, 0.2, 0.5, 0.2});
- lstm.SetRecurrentToOutputWeights(
- {0.3, -0.1, 0.1, -0.2, -0.5, -0.7, -0.2, -0.6, -0.1, -0.4, -0.7, -0.2});
- lstm.SetCellToInputWeights({0.05, 0.1, 0.25, 0.15});
- lstm.SetCellToForgetWeights({-0.02, -0.15, -0.25, -0.03});
- lstm.SetCellToOutputWeights({0.1, -0.1, -0.5, 0.05});
- lstm.SetProjectionWeights({-0.1, 0.2, 0.01, -0.2, 0.1, 0.5, 0.3, 0.08, 0.07, 0.2, -0.4, 0.2});
- lstm.SetInputLayerNormWeights({0.1, 0.2, 0.3, 0.5});
- lstm.SetForgetLayerNormWeights({0.2, 0.2, 0.4, 0.3});
- lstm.SetCellLayerNormWeights({0.7, 0.2, 0.3, 0.8});
- lstm.SetOutputLayerNormWeights({0.6, 0.2, 0.2, 0.5});
- const std::vector<std::vector<float>> lstm_input = {
- { // Batch0: 3 (input_sequence_size) * 5 (n_input)
- 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
- 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
- 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
- { // Batch1: 3 (input_sequence_size) * 5 (n_input)
- 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
- 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
- 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
- };
- const std::vector<std::vector<float>> lstm_golden_output = {
- {
- // Batch0: 3 (input_sequence_size) * 3 (n_output)
- 0.0244077, 0.128027, -0.00170918, // seq 0
- 0.0137642, 0.140751, 0.0395835, // seq 1
- -0.00459231, 0.155278, 0.0837377, // seq 2
- },
- {
- // Batch1: 3 (input_sequence_size) * 3 (n_output)
- -0.00692428, 0.0848741, 0.063445, // seq 0
- -0.00403912, 0.139963, 0.072681, // seq 1
- 0.00752706, 0.161903, 0.0561371, // seq 2
- }};
- // Resetting cell_state and output_state
- lstm.ResetCellState();
- lstm.ResetOutputState();
- const int input_sequence_size = lstm_input[0].size() / n_input;
- for (int i = 0; i < input_sequence_size; i++) {
- for (int b = 0; b < n_batch; ++b) {
- const float* batch_start = lstm_input[b].data() + i * n_input;
- const float* batch_end = batch_start + n_input;
- lstm.SetInput(b * n_input, batch_start, batch_end);
- }
- lstm.Invoke();
- std::vector<float> expected;
- for (int b = 0; b < n_batch; ++b) {
- const float* golden_start = lstm_golden_output[b].data() + i * n_output;
- const float* golden_end = golden_start + n_output;
- expected.insert(expected.end(), golden_start, golden_end);
- }
- EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected)));
- }
- }
- } // namespace wrapper
- } // namespace nn
- } // namespace android
|