Skip to content

Commit

Permalink
Fix handling of some errors during an AWS HTTP request (#7811)
Browse files Browse the repository at this point in the history
- Fixes the retry mechanism of some AWS HTTP requests
  in some cases, like STS.
  The osquery MakeRequest implementation for the AWS SDK
  was incorrectly setting a 200 response code when the osquery http client
  would throw an exception, due to some internal error or simply
  due to reaching the timeout for sending the request.
  Not only this hides some of the logging that could happen
  when this is reported as an error, but it also prevents
  the AWS SDK logic to retry again.

- Improve again the STS credentials retrieval failure message,
  since in some cases the error message was empty.
  Now print the error message when present, the STS error type,
  and the HTTP response code, when present.

- Improve support on shutting down quickly when the AWS logger plugin
  is retrying sending logs.
  • Loading branch information
Smjert authored Nov 22, 2022
1 parent 0ab780a commit 65216e1
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 10 deletions.
40 changes: 36 additions & 4 deletions osquery/utils/aws/aws_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,21 @@ std::shared_ptr<Aws::Http::HttpResponse> OsqueryHttpClient::MakeRequest(
auto response = std::make_shared<Standard::StandardHttpResponse>(request_ptr);
http::Response resp;

if (osquery::shutdownRequested()) {
/* This is technically a client error, but some AWS requests
consider any client error as retryable.
Since we want to stop the retries,
we use instead a non-retryable response code,
although we did not have any response.
We also log the reason of the failure to provide more information */

response->SetResponseCode(Aws::Http::HttpResponseCode::BLOCKED);
LOG(WARNING) << "An AWS request has been blocked since a shutdown has been "
"requested";

return response;
}

try {
switch (request.GetMethod()) {
case Aws::Http::HttpMethod::HTTP_GET:
Expand Down Expand Up @@ -268,8 +283,8 @@ std::shared_ptr<Aws::Http::HttpResponse> OsqueryHttpClient::MakeRequest(
request.GetMethod())
<< " request to URL (" << url << "): " << e.what();

response->SetResponseCode(
static_cast<Aws::Http::HttpResponseCode>(resp.status()));
response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION);
response->SetClientErrorMessage(e.what());
}

return response;
Expand Down Expand Up @@ -335,11 +350,28 @@ OsquerySTSAWSCredentialsProvider::GetAWSCredentials() {
access_key_id_ = sts_result.GetCredentials().GetAccessKeyId();
secret_access_key_ = sts_result.GetCredentials().GetSecretAccessKey();
session_token_ = sts_result.GetCredentials().GetSessionToken();

// Calculate when our credentials will expire.
token_expire_time_ = current_time + FLAGS_aws_sts_timeout;
} else {
LOG(ERROR) << "Failed to create STS temporary credentials, error: "
<< sts_outcome.GetError().GetMessage();
const auto& error = sts_outcome.GetError();

std::stringstream error_message;

error_message << static_cast<int>(error.GetErrorType());

if (error.GetResponseCode() !=
Aws::Http::HttpResponseCode::REQUEST_NOT_MADE) {
error_message << ", HTTP responde code: "
<< static_cast<int>(error.GetResponseCode());
}

if (!error.GetMessage().empty()) {
error_message << ", error message: " << error.GetMessage();
}

LOG(ERROR) << "Failed to create STS temporary credentials, error type: "
<< error_message.rdbuf();
}
}
return Aws::Auth::AWSCredentials(
Expand Down
27 changes: 22 additions & 5 deletions plugins/logger/aws_log_forwarder.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,9 @@ class AwsLogForwarder : public BufferedLogForwarder {
}

std::stringstream output;
output << name_ << ": The following log records have been discarded "
"because they were too big:\n";
output << name_
<< ": The following log records have been discarded "
"because they were too big:\n";

for (const auto& record : discarded_records) {
output << record << "\n";
Expand Down Expand Up @@ -132,8 +133,9 @@ class AwsLogForwarder : public BufferedLogForwarder {
if (!status.ok()) {
// To achieve behavior parity with TLS logger plugin, skip non-JSON
// content
LOG(ERROR) << name_ << ": The following log record has been discarded "
"because it was not in JSON format: "
LOG(ERROR) << name_
<< ": The following log record has been discarded "
"because it was not in JSON format: "
<< record;

continue;
Expand Down Expand Up @@ -198,6 +200,12 @@ class AwsLogForwarder : public BufferedLogForwarder {
(retry == 0 ? 0 : base_retry_delay) + (retry * 1000U);
if (retry_delay != 0) {
pause(std::chrono::milliseconds(retry_delay));

/* Stop retrying, osquery should shutdown; we fail the send
so that it's attempted again at the next start */
if (interrupted()) {
return false;
}
}

// Attempt to send the batch
Expand Down Expand Up @@ -279,6 +287,15 @@ class AwsLogForwarder : public BufferedLogForwarder {
for (auto batch_it = batch_list.begin(); batch_it != batch_list.end();) {
auto& batch = *batch_it;
if (!sendBatch(batch, status_output)) {
/* Since we are shutting down, we don't want to count this send failure
as a real error; returning with failure here will make
the BufferedLogForwarder try to send this batch again
when osquery starts again */
if (interrupted()) {
return Status::failure(
"Interrupted sending log batch due to osquery shutdown");
}

// We couldn't write some of the records; log them locally so that the
// administrator will at least be able to inspect them
dumpBatchToErrorLog(batch);
Expand Down Expand Up @@ -339,4 +356,4 @@ class AwsLogForwarder : public BufferedLogForwarder {
/// Service endpoint override
std::string endpoint_override_;
};
}
} // namespace osquery
10 changes: 9 additions & 1 deletion plugins/logger/buffered.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ void BufferedLogForwarder::check() {
status = send(results, "result");
if (!status.ok()) {
VLOG(1) << "Error sending results to logger: " << status.getMessage();

if (interrupted()) {
return;
}
} else {
// Clear the results logs once they were sent.
iterate(indexes, ([this](std::string& index) {
Expand All @@ -88,6 +92,10 @@ void BufferedLogForwarder::check() {
status = send(statuses, "status");
if (!status.ok()) {
VLOG(1) << "Error sending status to logger: " << status.getMessage();

if (interrupted()) {
return;
}
} else {
// Clear the status logs once they were sent.
iterate(indexes, ([this](std::string& index) {
Expand Down Expand Up @@ -288,4 +296,4 @@ Status BufferedLogForwarder::deleteValueWithCount(const std::string& domain,
}
return status;
}
}
} // namespace osquery

0 comments on commit 65216e1

Please sign in to comment.