Skip to content

Commit

Permalink
Fix bug where linear_output_ is not cleared when linear_before_reset …
Browse files Browse the repository at this point in the history
…is true and no bias input is provided. Requires a batch size of 3 or more to trigger if initial_h is not provided. (#3893)
  • Loading branch information
skottmckay authored May 12, 2020
1 parent 475ea38 commit 99415f0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 35 deletions.
20 changes: 8 additions & 12 deletions onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,22 +526,18 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
const int total_rows = max_sequence_length * batch_size_;

float alpha = 1.0f;
float beta = 0.0f; // zero out outputZRH_ when calling ComputeGemm.

// apply weights to all the inputs
ComputeGemm(total_rows, hidden_size_x3, input_size_, alpha,
inputs.cbegin(), inputs.cend(),
input_size_,
input_weights.cbegin(), input_weights.cend(),
input_size_, beta,
input_size_, 0.f,
outputZRH_.begin(), outputZRH_.end(),
hidden_size_x3, ttp_);

DumpMatrix("inputs with weights applied", outputZRH_.data(), seq_length_ * batch_size_ * 3, hidden_size_);

// set to 1 so the weighted inputs in outputZRH_ are added to the result in the next call to ComputeGemm
beta = 1.0f;

// output shape is [seq_length, num_directions, batch_size, hidden_size]
// if we are doing 2 directions and this is the forward pass we're writing to the real output so
// need to include num_directions in the step length.
Expand Down Expand Up @@ -599,7 +595,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
prev_Ht, prev_Ht_end,
hidden_size_,
recurrent_weightsZR.cbegin(), recurrent_weightsZR.cend(),
hidden_size_, beta,
hidden_size_, 1.f, // beta == 1 so we add existing values in outputZRH_
outputZRH_.begin() + out_added_offset, outputZRH_.end(),
hidden_size_x3, ttp_);

Expand All @@ -619,8 +615,10 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
prev_Ht, prev_Ht_end, // Ht-1
hidden_size_,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
linear_output_.begin(), linear_output_.end(), // pre: Rbh if use_bias_, post:output
hidden_size_,
use_bias_ ? 1.f : 0.f, // don't add values in linear_output_ if no bias input
linear_output_.begin(),
linear_output_.end(), // pre: Rbh if use_bias_, post:output
hidden_size_, ttp_);

DumpMatrix("Ht-1 * (Rh^T) + Rbh " + seqno_str, linear_output_.data(), batch_size_, hidden_size_);
Expand Down Expand Up @@ -690,7 +688,7 @@ void UniDirectionalGru<T>::Compute(const gsl::span<const T>& inputs_arg,
cur_h_local, cur_h_local_end, // rt (.) Ht-1
hidden_size_,
recurrent_weightsH.cbegin(), recurrent_weightsH.cend(), // Rh^T
hidden_size_, beta,
hidden_size_, 1.f, // beta == 1 to add Xt*(Wh^T) from out_H
out_H, outputZRH_.end(),
hidden_size_x3, ttp_);
}
Expand Down Expand Up @@ -830,9 +828,7 @@ void UniDirectionalGru<T>::AllocateBuffers() {
}

if (linear_before_reset_) {
// if use_bias_ is true we copy bias values to this as the first use. if it's false we don't and need to initialize
bool fill = !use_bias_;
linear_output_ = Allocate(allocator_, batch_size_ * hidden_size_, linear_output_ptr_, fill);
linear_output_ = Allocate(allocator_, batch_size_ * hidden_size_, linear_output_ptr_);
}

auto batch_times_seq_length = batch_size_ * seq_length_;
Expand Down
15 changes: 9 additions & 6 deletions onnxruntime/test/providers/cpu/rnn/GRU.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,17 +260,20 @@ def ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows():
print_results(fw_output)

@staticmethod
def BidirectionalDefaultActivationsSimpleWeightsNoBiasTwoRows(linear_before_reset=0):
def BidirectionalDefaultActivationsSimpleWeightsNoBias(linear_before_reset=0):

print(GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBiasTwoRows.__name__ +
print(GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBias.__name__ +
'.linear_before_reset=' + str(linear_before_reset))

seq_length = 2
batch_size = 2
batch_size = 3 if linear_before_reset else 2
input_size = 1
hidden_size = 3

input = np.array([[[1.], [2.]], [[10.], [11.]]]).astype(np.float32)
if linear_before_reset:
input = np.array([[[1.], [2.], [3.]], [[10.], [11.], [12.]]]).astype(np.float32)
else:
input = np.array([[[1.], [2.]], [[10.], [11.]]]).astype(np.float32)

W = np.array([0.1, 0.2, 0.3, 1, 2, 3, 10, 11, 12]).astype(np.float32).reshape(1, 3 * hidden_size, input_size)

Expand Down Expand Up @@ -384,8 +387,8 @@ def Legacy_TestGRUOpBidirectionalBasic():

GRU_ONNXRuntimeUnitTests.ForwardDefaultActivationsSimpleWeightsNoBiasTwoRows()
GRU_ONNXRuntimeUnitTests.ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows()
GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBiasTwoRows()
GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBiasTwoRows(linear_before_reset=1)
GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBias()
GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBias(linear_before_reset=1)

GRU_ONNXRuntimeUnitTests.ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallel()
GRU_ONNXRuntimeUnitTests.ForwardDefaultActivationsSimpleWeightsWithBiasBatchParallelLinearBeforeReset()
Expand Down
47 changes: 30 additions & 17 deletions onnxruntime/test/providers/cpu/rnn/deep_cpu_gru_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,21 @@ void DefaultActivationsSimpleWeightsNoBias(std::string direction,
const std::vector<float>& Y_h_data,
bool linear_before_reset = false) {
int64_t seq_length = 2;
int batch_size = 2;
int batch_size = linear_before_reset ? 3 : 2; // extra row to validate usage of linear_output_
int64_t input_size = 1;
int64_t hidden_size = 3;

int num_directions = direction == "bidirectional" ? 2 : 1;

std::vector<float> X_data{1.f, 2.f,
10.f, 11.f};
std::vector<float> X_data;

if (linear_before_reset) {
X_data = {1.f, 2.f, 3.f,
10.f, 11.f, 12.f};
} else {
X_data = {1.f, 2.f,
10.f, 11.f};
}

std::vector<float> W_data{0.1f, 0.2f, 0.3f, // wz
1.f, 2.f, 3.f, // wr
Expand Down Expand Up @@ -170,7 +177,7 @@ TEST(GRUTest, ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows) {
DefaultActivationsSimpleWeightsNoBias("reverse", Y_data, Y_h_data);
}

TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBiasTwoRows) {
TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBias) {
std::vector<float> Y_data{
// forward output for input sequence 0
0.4750208f, 0.450166f, 0.4255575f,
Expand Down Expand Up @@ -200,32 +207,38 @@ TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBiasTwoRows) {
DefaultActivationsSimpleWeightsNoBias("bidirectional", Y_data, Y_h_data);
}

TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBiasTwoRowsLinearBeforeReset) {
TEST(GRUTest, BidirectionalDefaultActivationsSimpleWeightsNoBiasLinearBeforeReset) {
std::vector<float> Y_data{
// forward output for input sequence 0
0.47502081f, 0.450166f, 0.42555748f,
0.450166f, 0.40131234f, 0.35434369f,
0.4750208f, 0.450166f, 0.4255575f,
0.45016602f, 0.40131235f, 0.35434368f,
0.42555748f, 0.35434369f, 0.28905049f,

// reverse output for input sequence 0 [sequence 1 in reversed input]
0.60827853f, 0.50623393f, 0.4426924f,
0.5803454f, 0.4527356f, 0.36886264f,
0.6082785f, 0.50623393f, 0.4426924f,
0.5803454f, 0.4527356f, 0.36886263f,
0.5521325f, 0.40092295f, 0.30118297f,

// forward output for input sequence 1
0.60270932f, 0.50830227f, 0.44950222f,
0.57543688f, 0.45485455f, 0.37478411f,
0.6027093f, 0.5083023f, 0.44950223f,
0.5754369f, 0.45485455f, 0.3747841f,
0.54791767f, 0.40301081f, 0.30608854f,

// reverse output for input sequence 1 [sequence 0 in reversed input]
0.26894142f, 0.11920292f, 0.04742587f,
0.24973989f, 0.09975048f, 0.03557118f};
0.26894143f, 0.11920292f, 0.04742587f,
0.24973989f, 0.09975048f, 0.03557118f,
0.23147521f, 0.08317269f, 0.02659699f};

std::vector<float> Y_h_data{
// we did the forward processing of input[1] last
0.60270932f, 0.50830227f, 0.44950222f,
0.57543688f, 0.45485455f, 0.37478411f,
0.6027093f, 0.5083023f, 0.44950223f,
0.5754369f, 0.45485455f, 0.3747841f,
0.54791767f, 0.40301081f, 0.30608854f,

// and the reverse processing of input[0] last as the input order was reversed
0.60827853f, 0.50623393f, 0.4426924f,
0.5803454f, 0.4527356f, 0.36886264f};
0.6082785f, 0.50623393f, 0.4426924f,
0.5803454f, 0.4527356f, 0.36886263f,
0.5521325f, 0.40092295f, 0.30118297f};

DefaultActivationsSimpleWeightsNoBias("bidirectional", Y_data, Y_h_data, true);
}
Expand Down

0 comments on commit 99415f0

Please sign in to comment.