diff --git a/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h b/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h index 0249d7d0124..360b39a1622 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h +++ b/src/aws-cpp-sdk-core/include/aws/core/client/ClientConfiguration.h @@ -78,6 +78,16 @@ namespace Aws WHEN_REQUIRED, }; + /** + * Control HTTP client chunking implementation mode. + * DEFAULT: Use SDK's ChunkingInterceptor for aws-chunked encoding + * CLIENT_IMPLEMENTATION: Rely on HTTP client's native chunking (default for custom clients) + */ + enum class HttpClientChunkedMode { + DEFAULT, + CLIENT_IMPLEMENTATION, + }; + struct RequestCompressionConfig { UseRequestCompression useRequestCompression=UseRequestCompression::ENABLE; size_t requestMinCompressionSizeBytes = 10240; @@ -493,6 +503,12 @@ namespace Aws * https://docs.aws.amazon.com/sdkref/latest/guide/feature-account-endpoints.html */ Aws::String accountIdEndpointMode = "preferred"; + + /** + * Control HTTP client chunking implementation mode. + * Default is set automatically: CLIENT_IMPLEMENTATION for custom clients, DEFAULT for AWS clients. + */ + HttpClientChunkedMode httpClientChunkedMode = HttpClientChunkedMode::CLIENT_IMPLEMENTATION; /** * Configuration structure for credential providers in the AWS SDK. * This structure allows passing configuration options to credential providers diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h index cb6e928e768..d38c77f4dc0 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h @@ -48,6 +48,11 @@ namespace Aws */ virtual bool SupportsChunkedTransferEncoding() const { return true; } + /** + * Returns true if this is a default AWS SDK HTTP client implementation. + */ + virtual bool IsDefaultAwsHttpClient() const { return false; } + /** * Stops all requests in progress and prevents any others from initiating. */ diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h index e5cb2533387..a0a87619042 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h @@ -52,6 +52,8 @@ namespace Aws Aws::Utils::RateLimits::RateLimiterInterface* readLimiter, Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter) const override; + bool IsDefaultAwsHttpClient() const override { return true; } + private: // Yeah, I know, but someone made MakeRequest() const and didn't think about the fact that // making an HTTP request most certainly mutates state. It was me. I'm the person that did that, and diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h index 924cd59d830..087ed8d2c6f 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/curl/CurlHttpClient.h @@ -37,6 +37,8 @@ class AWS_CORE_API CurlHttpClient: public HttpClient Aws::Utils::RateLimits::RateLimiterInterface* readLimiter = nullptr, Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter = nullptr) const override; + bool IsDefaultAwsHttpClient() const override { return true; } + static void InitGlobalState(); static void CleanupGlobalState(); diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h index 24a427edee7..995b20197a6 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/IXmlHttpRequest2HttpClient.h @@ -54,6 +54,8 @@ namespace Aws */ virtual bool SupportsChunkedTransferEncoding() const override { return false; } + bool IsDefaultAwsHttpClient() const override { return true; } + protected: /** * Override any configuration on request handle. diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h index 61e3b0c4a3c..4c630780e37 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinHttpSyncHttpClient.h @@ -42,6 +42,8 @@ namespace Aws */ const char* GetLogTag() const override { return "WinHttpSyncHttpClient"; } + bool IsDefaultAwsHttpClient() const override { return true; } + private: // WinHttp specific implementations void* OpenRequest(const std::shared_ptr& request, void* connection, const Aws::StringStream& ss) const override; diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h index 52a1ce2d8f4..51b2680c4a6 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/windows/WinINetSyncHttpClient.h @@ -39,6 +39,8 @@ namespace Aws * Gets log tag for use in logging in the base class. */ const char* GetLogTag() const override { return "WinInetSyncHttpClient"; } + + bool IsDefaultAwsHttpClient() const override { return true; } private: // WinHttp specific implementations diff --git a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h index b808fe2bf54..bd3f2704380 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -20,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -99,8 +101,13 @@ namespace client m_serviceUserAgentName(std::move(serviceUserAgentName)), m_httpClient(std::move(httpClient)), m_errorMarshaller(std::move(errorMarshaller)), - m_interceptors{Aws::MakeShared("AwsSmithyClientBase", *m_clientConfig)} + m_interceptors({ + Aws::MakeShared("AwsSmithyClientBase", *m_clientConfig), + Aws::MakeShared("AwsSmithyClientBase", + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : m_clientConfig->httpClientChunkedMode) + }) { + baseInit(); } diff --git a/src/aws-cpp-sdk-core/include/smithy/client/features/ChunkingInterceptor.h b/src/aws-cpp-sdk-core/include/smithy/client/features/ChunkingInterceptor.h new file mode 100644 index 00000000000..2e421a1cb59 --- /dev/null +++ b/src/aws-cpp-sdk-core/include/smithy/client/features/ChunkingInterceptor.h @@ -0,0 +1,232 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace smithy { +namespace client { +namespace features { + +static const size_t AWS_DATA_BUFFER_SIZE = 65536; +static const char* ALLOCATION_TAG = "ChunkingInterceptor"; +static const char* CHECKSUM_HEADER_PREFIX = "x-amz-checksum-"; + +template +class AwsChunkedStreamBuf : public std::streambuf { +public: + AwsChunkedStreamBuf(Aws::Http::HttpRequest* request, + const std::shared_ptr& stream, + size_t bufferSize = DataBufferSize) + : m_request(request), + m_stream(stream), + m_data(bufferSize) + { + assert(m_stream != nullptr); + if (m_stream == nullptr) { + AWS_LOGSTREAM_ERROR("AwsChunkedStream", "stream is null"); + } + assert(m_request != nullptr); + if (m_request == nullptr) { + AWS_LOGSTREAM_ERROR("AwsChunkedStream", "request is null"); + } + } + +protected: + int_type underflow() override { + if (gptr() && gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + + // Compact buffer when consumed data exceeds half buffer size + if (m_chunkingBufferPos > m_chunkingBuffer.GetLength() / 2) { + size_t remaining = m_chunkingBufferSize - m_chunkingBufferPos; + if (remaining > 0) { + std::memmove(m_chunkingBuffer.GetUnderlyingData(), + m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos, + remaining); + } + m_chunkingBufferSize = remaining; + m_chunkingBufferPos = 0; + } + + // only read and write to chunked stream if the underlying stream + // is still in a valid state and we have buffer space + if (m_stream->good()) { + // Check if we have enough space for worst-case chunk (data + header + footer) + size_t maxChunkSize = m_data.GetLength() + 20; // data + hex header + CRLF + if (m_chunkingBufferSize + maxChunkSize <= m_chunkingBuffer.GetLength()) { + // Try to read in a 64K chunk, if we cant we know the stream is over + m_stream->read(m_data.GetUnderlyingData(), m_data.GetLength()); + size_t bytesRead = static_cast(m_stream->gcount()); + writeChunk(bytesRead); + + // if we've read everything from the stream, we want to add the trailer + // to the underlying stream + if ((m_stream->peek() == EOF || m_stream->eof()) && !m_stream->bad()) { + writeTrailerToUnderlyingStream(); + } + } + } + + // if the chunking buffer is empty there is nothing to read + if (m_chunkingBufferPos >= m_chunkingBufferSize) { + return traits_type::eof(); + } + + // Set up buffer pointers to read from chunking buffer + size_t remainingBytes = m_chunkingBufferSize - m_chunkingBufferPos; + size_t bytesToRead = std::min(remainingBytes, DataBufferSize); + + setg(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos, + m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos, + m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferPos + bytesToRead); + + m_chunkingBufferPos += bytesToRead; + + return traits_type::to_int_type(*gptr()); + } + +private: + void writeTrailerToUnderlyingStream() { + Aws::String trailer = "0\r\n"; + if (m_request->GetRequestHash().second != nullptr) { + trailer += "x-amz-checksum-" + m_request->GetRequestHash().first + ":" + + Aws::Utils::HashingUtils::Base64Encode(m_request->GetRequestHash().second->GetHash().GetResult()) + "\r\n"; + } + trailer += "\r\n"; + if (m_chunkingBufferSize + trailer.length() <= m_chunkingBuffer.GetLength()) { + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, trailer.c_str(), trailer.length()); + m_chunkingBufferSize += trailer.length(); + } + } + + void writeChunk(size_t bytesRead) { + if (m_request->GetRequestHash().second != nullptr) { + m_request->GetRequestHash().second->Update(reinterpret_cast(m_data.GetUnderlyingData()), bytesRead); + } + + if (bytesRead > 0) { + Aws::String chunkHeader = Aws::Utils::StringUtils::ToHexString(bytesRead) + "\r\n"; + size_t totalSize = chunkHeader.length() + bytesRead + 2; + if (m_chunkingBufferSize + totalSize <= m_chunkingBuffer.GetLength()) { + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, chunkHeader.c_str(), chunkHeader.length()); + m_chunkingBufferSize += chunkHeader.length(); + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, m_data.GetUnderlyingData(), bytesRead); + m_chunkingBufferSize += bytesRead; + std::memcpy(m_chunkingBuffer.GetUnderlyingData() + m_chunkingBufferSize, "\r\n", 2); + m_chunkingBufferSize += 2; + } + } + } + + Aws::Utils::Array m_chunkingBuffer{DataBufferSize * 4}; + size_t m_chunkingBufferSize{0}; + size_t m_chunkingBufferPos{0}; + Aws::Http::HttpRequest* m_request{nullptr}; + std::shared_ptr m_stream; + Aws::Utils::Array m_data; +}; + +class AwsChunkedIOStream : public Aws::IOStream { +public: + AwsChunkedIOStream(Aws::Http::HttpRequest* request, + const std::shared_ptr& originalBody, + size_t bufferSize = AWS_DATA_BUFFER_SIZE) + : Aws::IOStream(&m_buf), + m_buf(request, originalBody, bufferSize) {} + +private: + AwsChunkedStreamBuf<> m_buf; +}; + +/** + * Interceptor that handles chunked encoding for streaming requests with checksums. + * Wraps request body with chunked stream and sets appropriate headers. + */ +class ChunkingInterceptor : public smithy::interceptor::Interceptor { +public: + explicit ChunkingInterceptor(Aws::Client::HttpClientChunkedMode httpClientChunkedMode) + : m_httpClientChunkedMode(httpClientChunkedMode) {} + ~ChunkingInterceptor() override = default; + + ModifyRequestOutcome ModifyBeforeSigning(smithy::interceptor::InterceptorContext& context) override { + auto request = context.GetTransmitRequest(); + + if (!ShouldApplyChunking(request)) { + return request; + } + + auto originalBody = request->GetContentBody(); + if (!originalBody) { + return request; + } + + // Set up chunked encoding headers for checksum calculation + const auto& hashPair = request->GetRequestHash(); + if (hashPair.second != nullptr) { + Aws::String checksumHeaderValue = Aws::String(CHECKSUM_HEADER_PREFIX) + hashPair.first; + request->DeleteHeader(checksumHeaderValue.c_str()); + request->SetHeaderValue(Aws::Http::AWS_TRAILER_HEADER, checksumHeaderValue); + request->SetTransferEncoding(Aws::Http::CHUNKED_VALUE); + + if (!request->HasContentEncoding()) { + request->SetContentEncoding(Aws::Http::AWS_CHUNKED_VALUE); + } else { + Aws::String currentEncoding = request->GetContentEncoding(); + if (currentEncoding.find(Aws::Http::AWS_CHUNKED_VALUE) == Aws::String::npos) { + request->SetContentEncoding(Aws::String{Aws::Http::AWS_CHUNKED_VALUE} + "," + currentEncoding); + } + } + + if (request->HasHeader(Aws::Http::CONTENT_LENGTH_HEADER)) { + request->SetHeaderValue(Aws::Http::DECODED_CONTENT_LENGTH_HEADER, request->GetHeaderValue(Aws::Http::CONTENT_LENGTH_HEADER)); + request->DeleteHeader(Aws::Http::CONTENT_LENGTH_HEADER); + } + } + + auto chunkedBody = Aws::MakeShared( + ALLOCATION_TAG, request.get(), originalBody); + + request->AddContentBody(chunkedBody); + return request; + } + + ModifyResponseOutcome ModifyBeforeDeserialization(smithy::interceptor::InterceptorContext& context) override { + return context.GetTransmitResponse(); + } + +private: + bool ShouldApplyChunking(const std::shared_ptr& request) const { + // Use configuration setting to determine chunking behavior + if (m_httpClientChunkedMode != Aws::Client::HttpClientChunkedMode::DEFAULT) { + return false; + } + + if (!request || !request->GetContentBody()) { + return false; + } + + // Check if request has checksum requirements + const auto& hashPair = request->GetRequestHash(); + return hashPair.second != nullptr; + } + + Aws::Client::HttpClientChunkedMode m_httpClientChunkedMode; +}; + +} // namespace features +} // namespace client +} // namespace smithy \ No newline at end of file diff --git a/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp b/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp index 1fc6094955c..bcca2d2e602 100644 --- a/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp +++ b/src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp @@ -218,26 +218,10 @@ bool AWSAuthV4Signer::SignRequestWithCreds(Aws::Http::HttpRequest& request, cons request.SetAwsSessionToken(credentials.GetSessionToken()); } - // If the request checksum, set the signer to use a unsigned - // trailing payload. otherwise use it in the header - if (request.GetRequestHash().second != nullptr && !request.GetRequestHash().first.empty() && request.GetContentBody() != nullptr) { - AWS_LOGSTREAM_DEBUG(v4LogTag, "Note: Http payloads are not being signed. signPayloads=" - << signBody << " http scheme=" << Http::SchemeMapper::ToString(request.GetUri().GetScheme())); - if (request.GetRequestHash().second != nullptr) { + // If the request has checksum and chunking was applied by interceptor, use streaming payload + if (request.GetRequestHash().second != nullptr && !request.GetRequestHash().first.empty() && + request.GetContentBody() != nullptr && request.HasHeader(Http::AWS_TRAILER_HEADER)) { payloadHash = STREAMING_UNSIGNED_PAYLOAD_TRAILER; - Aws::String checksumHeaderValue = Aws::String("x-amz-checksum-") + request.GetRequestHash().first; - request.DeleteHeader(checksumHeaderValue.c_str()); - request.SetHeaderValue(Http::AWS_TRAILER_HEADER, checksumHeaderValue); - request.SetTransferEncoding(CHUNKED_VALUE); - request.HasContentEncoding() - ? request.SetContentEncoding(Aws::String{Http::AWS_CHUNKED_VALUE} + "," + request.GetContentEncoding()) - : request.SetContentEncoding(Http::AWS_CHUNKED_VALUE); - - if (request.HasHeader(Http::CONTENT_LENGTH_HEADER)) { - request.SetHeaderValue(Http::DECODED_CONTENT_LENGTH_HEADER, request.GetHeaderValue(Http::CONTENT_LENGTH_HEADER)); - request.DeleteHeader(Http::CONTENT_LENGTH_HEADER); - } - } } else { payloadHash = ComputePayloadHash(request); if (payloadHash.empty()) { diff --git a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp index 1d4733f6eb6..90f83a58f33 100644 --- a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp +++ b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp @@ -46,6 +46,7 @@ #include #include +#include #include #include @@ -139,7 +140,8 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment), m_requestCompressionConfig(configuration.requestCompressionConfig), m_userAgentInterceptor{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration, m_retryStrategy->GetStrategyName(), m_serviceName)}, - m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG), m_userAgentInterceptor} + m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG), Aws::MakeShared(AWS_CLIENT_LOG_TAG, + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : configuration.httpClientChunkedMode), m_userAgentInterceptor} { } @@ -165,7 +167,8 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment), m_requestCompressionConfig(configuration.requestCompressionConfig), m_userAgentInterceptor{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration, m_retryStrategy->GetStrategyName(), m_serviceName)}, - m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration), m_userAgentInterceptor} + m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG, configuration), Aws::MakeShared(AWS_CLIENT_LOG_TAG, + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : configuration.httpClientChunkedMode), m_userAgentInterceptor} { } diff --git a/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp b/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp index 6eaf26e2a38..fe21ba5de11 100644 --- a/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp +++ b/src/aws-cpp-sdk-core/source/client/ClientConfiguration.cpp @@ -220,6 +220,7 @@ void setLegacyClientConfigurationParameters(ClientConfiguration& clientConfig) clientConfig.writeRateLimiter = nullptr; clientConfig.readRateLimiter = nullptr; clientConfig.httpLibOverride = Aws::Http::TransferLibType::DEFAULT_CLIENT; + clientConfig.httpClientChunkedMode = HttpClientChunkedMode::CLIENT_IMPLEMENTATION; clientConfig.followRedirects = FollowRedirectsPolicy::DEFAULT; clientConfig.disableExpectHeader = false; clientConfig.enableClockSkewAdjustment = true; diff --git a/src/aws-cpp-sdk-core/source/client/UserAgent.cpp b/src/aws-cpp-sdk-core/source/client/UserAgent.cpp index 909184b447a..dd6fa87c3a7 100644 --- a/src/aws-cpp-sdk-core/source/client/UserAgent.cpp +++ b/src/aws-cpp-sdk-core/source/client/UserAgent.cpp @@ -183,6 +183,15 @@ Aws::String UserAgent::SerializeWithFeatures(const Aws::Set& f SerializeMetadata(METADATA, m_compilerMetadata); } + // Add HTTP client metadata +#if AWS_SDK_USE_CRT_HTTP + SerializeMetadata(METADATA, "http#crt"); +#elif ENABLE_CURL_CLIENT + SerializeMetadata(METADATA, "http#curl"); +#elif ENABLE_WINDOWS_CLIENT + SerializeMetadata(METADATA, "http#winhttp"); +#endif + // metrics Aws::Vector encodedMetrics{}; diff --git a/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp index 4c392bdf280..14c1ef25b0f 100644 --- a/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp @@ -9,7 +9,6 @@ #include #include #include -#include #include #include @@ -379,11 +378,7 @@ namespace Aws if (request->GetContentBody()) { bool isStreaming = request->IsEventStreamRequest(); - if (request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER) == Aws::Http::AWS_CHUNKED_VALUE) { - crtRequest->SetBody(Aws::MakeShared>(CRT_HTTP_CLIENT_TAG, request.get(), request->GetContentBody())); - } else { - crtRequest->SetBody(Aws::MakeShared(CRT_HTTP_CLIENT_TAG, m_configuration.writeRateLimiter, request->GetContentBody(), *this, *request, isStreaming)); - } + crtRequest->SetBody(Aws::MakeShared(CRT_HTTP_CLIENT_TAG, m_configuration.writeRateLimiter, request->GetContentBody(), *this, *request, isStreaming)); } Crt::Http::HttpRequestOptions requestOptions; diff --git a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp index 1ac37f63eaf..58fc56875de 100644 --- a/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/curl/CurlHttpClient.cpp @@ -14,7 +14,6 @@ #include #include #include -#include #include #include @@ -155,21 +154,16 @@ static const char* CURL_HTTP_CLIENT_TAG = "CurlHttpClient"; struct CurlReadCallbackContext { CurlReadCallbackContext(const CurlHttpClient* client, CURL* curlHandle, HttpRequest* request, - Aws::Utils::RateLimits::RateLimiterInterface* limiter, - std::shared_ptr> chunkedStream = nullptr) + Aws::Utils::RateLimits::RateLimiterInterface* limiter) : m_client(client), m_curlHandle(curlHandle), m_rateLimiter(limiter), - m_request(request), - m_chunkEnd(false), - m_chunkedStream{std::move(chunkedStream)} {} + m_request(request) {} const CurlHttpClient* m_client; CURL* m_curlHandle; Aws::Utils::RateLimits::RateLimiterInterface* m_rateLimiter; HttpRequest* m_request; - bool m_chunkEnd; - std::shared_ptr> m_chunkedStream; }; static int64_t GetContentLengthFromHeader(CURL* connectionHandle, @@ -315,8 +309,6 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo const std::shared_ptr& ioStream = request->GetContentBody(); size_t amountToRead = size * nmemb; - bool isAwsChunked = request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && - request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER).find(Aws::Http::AWS_CHUNKED_VALUE) != Aws::String::npos; if (ioStream != nullptr && amountToRead > 0) { @@ -334,8 +326,6 @@ static size_t ReadBody(char* ptr, size_t size, size_t nmemb, void* userdata, boo return 0; } amountRead = (size_t)ioStream->readsome(ptr, amountToRead); - } else if (isAwsChunked && context->m_chunkedStream != nullptr) { - amountRead = context->m_chunkedStream->BufferedRead(ptr, amountToRead); } else { ioStream->read(ptr, amountToRead); amountRead = static_cast(ioStream->gcount()); @@ -380,7 +370,7 @@ static size_t SeekBody(void* userdata, curl_off_t offset, int origin) return CURL_SEEKFUNC_FAIL; } - // fail seek for aws-chunk encoded body as the length and offset is unknown + // Fail seek for aws-chunk encoded body as the length and offset is unknown if (context->m_request && context->m_request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && context->m_request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER).find(Aws::Http::AWS_CHUNKED_VALUE) != Aws::String::npos) @@ -388,6 +378,7 @@ static size_t SeekBody(void* userdata, curl_off_t offset, int origin) return CURL_SEEKFUNC_FAIL; } + HttpRequest* request = context->m_request; const std::shared_ptr& ioStream = request->GetContentBody(); @@ -713,13 +704,7 @@ std::shared_ptr CurlHttpClient::MakeRequest(const std::shared_ptr< CurlWriteCallbackContext writeContext(this, connectionHandle ,request.get(), response.get(), readLimiter); - const auto readContext = [this, &connectionHandle, &request, &writeLimiter]() -> CurlReadCallbackContext { - if (request->GetContentBody() != nullptr) { - auto chunkedBodyPtr = Aws::MakeShared>(CURL_HTTP_CLIENT_TAG, request.get(), request->GetContentBody()); - return {this, connectionHandle, request.get(), writeLimiter, std::move(chunkedBodyPtr)}; - } - return {this, connectionHandle, request.get(), writeLimiter}; - }(); + CurlReadCallbackContext readContext(this, connectionHandle, request.get(), writeLimiter); SetOptCodeForHttpMethod(connectionHandle, request); diff --git a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp index 7677e02052f..ee35bb5a81f 100644 --- a/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/windows/WinSyncHttpClient.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -100,21 +99,14 @@ bool WinSyncHttpClient::StreamPayloadToRequest(const std::shared_ptrHasTransferEncoding() && request->GetTransferEncoding() == Aws::Http::CHUNKED_VALUE; - bool isAwsChunked = request->HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && - request->GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER).find(Aws::Http::AWS_CHUNKED_VALUE) != Aws::String::npos; auto payloadStream = request->GetContentBody(); - const char CRLF[] = "\r\n"; if(payloadStream) { uint64_t bytesWritten; uint64_t bytesToRead = HTTP_REQUEST_WRITE_BUFFER_LENGTH; auto startingPos = payloadStream->tellg(); bool done = false; - // aws-chunk = hex(chunk-size) + CRLF + chunk-data + CRLF - // Length of hex(HTTP_REQUEST_WRITE_BUFFER_LENGTH) is 4; - // Length of each CRLF is 2. - // Reserve 8 bytes in total, should the request be aws-chunked. - char streamBuffer[ HTTP_REQUEST_WRITE_BUFFER_LENGTH + 8 ]; + char streamBuffer[HTTP_REQUEST_WRITE_BUFFER_LENGTH]; while(success && !done) { payloadStream->read(streamBuffer, bytesToRead); @@ -124,21 +116,6 @@ bool WinSyncHttpClient::StreamPayloadToRequest(const std::shared_ptr 0) { - if (isAwsChunked) - { - if (request->GetRequestHash().second != nullptr) - { - request->GetRequestHash().second->Update(reinterpret_cast(streamBuffer), static_cast(bytesRead)); - } - - Aws::String hex = Aws::Utils::StringUtils::ToHexString(static_cast(bytesRead)); - memcpy(streamBuffer + hex.size() + 2, streamBuffer, static_cast(bytesRead)); - memcpy(streamBuffer + hex.size() + 2 + bytesRead, CRLF, 2); - memcpy(streamBuffer, hex.c_str(), hex.size()); - memcpy(streamBuffer + hex.size(), CRLF, 2); - bytesRead += hex.size() + 4; - } - bytesWritten = DoWriteData(hHttpRequest, streamBuffer, bytesRead, isChunked); if (!bytesWritten) { @@ -164,27 +141,6 @@ bool WinSyncHttpClient::StreamPayloadToRequest(const std::shared_ptrGetRequestHash().second != nullptr) - { - chunkedTrailer << "x-amz-checksum-" << request->GetRequestHash().first << ":" - << Aws::Utils::HashingUtils::Base64Encode(request->GetRequestHash().second->GetHash().GetResult()) << CRLF; - } - chunkedTrailer << CRLF; - bytesWritten = DoWriteData(hHttpRequest, const_cast(chunkedTrailer.str().c_str()), chunkedTrailer.str().size(), isChunked); - if (!bytesWritten) - { - success = false; - } - else if(writeLimiter) - { - writeLimiter->ApplyAndPayForCost(bytesWritten); - } - } - if (success && isChunked) { bytesWritten = FinalizeWriteData(hHttpRequest); diff --git a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp index e799ab5a5a1..0d70b089e5c 100644 --- a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp +++ b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp @@ -25,6 +25,7 @@ #include using namespace smithy::client; +using namespace smithy::client::features; using namespace smithy::interceptor; using namespace smithy::components::tracing; @@ -102,7 +103,12 @@ void AwsSmithyClientBase::baseCopyAssign(const AwsSmithyClientBase& other, m_serviceUserAgentName = other.m_serviceUserAgentName; m_httpClient = std::move(httpClient); m_errorMarshaller = std::move(errorMarshaller); - m_interceptors = Aws::Vector>{Aws::MakeShared("AwsSmithyClientBase")}; + + m_interceptors = Aws::Vector>{ + Aws::MakeShared("AwsSmithyClientBase", *m_clientConfig), + Aws::MakeShared("AwsSmithyClientBase", + m_httpClient->IsDefaultAwsHttpClient() ? Aws::Client::HttpClientChunkedMode::DEFAULT : m_clientConfig->httpClientChunkedMode) + }; baseCopyInit(); } diff --git a/tests/aws-cpp-sdk-core-tests/utils/stream/ChunkingInterceptorTest.cpp b/tests/aws-cpp-sdk-core-tests/utils/stream/ChunkingInterceptorTest.cpp new file mode 100644 index 00000000000..742219e3505 --- /dev/null +++ b/tests/aws-cpp-sdk-core-tests/utils/stream/ChunkingInterceptorTest.cpp @@ -0,0 +1,162 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace Aws; +using namespace Aws::Http::Standard; +using namespace smithy::client::features; +using namespace Aws::Utils::Crypto; + +const char* CHUNKING_TEST_LOG_TAG = "CHUNKING_INTERCEPTOR_TEST"; + +// Mock implementation of AmazonWebServiceRequest +class MockRequest : public Aws::AmazonWebServiceRequest { +public: + std::shared_ptr GetBody() const override { return nullptr; } + Aws::Http::HeaderValueCollection GetHeaders() const override { return {}; } + const char* GetServiceRequestName() const override { return "MockRequest"; } +}; + +class ChunkingInterceptorTest : public Aws::Testing::AwsCppSdkGTestSuite { +protected: + template + void withChunkedStream(const std::string& input, size_t bufferSize, Fn&& fn) { + StandardHttpRequest request{"test.com", Http::HttpMethod::HTTP_GET}; + auto requestHash = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + request.SetRequestHash("crc32", requestHash); + auto inputStream = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + *inputStream << input; + + AwsChunkedIOStream wrapper{&request, inputStream, bufferSize}; + Aws::IOStream* stream = &wrapper; + + fn(*stream); + } +}; + +TEST_F(ChunkingInterceptorTest, ChunkedStreamShouldWork) { + withChunkedStream("1234567890123456789012345", 10, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + std::stringstream output; + + // Read in 10-byte chunks + for (int i = 0; i < 4; i++) { + chunkedStream.read(buffer, 10); + auto bytesRead = chunkedStream.gcount(); + output.write(buffer, bytesRead); + } + + // Read trailing checksum (greater than 10 chars) + chunkedStream.read(buffer, 40); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(36ul, bytesRead); + output.write(buffer, bytesRead); + + EXPECT_EQ("A\r\n1234567890\r\nA\r\n1234567890\r\n5\r\n12345\r\n0\r\nx-amz-checksum-crc32:78DeVw==\r\n\r\n", output.str()); + }); +} + +TEST_F(ChunkingInterceptorTest, ShouldNotRequireTwoReadsOnSmallChunk) { + withChunkedStream("12345", 100, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + chunkedStream.read(buffer, 100); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(46ul, bytesRead); + + std::string output(buffer, bytesRead); + EXPECT_EQ("5\r\n12345\r\n0\r\nx-amz-checksum-crc32:y/U6HA==\r\n\r\n", output); + }); +} + +TEST_F(ChunkingInterceptorTest, ShouldWorkOnSmallBuffer) { + withChunkedStream("1234567890", 5, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + + // First read - explicitly ask for 10 bytes (first chunk: "5\r\n12345\r\n") + chunkedStream.read(buffer, 10); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(10ul, bytesRead); + std::string firstRead(buffer, bytesRead); + EXPECT_EQ("5\r\n12345\r\n", firstRead); + + // Second read - now we expect the rest (46 bytes: second chunk + trailer) + chunkedStream.read(buffer, 100); + bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(46ul, bytesRead); + std::string secondRead(buffer, bytesRead); + EXPECT_EQ("5\r\n67890\r\n0\r\nx-amz-checksum-crc32:Jh2u5Q==\r\n\r\n", secondRead); + + // Subsequent reads should return 0 + chunkedStream.read(buffer, 100); + bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(0ul, bytesRead); + }); +} + +TEST_F(ChunkingInterceptorTest, ShouldWorkOnEmptyStream) { + withChunkedStream("", 5, [](Aws::IOStream& chunkedStream) { + char buffer[100]; + chunkedStream.read(buffer, 100); + auto bytesRead = static_cast(chunkedStream.gcount()); + EXPECT_EQ(36ul, bytesRead); + + std::string output(buffer, bytesRead); + EXPECT_EQ("0\r\nx-amz-checksum-crc32:AAAAAA==\r\n\r\n", output); + }); +} + +// Custom HTTP client (inherits default IsDefaultAwsHttpClient() = false from base class) +class CustomHttpClient : public Aws::Http::HttpClient { +public: + std::shared_ptr MakeRequest(const std::shared_ptr&, + Aws::Utils::RateLimits::RateLimiterInterface*, + Aws::Utils::RateLimits::RateLimiterInterface*) const override { + return nullptr; + } +}; + +TEST_F(ChunkingInterceptorTest, ShouldNotApplyChunkingForCustomHttpClient) { + // Simulate the GetChunkingConfig behavior from AWSClient.cpp + // When IsDefaultAwsHttpClient() returns false, httpClientChunkedMode is set to CLIENT_IMPLEMENTATION + Aws::Client::ClientConfiguration config; + auto customHttpClient = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + + // This simulates the logic in GetChunkingConfig function + if (!customHttpClient->IsDefaultAwsHttpClient()) { + config.httpClientChunkedMode = Aws::Client::HttpClientChunkedMode::CLIENT_IMPLEMENTATION; + } + + ChunkingInterceptor interceptor(config.httpClientChunkedMode); + + // Create request with checksum (would normally trigger chunking) + auto request = Aws::MakeShared(CHUNKING_TEST_LOG_TAG, "test.com", Http::HttpMethod::HTTP_POST); + auto requestHash = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + request->SetRequestHash("crc32", requestHash); + + auto inputStream = Aws::MakeShared(CHUNKING_TEST_LOG_TAG); + *inputStream << "test data"; + request->AddContentBody(inputStream); + + // Create interceptor context with a mock request + MockRequest mockRequest; + smithy::interceptor::InterceptorContext context(mockRequest); + context.SetTransmitRequest(request); + + // Apply interceptor + auto result = interceptor.ModifyBeforeSigning(context); + + // Verify chunking was NOT applied because custom HTTP client uses default IsDefaultAwsHttpClient() = false + EXPECT_EQ(request, result.GetResult()); + EXPECT_FALSE(request->HasHeader(Aws::Http::AWS_TRAILER_HEADER)); + EXPECT_FALSE(request->HasHeader(Aws::Http::TRANSFER_ENCODING_HEADER)); +} \ No newline at end of file