123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460 |
- /*
- * Copyright (C) 2019 The Android Open Source Project
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include "BidirectionalSequenceLSTM.h"
- #include "CpuExecutor.h"
- #include "CpuOperationUtils.h"
- #include "HalInterfaces.h"
- #include "OperationsUtils.h"
- #include "Tracing.h"
- namespace android {
- namespace nn {
- namespace {
- template <typename T>
- inline T* GetBuffer(RunTimeOperandInfo* operand) {
- return reinterpret_cast<T*>(operand->buffer);
- }
- template <typename T>
- inline const T* GetBuffer(const RunTimeOperandInfo* operand) {
- return reinterpret_cast<const T*>(operand->buffer);
- }
- template <typename T>
- inline const T* GetOptionalBuffer(const RunTimeOperandInfo* operand) {
- return !IsNullInput(operand) ? reinterpret_cast<const T*>(operand->buffer) : nullptr;
- }
- } // anonymous namespace
- BidirectionalSequenceLSTM::BidirectionalSequenceLSTM(const Operation& operation,
- std::vector<RunTimeOperandInfo>& operands) {
- input_ = GetInput(operation, operands, kInputTensor);
- fw_input_to_input_weights_ =
- GetInput(operation, operands, kFwInputToInputWeightsTensor); // optional
- fw_input_to_forget_weights_ = GetInput(operation, operands, kFwInputToForgetWeightsTensor);
- fw_input_to_cell_weights_ = GetInput(operation, operands, kFwInputToCellWeightsTensor);
- fw_input_to_output_weights_ = GetInput(operation, operands, kFwInputToOutputWeightsTensor);
- fw_recurrent_to_input_weights_ =
- GetInput(operation, operands, kFwRecurrentToInputWeightsTensor); // optional
- fw_recurrent_to_forget_weights_ =
- GetInput(operation, operands, kFwRecurrentToForgetWeightsTensor);
- fw_recurrent_to_cell_weights_ = GetInput(operation, operands, kFwRecurrentToCellWeightsTensor);
- fw_recurrent_to_output_weights_ =
- GetInput(operation, operands, kFwRecurrentToOutputWeightsTensor);
- fw_cell_to_input_weights_ =
- GetInput(operation, operands, kFwCellToInputWeightsTensor); // optional
- fw_cell_to_forget_weights_ =
- GetInput(operation, operands, kFwCellToForgetWeightsTensor); // optional
- fw_cell_to_output_weights_ =
- GetInput(operation, operands, kFwCellToOutputWeightsTensor); // optional
- fw_input_gate_bias_ = GetInput(operation, operands, kFwInputGateBiasTensor);
- fw_forget_gate_bias_ = GetInput(operation, operands, kFwForgetGateBiasTensor);
- fw_cell_bias_ = GetInput(operation, operands, kFwCellGateBiasTensor);
- fw_output_gate_bias_ = GetInput(operation, operands, kFwOutputGateBiasTensor);
- fw_projection_weights_ = GetInput(operation, operands, kFwProjectionWeightsTensor); // optional
- fw_projection_bias_ = GetInput(operation, operands, kFwProjectionBiasTensor); // optional
- fw_activation_state_ = GetInput(operation, operands, kFwInputActivationStateTensor);
- fw_cell_state_ = GetInput(operation, operands, kFwInputCellStateTensor);
- bw_input_to_input_weights_ =
- GetInput(operation, operands, kBwInputToInputWeightsTensor); // optional
- bw_input_to_forget_weights_ = GetInput(operation, operands, kBwInputToForgetWeightsTensor);
- bw_input_to_cell_weights_ = GetInput(operation, operands, kBwInputToCellWeightsTensor);
- bw_input_to_output_weights_ = GetInput(operation, operands, kBwInputToOutputWeightsTensor);
- bw_recurrent_to_input_weights_ =
- GetInput(operation, operands, kBwRecurrentToInputWeightsTensor); // optional
- bw_recurrent_to_forget_weights_ =
- GetInput(operation, operands, kBwRecurrentToForgetWeightsTensor);
- bw_recurrent_to_cell_weights_ = GetInput(operation, operands, kBwRecurrentToCellWeightsTensor);
- bw_recurrent_to_output_weights_ =
- GetInput(operation, operands, kBwRecurrentToOutputWeightsTensor);
- bw_cell_to_input_weights_ =
- GetInput(operation, operands, kBwCellToInputWeightsTensor); // optional
- bw_cell_to_forget_weights_ =
- GetInput(operation, operands, kBwCellToForgetWeightsTensor); // optional
- bw_cell_to_output_weights_ =
- GetInput(operation, operands, kBwCellToOutputWeightsTensor); // optional
- bw_input_gate_bias_ = GetInput(operation, operands, kBwInputGateBiasTensor);
- bw_forget_gate_bias_ = GetInput(operation, operands, kBwForgetGateBiasTensor);
- bw_cell_bias_ = GetInput(operation, operands, kBwCellGateBiasTensor);
- bw_output_gate_bias_ = GetInput(operation, operands, kBwOutputGateBiasTensor);
- bw_projection_weights_ = GetInput(operation, operands, kBwProjectionWeightsTensor); // optional
- bw_projection_bias_ = GetInput(operation, operands, kBwProjectionBiasTensor); // optional
- bw_activation_state_ = GetInput(operation, operands, kBwInputActivationStateTensor);
- bw_cell_state_ = GetInput(operation, operands, kBwInputCellStateTensor);
- aux_input_ = GetInput(operation, operands, kAuxInputTensor);
- fw_aux_input_to_input_weights_ = GetInput(operation, operands, kFwAuxInputToInputWeightsTensor);
- fw_aux_input_to_forget_weights_ =
- GetInput(operation, operands, kFwAuxInputToForgetWeightsTensor);
- fw_aux_input_to_cell_weights_ = GetInput(operation, operands, kFwAuxInputToCellWeightsTensor);
- fw_aux_input_to_output_weights_ =
- GetInput(operation, operands, kFwAuxInputToOutputWeightsTensor);
- bw_aux_input_to_input_weights_ = GetInput(operation, operands, kBwAuxInputToInputWeightsTensor);
- bw_aux_input_to_forget_weights_ =
- GetInput(operation, operands, kBwAuxInputToForgetWeightsTensor);
- bw_aux_input_to_cell_weights_ = GetInput(operation, operands, kBwAuxInputToCellWeightsTensor);
- bw_aux_input_to_output_weights_ =
- GetInput(operation, operands, kBwAuxInputToOutputWeightsTensor);
- fw_input_layer_norm_weights_ = GetInput(operation, operands, kFwInputLayerNormWeightsTensor);
- fw_forget_layer_norm_weights_ = GetInput(operation, operands, kFwForgetLayerNormWeightsTensor);
- fw_cell_layer_norm_weights_ = GetInput(operation, operands, kFwCellLayerNormWeightsTensor);
- fw_output_layer_norm_weights_ = GetInput(operation, operands, kFwOutputLayerNormWeightsTensor);
- bw_input_layer_norm_weights_ = GetInput(operation, operands, kBwInputLayerNormWeightsTensor);
- bw_forget_layer_norm_weights_ = GetInput(operation, operands, kBwForgetLayerNormWeightsTensor);
- bw_cell_layer_norm_weights_ = GetInput(operation, operands, kBwCellLayerNormWeightsTensor);
- bw_output_layer_norm_weights_ = GetInput(operation, operands, kBwOutputLayerNormWeightsTensor);
- params_.activation = static_cast<TfLiteFusedActivation>(
- getScalarData<int32_t>(*GetInput(operation, operands, kActivationParam)));
- if (input_->type == OperandType::TENSOR_FLOAT32) {
- params_.cell_clip = getScalarData<float>(*GetInput(operation, operands, kCellClipParam));
- params_.proj_clip = getScalarData<float>(*GetInput(operation, operands, kProjClipParam));
- } else {
- params_.cell_clip = static_cast<float>(
- getScalarData<_Float16>(*GetInput(operation, operands, kCellClipParam)));
- params_.proj_clip = static_cast<float>(
- getScalarData<_Float16>(*GetInput(operation, operands, kProjClipParam)));
- }
- params_.merge_outputs = getScalarData<bool>(*GetInput(operation, operands, kMergeOutputsParam));
- params_.time_major = getScalarData<bool>(*GetInput(operation, operands, kTimeMajorParam));
- params_.use_layer_norm = !IsNullInput(fw_input_layer_norm_weights_);
- fw_output_ = GetOutput(operation, operands, kFwOutputTensor);
- if (!params_.merge_outputs) {
- bw_output_ = GetOutput(operation, operands, kBwOutputTensor);
- }
- }
- bool BidirectionalSequenceLSTM::Prepare(const Operation& operation,
- std::vector<RunTimeOperandInfo>& operands,
- Shape* fwOutputShape, Shape* bwOutputShape) {
- // Inferring batch size, number of outputs and number of cells from the
- // input tensors.
- NN_CHECK(NumDimensions(input_) == 3);
- const uint32_t max_time = SizeOfDimension(input_, params_.time_major ? 0 : 1);
- const uint32_t n_batch = SizeOfDimension(input_, params_.time_major ? 1 : 0);
- const uint32_t n_input = SizeOfDimension(input_, 2);
- const uint32_t n_fw_cell = SizeOfDimension(fw_input_to_output_weights_, 0);
- NN_CHECK_EQ(NumDimensions(fw_input_to_output_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(fw_input_to_output_weights_, 1), n_input);
- NN_CHECK_EQ(NumDimensions(fw_recurrent_to_output_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(fw_recurrent_to_output_weights_, 0), n_fw_cell);
- const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
- // Check that input tensor dimensions matches with each other.
- if (!LSTMCell::CheckInputTensorDimensions(
- input_, fw_input_to_input_weights_, fw_input_to_forget_weights_,
- fw_input_to_cell_weights_, fw_input_to_output_weights_,
- fw_recurrent_to_input_weights_, fw_recurrent_to_forget_weights_,
- fw_recurrent_to_cell_weights_, fw_recurrent_to_output_weights_,
- fw_cell_to_input_weights_, fw_cell_to_forget_weights_, fw_cell_to_output_weights_,
- fw_input_gate_bias_, fw_forget_gate_bias_, fw_cell_bias_, fw_output_gate_bias_,
- fw_projection_weights_, fw_projection_bias_, fw_input_layer_norm_weights_,
- fw_forget_layer_norm_weights_, fw_cell_layer_norm_weights_,
- fw_output_layer_norm_weights_, n_input, n_fw_output, n_fw_cell, ¶ms_)) {
- return false;
- }
- const bool aux_inputs_all_or_none =
- (!IsNullInput(aux_input_) && !IsNullInput(fw_aux_input_to_cell_weights_) &&
- !IsNullInput(fw_aux_input_to_forget_weights_) &&
- !IsNullInput(fw_aux_input_to_output_weights_) &&
- !IsNullInput(bw_aux_input_to_cell_weights_) &&
- !IsNullInput(bw_aux_input_to_forget_weights_) &&
- !IsNullInput(bw_aux_input_to_output_weights_)) ||
- (IsNullInput(fw_aux_input_to_cell_weights_) &&
- IsNullInput(fw_aux_input_to_forget_weights_) &&
- IsNullInput(fw_aux_input_to_output_weights_) &&
- IsNullInput(bw_aux_input_to_cell_weights_) &&
- IsNullInput(bw_aux_input_to_forget_weights_) &&
- IsNullInput(bw_aux_input_to_output_weights_));
- NN_CHECK(aux_inputs_all_or_none);
- if (!IsNullInput(aux_input_)) {
- // Check that aux_input has the same dimensions (except last) as the input.
- NN_CHECK_EQ(aux_input_->shape().dimensions[0], input_->shape().dimensions[0]);
- NN_CHECK_EQ(aux_input_->shape().dimensions[1], input_->shape().dimensions[1]);
- }
- const uint32_t n_bw_cell = SizeOfDimension(bw_input_to_output_weights_, 0);
- NN_CHECK_EQ(NumDimensions(bw_input_to_output_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(bw_input_to_output_weights_, 1), n_input);
- NN_CHECK_EQ(NumDimensions(bw_recurrent_to_output_weights_), 2);
- NN_CHECK_EQ(SizeOfDimension(bw_recurrent_to_output_weights_, 0), n_bw_cell);
- const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
- const Shape& inputShape = input_->shape();
- fwOutputShape->type = inputShape.type;
- fwOutputShape->offset = inputShape.offset;
- fwOutputShape->scale = inputShape.scale;
- fwOutputShape->dimensions.resize(3);
- fwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
- fwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
- fwOutputShape->dimensions[2] = params_.merge_outputs ? n_fw_output + n_bw_output : n_fw_output;
- // Check that input tensor dimensions matches with each other.
- if (!LSTMCell::CheckInputTensorDimensions(
- input_, bw_input_to_input_weights_, bw_input_to_forget_weights_,
- bw_input_to_cell_weights_, bw_input_to_output_weights_,
- bw_recurrent_to_input_weights_, bw_recurrent_to_forget_weights_,
- bw_recurrent_to_cell_weights_, bw_recurrent_to_output_weights_,
- bw_cell_to_input_weights_, bw_cell_to_forget_weights_, bw_cell_to_output_weights_,
- bw_input_gate_bias_, bw_forget_gate_bias_, bw_cell_bias_, bw_output_gate_bias_,
- bw_projection_weights_, bw_projection_bias_, bw_input_layer_norm_weights_,
- bw_forget_layer_norm_weights_, bw_cell_layer_norm_weights_,
- bw_output_layer_norm_weights_, n_input, n_bw_output, n_bw_cell, ¶ms_)) {
- return false;
- }
- if (!params_.merge_outputs) {
- bwOutputShape->type = inputShape.type;
- bwOutputShape->offset = inputShape.offset;
- bwOutputShape->scale = inputShape.scale;
- bwOutputShape->dimensions.resize(3);
- bwOutputShape->dimensions[0] = params_.time_major ? max_time : n_batch;
- bwOutputShape->dimensions[1] = params_.time_major ? n_batch : max_time;
- bwOutputShape->dimensions[2] = n_bw_output;
- }
- if (params_.use_cifg) {
- fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 3};
- bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 3};
- } else {
- fw_scratch_shape_.dimensions = {n_batch, n_fw_cell * 4};
- bw_scratch_shape_.dimensions = {n_batch, n_bw_cell * 4};
- }
- fw_scratch_shape_.type = bw_scratch_shape_.type = inputShape.type;
- fw_scratch_shape_.offset = bw_scratch_shape_.offset = inputShape.offset;
- fw_scratch_shape_.scale = bw_scratch_shape_.scale = inputShape.scale;
- return true;
- }
- bool BidirectionalSequenceLSTM::Eval() {
- const uint32_t n_fw_output = SizeOfDimension(fw_recurrent_to_output_weights_, 1);
- const uint32_t n_bw_output = SizeOfDimension(bw_recurrent_to_output_weights_, 1);
- std::vector<uint32_t> fw_output_dims = input_->shape().dimensions;
- fw_output_dims[2] = n_fw_output;
- std::vector<uint32_t> bw_output_dims = fw_output_dims;
- bw_output_dims[2] = n_bw_output;
- const uint32_t n_fw_output_elements = fw_output_dims[0] * fw_output_dims[1] * fw_output_dims[2];
- const uint32_t n_output_elements =
- fw_output_dims[0] * fw_output_dims[1] * (fw_output_dims[2] + bw_output_dims[2]);
- switch (input_->type) {
- case OperandType::TENSOR_FLOAT32: {
- std::vector<float> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
- const bool kForwardSequence = true;
- LSTMCell::LSTMEvalFloat32(
- params_, GetBuffer<const float>(input_), input_->shape(),
- GetBuffer<const float>(fw_input_to_input_weights_),
- GetBuffer<const float>(fw_input_to_forget_weights_),
- GetBuffer<const float>(fw_input_to_cell_weights_),
- GetBuffer<const float>(fw_input_to_output_weights_),
- fw_input_to_output_weights_->shape(),
- GetBuffer<const float>(fw_recurrent_to_input_weights_),
- GetBuffer<const float>(fw_recurrent_to_forget_weights_),
- GetBuffer<const float>(fw_recurrent_to_cell_weights_),
- GetBuffer<const float>(fw_recurrent_to_output_weights_),
- fw_recurrent_to_output_weights_->shape(),
- GetBuffer<const float>(fw_cell_to_input_weights_),
- GetBuffer<const float>(fw_cell_to_forget_weights_),
- GetBuffer<const float>(fw_cell_to_output_weights_),
- GetOptionalBuffer<const float>(aux_input_),
- GetOptionalBuffer<const float>(fw_aux_input_to_input_weights_),
- GetOptionalBuffer<const float>(fw_aux_input_to_forget_weights_),
- GetOptionalBuffer<const float>(fw_aux_input_to_cell_weights_),
- GetOptionalBuffer<const float>(fw_aux_input_to_output_weights_),
- GetBuffer<const float>(fw_input_gate_bias_),
- GetBuffer<const float>(fw_forget_gate_bias_),
- GetBuffer<const float>(fw_cell_bias_),
- GetBuffer<const float>(fw_output_gate_bias_),
- GetBuffer<const float>(fw_projection_weights_),
- GetBuffer<const float>(fw_projection_bias_),
- GetBuffer<const float>(fw_activation_state_),
- GetBuffer<const float>(fw_cell_state_),
- GetOptionalBuffer<const float>(fw_input_layer_norm_weights_),
- GetOptionalBuffer<const float>(fw_forget_layer_norm_weights_),
- GetOptionalBuffer<const float>(fw_cell_layer_norm_weights_),
- GetOptionalBuffer<const float>(fw_output_layer_norm_weights_),
- GetBuffer<float>(fw_activation_state_), GetBuffer<float>(fw_cell_state_),
- GetBuffer<float>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
- kForwardSequence);
- std::vector<float> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
- const bool kBackwardSequence = false;
- LSTMCell::LSTMEvalFloat32(
- params_, GetBuffer<const float>(input_), input_->shape(),
- GetBuffer<const float>(bw_input_to_input_weights_),
- GetBuffer<const float>(bw_input_to_forget_weights_),
- GetBuffer<const float>(bw_input_to_cell_weights_),
- GetBuffer<const float>(bw_input_to_output_weights_),
- bw_input_to_output_weights_->shape(),
- GetBuffer<const float>(bw_recurrent_to_input_weights_),
- GetBuffer<const float>(bw_recurrent_to_forget_weights_),
- GetBuffer<const float>(bw_recurrent_to_cell_weights_),
- GetBuffer<const float>(bw_recurrent_to_output_weights_),
- bw_recurrent_to_output_weights_->shape(),
- GetBuffer<const float>(bw_cell_to_input_weights_),
- GetBuffer<const float>(bw_cell_to_forget_weights_),
- GetBuffer<const float>(bw_cell_to_output_weights_),
- GetOptionalBuffer<const float>(aux_input_),
- GetOptionalBuffer<const float>(bw_aux_input_to_input_weights_),
- GetOptionalBuffer<const float>(bw_aux_input_to_forget_weights_),
- GetOptionalBuffer<const float>(bw_aux_input_to_cell_weights_),
- GetOptionalBuffer<const float>(bw_aux_input_to_output_weights_),
- GetBuffer<const float>(bw_input_gate_bias_),
- GetBuffer<const float>(bw_forget_gate_bias_),
- GetBuffer<const float>(bw_cell_bias_),
- GetBuffer<const float>(bw_output_gate_bias_),
- GetBuffer<const float>(bw_projection_weights_),
- GetBuffer<const float>(bw_projection_bias_),
- GetBuffer<const float>(bw_activation_state_),
- GetBuffer<const float>(bw_cell_state_),
- GetOptionalBuffer<const float>(bw_input_layer_norm_weights_),
- GetOptionalBuffer<const float>(bw_forget_layer_norm_weights_),
- GetOptionalBuffer<const float>(bw_cell_layer_norm_weights_),
- GetOptionalBuffer<const float>(bw_output_layer_norm_weights_),
- GetBuffer<float>(bw_activation_state_), GetBuffer<float>(bw_cell_state_),
- params_.merge_outputs ? GetBuffer<float>(fw_output_) + n_fw_output_elements
- : GetBuffer<float>(bw_output_),
- bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
- if (params_.merge_outputs) {
- std::vector<float> temp(n_output_elements);
- mergeThirdDimension(GetBuffer<float>(fw_output_), fw_output_dims,
- GetBuffer<float>(fw_output_) + n_fw_output_elements,
- bw_output_dims, temp.data());
- std::copy(temp.data(), temp.data() + n_output_elements,
- GetBuffer<float>(fw_output_));
- }
- } break;
- case OperandType::TENSOR_FLOAT16: {
- std::vector<_Float16> fw_scratch_buffer(getNumberOfElements(fw_scratch_shape_));
- const bool kForwardSequence = true;
- LSTMCell::LSTMEvalFloat16(
- params_, GetBuffer<const _Float16>(input_), input_->shape(),
- GetOptionalBuffer<const _Float16>(fw_input_to_input_weights_),
- GetBuffer<const _Float16>(fw_input_to_forget_weights_),
- GetBuffer<const _Float16>(fw_input_to_cell_weights_),
- GetBuffer<const _Float16>(fw_input_to_output_weights_),
- fw_input_to_output_weights_->shape(),
- GetOptionalBuffer<const _Float16>(fw_recurrent_to_input_weights_),
- GetBuffer<const _Float16>(fw_recurrent_to_forget_weights_),
- GetBuffer<const _Float16>(fw_recurrent_to_cell_weights_),
- GetBuffer<const _Float16>(fw_recurrent_to_output_weights_),
- fw_recurrent_to_output_weights_->shape(),
- GetOptionalBuffer<const _Float16>(fw_cell_to_input_weights_),
- GetOptionalBuffer<const _Float16>(fw_cell_to_forget_weights_),
- GetOptionalBuffer<const _Float16>(fw_cell_to_output_weights_),
- GetOptionalBuffer<const _Float16>(aux_input_),
- GetOptionalBuffer<const _Float16>(fw_aux_input_to_input_weights_),
- GetOptionalBuffer<const _Float16>(fw_aux_input_to_forget_weights_),
- GetOptionalBuffer<const _Float16>(fw_aux_input_to_cell_weights_),
- GetOptionalBuffer<const _Float16>(fw_aux_input_to_output_weights_),
- GetOptionalBuffer<const _Float16>(fw_input_gate_bias_),
- GetBuffer<const _Float16>(fw_forget_gate_bias_),
- GetBuffer<const _Float16>(fw_cell_bias_),
- GetBuffer<const _Float16>(fw_output_gate_bias_),
- GetOptionalBuffer<const _Float16>(fw_projection_weights_),
- GetOptionalBuffer<const _Float16>(fw_projection_bias_),
- GetBuffer<const _Float16>(fw_activation_state_),
- GetBuffer<const _Float16>(fw_cell_state_),
- GetOptionalBuffer<const _Float16>(fw_input_layer_norm_weights_),
- GetOptionalBuffer<const _Float16>(fw_forget_layer_norm_weights_),
- GetOptionalBuffer<const _Float16>(fw_cell_layer_norm_weights_),
- GetOptionalBuffer<const _Float16>(fw_output_layer_norm_weights_),
- GetBuffer<_Float16>(fw_activation_state_), GetBuffer<_Float16>(fw_cell_state_),
- GetBuffer<_Float16>(fw_output_), fw_scratch_buffer.data(), params_.time_major,
- kForwardSequence);
- std::vector<_Float16> bw_scratch_buffer(getNumberOfElements(bw_scratch_shape_));
- const bool kBackwardSequence = false;
- LSTMCell::LSTMEvalFloat16(
- params_, GetBuffer<const _Float16>(input_), input_->shape(),
- GetOptionalBuffer<const _Float16>(bw_input_to_input_weights_),
- GetBuffer<const _Float16>(bw_input_to_forget_weights_),
- GetBuffer<const _Float16>(bw_input_to_cell_weights_),
- GetBuffer<const _Float16>(bw_input_to_output_weights_),
- bw_input_to_output_weights_->shape(),
- GetOptionalBuffer<const _Float16>(bw_recurrent_to_input_weights_),
- GetBuffer<const _Float16>(bw_recurrent_to_forget_weights_),
- GetBuffer<const _Float16>(bw_recurrent_to_cell_weights_),
- GetBuffer<const _Float16>(bw_recurrent_to_output_weights_),
- bw_recurrent_to_output_weights_->shape(),
- GetOptionalBuffer<const _Float16>(bw_cell_to_input_weights_),
- GetOptionalBuffer<const _Float16>(bw_cell_to_forget_weights_),
- GetOptionalBuffer<const _Float16>(bw_cell_to_output_weights_),
- GetOptionalBuffer<const _Float16>(aux_input_),
- GetOptionalBuffer<const _Float16>(bw_aux_input_to_input_weights_),
- GetOptionalBuffer<const _Float16>(bw_aux_input_to_forget_weights_),
- GetOptionalBuffer<const _Float16>(bw_aux_input_to_cell_weights_),
- GetOptionalBuffer<const _Float16>(bw_aux_input_to_output_weights_),
- GetOptionalBuffer<const _Float16>(bw_input_gate_bias_),
- GetBuffer<const _Float16>(bw_forget_gate_bias_),
- GetBuffer<const _Float16>(bw_cell_bias_),
- GetBuffer<const _Float16>(bw_output_gate_bias_),
- GetOptionalBuffer<const _Float16>(bw_projection_weights_),
- GetOptionalBuffer<const _Float16>(bw_projection_bias_),
- GetBuffer<const _Float16>(bw_activation_state_),
- GetBuffer<const _Float16>(bw_cell_state_),
- GetOptionalBuffer<const _Float16>(bw_input_layer_norm_weights_),
- GetOptionalBuffer<const _Float16>(bw_forget_layer_norm_weights_),
- GetOptionalBuffer<const _Float16>(bw_cell_layer_norm_weights_),
- GetOptionalBuffer<const _Float16>(bw_output_layer_norm_weights_),
- GetBuffer<_Float16>(bw_activation_state_), GetBuffer<_Float16>(bw_cell_state_),
- params_.merge_outputs ? GetBuffer<_Float16>(fw_output_) + n_fw_output_elements
- : GetBuffer<_Float16>(bw_output_),
- bw_scratch_buffer.data(), params_.time_major, kBackwardSequence);
- if (params_.merge_outputs) {
- std::vector<_Float16> temp(n_output_elements);
- mergeThirdDimension(GetBuffer<_Float16>(fw_output_), fw_output_dims,
- GetBuffer<_Float16>(fw_output_) + n_fw_output_elements,
- bw_output_dims, temp.data());
- std::copy(temp.data(), temp.data() + n_output_elements,
- GetBuffer<_Float16>(fw_output_));
- }
- } break;
- default: {
- LOG(ERROR) << "Unsupported data type: " << static_cast<int>(input_->type);
- return false;
- }
- }
- return true;
- }
- } // namespace nn
- } // namespace android
|