Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ namespace Aws

Aws::String GetChecksum() const { return m_checksum; };
void SetChecksum(const Aws::String& checksum) { m_checksum = checksum; }

std::shared_ptr<Aws::Utils::Crypto::Hash> GetChecksumHash() const { return m_checksumHash; }
void SetChecksumHash(std::shared_ptr<Aws::Utils::Crypto::Hash> hash) { m_checksumHash = hash; }
private:

int m_partId = 0;
Expand All @@ -93,6 +96,7 @@ namespace Aws
std::atomic<unsigned char*> m_downloadBuffer;
bool m_lastPart = false;
Aws::String m_checksum;
std::shared_ptr<Aws::Utils::Crypto::Hash> m_checksumHash;
};

using PartPointer = std::shared_ptr< PartState >;
Expand Down Expand Up @@ -389,6 +393,12 @@ namespace Aws
Aws::String GetChecksum() const { return m_checksum; }
void SetChecksum(const Aws::String& checksum) { this->m_checksum = checksum; }

void SetPartChecksum(int partId, std::shared_ptr<Aws::Utils::Crypto::Hash> hash) { m_partChecksums[partId] = hash; }
std::shared_ptr<Aws::Utils::Crypto::Hash> GetPartChecksum(int partId) const {
auto it = m_partChecksums.find(partId);
return it != m_partChecksums.end() ? it->second : nullptr;
}

private:
void CleanupDownloadStream();

Expand Down Expand Up @@ -430,6 +440,9 @@ namespace Aws
mutable std::condition_variable m_waitUntilFinishedSignal;
mutable std::mutex m_getterSetterLock;
Aws::String m_checksum;
// Map of part number to Hash instance for multipart download checksum validation
// TODO: Add CRT checksum combining utility when available
Aws::Map<int, std::shared_ptr<Aws::Utils::Crypto::Hash>> m_partChecksums;
};

AWS_TRANSFER_API Aws::OStream& operator << (Aws::OStream& s, TransferStatus status);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ namespace Aws
* upload. Defaults to CRC64-NVME.
*/
Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC64NVME;

/**
* Enable checksum validation for downloads. When enabled, checksums will be
* calculated during download and validated against S3 response headers.
* Defaults to true.
*/
bool validateChecksums = true;
};

/**
Expand Down
161 changes: 141 additions & 20 deletions src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,42 @@ namespace Aws
}
}

static std::shared_ptr<Utils::Crypto::Hash> CreateHashForAlgorithm(S3::Model::ChecksumAlgorithm algorithm) {
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) {
return Aws::MakeShared<Utils::Crypto::CRC32>(CLASS_TAG);
}
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) {
return Aws::MakeShared<Utils::Crypto::CRC32C>(CLASS_TAG);
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) {
return Aws::MakeShared<Utils::Crypto::Sha1>(CLASS_TAG);
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) {
return Aws::MakeShared<Utils::Crypto::Sha256>(CLASS_TAG);
}
return Aws::MakeShared<Utils::Crypto::CRC64>(CLASS_TAG);
}

template <typename ResultT>
static Aws::String GetChecksumFromResult(const ResultT& result, S3::Model::ChecksumAlgorithm algorithm) {
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) {
return result.GetChecksumCRC32();
}
if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) {
return result.GetChecksumCRC32C();
}
if (algorithm == S3::Model::ChecksumAlgorithm::CRC64NVME) {
return result.GetChecksumCRC64NVME();
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) {
return result.GetChecksumSHA1();
}
if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) {
return result.GetChecksumSHA256();
}
return "";
}

struct TransferHandleAsyncContext : public Aws::Client::AsyncCallerContext
{
std::shared_ptr<TransferHandle> handle;
Expand Down Expand Up @@ -664,26 +700,7 @@ namespace Aws
{
if (handle->ShouldContinue())
{
partState->SetChecksum([&]() -> Aws::String {
if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32)
{
return outcome.GetResult().GetChecksumCRC32();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C)
{
return outcome.GetResult().GetChecksumCRC32C();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1)
{
return outcome.GetResult().GetChecksumSHA1();
}
else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256)
{
return outcome.GetResult().GetChecksumSHA256();
}
//Return empty checksum for not set.
return "";
}());
partState->SetChecksum(GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm));
handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag());
AWS_LOGSTREAM_DEBUG(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< " successfully uploaded Part: [" << partState->GetPartId() << "] to Bucket: ["
Expand Down Expand Up @@ -938,6 +955,61 @@ namespace Aws
handle->SetContentType(getObjectOutcome.GetResult().GetContentType());
handle->ChangePartToCompleted(partState, getObjectOutcome.GetResult().GetETag());
getObjectOutcome.GetResult().GetBody().flush();

// Validate checksum for single-part download by reading file
if (m_transferConfig.validateChecksums)
{
Aws::String expectedChecksum = GetChecksumFromResult(getObjectOutcome.GetResult(), m_transferConfig.checksumAlgorithm);

if (!expectedChecksum.empty() && !handle->GetTargetFilePath().empty())
{
auto hash = CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm);
Aws::IFStream fileStream(handle->GetTargetFilePath().c_str(), std::ios::binary);

if (fileStream.good())
{
const size_t bufferSize = 8192;
char buffer[bufferSize];
while (fileStream.good())
{
fileStream.read(buffer, bufferSize);
std::streamsize bytesRead = fileStream.gcount();
if (bytesRead > 0)
{
hash->Update(reinterpret_cast<unsigned char*>(buffer), static_cast<size_t>(bytesRead));
}
}
fileStream.close();

auto calculatedResult = hash->GetHash();
if (calculatedResult.IsSuccess())
{
Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(calculatedResult.GetResult());
if (calculatedChecksum != expectedChecksum)
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< "] Checksum mismatch for single-part download. Expected: "
<< expectedChecksum << ", Calculated: " << calculatedChecksum);

// Delete the corrupted file
Aws::FileSystem::RemoveFileIfExists(handle->GetTargetFilePath().c_str());

handle->ChangePartToFailed(partState);
handle->UpdateStatus(TransferStatus::FAILED);
Aws::Client::AWSError<Aws::S3::S3Errors> error(Aws::S3::S3Errors::INTERNAL_FAILURE,
"ChecksumMismatch",
"Single-part download checksum validation failed",
false);
handle->SetError(error);
TriggerErrorCallback(handle, error);
TriggerTransferStatusUpdatedCallback(handle);
return;
}
}
}
}
}

handle->UpdateStatus(TransferStatus::COMPLETED);
}
else
Expand Down Expand Up @@ -1074,6 +1146,12 @@ namespace Aws
{
partState->SetDownloadBuffer(buffer);

// Initialize checksum Hash for this part if validation is enabled
if (m_transferConfig.validateChecksums)
{
handle->SetPartChecksum(partState->GetPartId(), CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm));
}

auto getObjectRangeRequest = m_transferConfig.getObjectTemplate;
getObjectRangeRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag);
getObjectRangeRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); });
Expand Down Expand Up @@ -1202,6 +1280,48 @@ namespace Aws
Aws::IOStream* bufferStream = partState->GetDownloadPartStream();
assert(bufferStream);

// Calculate and validate checksum for this part if validation is enabled
if (m_transferConfig.validateChecksums)
{
auto hash = handle->GetPartChecksum(partState->GetPartId());
if (hash && partState->GetDownloadBuffer())
{
hash->Update(partState->GetDownloadBuffer(), static_cast<size_t>(partState->GetSizeInBytes()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets not do this, in the get object we already validate the checksum of the body i think the idea here is more so the compare the checksum returned from get object, combined across parts, to make sure that the combined object has the same full object checksum.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright


// Get expected checksum from response
Aws::String expectedChecksum = GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm);

// Validate part checksum
if (!expectedChecksum.empty())
{
auto calculatedResult = hash->GetHash();
if (calculatedResult.IsSuccess())
{
Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(calculatedResult.GetResult());
if (calculatedChecksum != expectedChecksum)
{
AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId()
<< "] Checksum mismatch for part " << partState->GetPartId()
<< ". Expected: " << expectedChecksum << ", Calculated: " << calculatedChecksum);
Aws::Client::AWSError<Aws::S3::S3Errors> error(Aws::S3::S3Errors::INTERNAL_FAILURE,
"ChecksumMismatch",
"Part checksum validation failed",
false);
handle->ChangePartToFailed(partState);
handle->SetError(error);
TriggerErrorCallback(handle, error);
if(partState->GetDownloadBuffer())
{
m_bufferManager.Release(partState->GetDownloadBuffer());
partState->SetDownloadBuffer(nullptr);
}
return;
}
}
}
}
}

Aws::String errMsg{handle->WritePartToDownloadStream(bufferStream, partState->GetRangeBegin())};
if (errMsg.empty()) {
handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag());
Expand Down Expand Up @@ -1239,6 +1359,7 @@ namespace Aws
{
if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize())
{
// TODO: Combine part checksums and validate full-object checksum when CRT provides combining utility
outcome.GetResult().GetBody().flush();
handle->UpdateStatus(TransferStatus::COMPLETED);
}
Expand Down
Loading