diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h index 0d062be1e00..ff4c2e4dbf6 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferHandle.h @@ -79,6 +79,9 @@ namespace Aws Aws::String GetChecksum() const { return m_checksum; }; void SetChecksum(const Aws::String& checksum) { m_checksum = checksum; } + + std::shared_ptr GetChecksumHash() const { return m_checksumHash; } + void SetChecksumHash(std::shared_ptr hash) { m_checksumHash = hash; } private: int m_partId = 0; @@ -93,6 +96,7 @@ namespace Aws std::atomic m_downloadBuffer; bool m_lastPart = false; Aws::String m_checksum; + std::shared_ptr m_checksumHash; }; using PartPointer = std::shared_ptr< PartState >; @@ -389,6 +393,12 @@ namespace Aws Aws::String GetChecksum() const { return m_checksum; } void SetChecksum(const Aws::String& checksum) { this->m_checksum = checksum; } + void SetPartChecksum(int partId, std::shared_ptr hash) { m_partChecksums[partId] = hash; } + std::shared_ptr GetPartChecksum(int partId) const { + auto it = m_partChecksums.find(partId); + return it != m_partChecksums.end() ? it->second : nullptr; + } + private: void CleanupDownloadStream(); @@ -430,6 +440,9 @@ namespace Aws mutable std::condition_variable m_waitUntilFinishedSignal; mutable std::mutex m_getterSetterLock; Aws::String m_checksum; + // Map of part number to Hash instance for multipart download checksum validation + // TODO: Add CRT checksum combining utility when available + Aws::Map> m_partChecksums; }; AWS_TRANSFER_API Aws::OStream& operator << (Aws::OStream& s, TransferStatus status); diff --git a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h index a4b5580fd6e..725f14c1219 100644 --- a/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h +++ b/src/aws-cpp-sdk-transfer/include/aws/transfer/TransferManager.h @@ -144,6 +144,13 @@ namespace Aws * upload. Defaults to CRC64-NVME. */ Aws::S3::Model::ChecksumAlgorithm checksumAlgorithm = S3::Model::ChecksumAlgorithm::CRC64NVME; + + /** + * Enable checksum validation for downloads. When enabled, checksums will be + * calculated during download and validated against S3 response headers. + * Defaults to true. + */ + bool validateChecksums = true; }; /** diff --git a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp index 996e427e114..f6f1ed47bc6 100644 --- a/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp +++ b/src/aws-cpp-sdk-transfer/source/transfer/TransferManager.cpp @@ -51,6 +51,42 @@ namespace Aws } } + static std::shared_ptr CreateHashForAlgorithm(S3::Model::ChecksumAlgorithm algorithm) { + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) { + return Aws::MakeShared(CLASS_TAG); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) { + return Aws::MakeShared(CLASS_TAG); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) { + return Aws::MakeShared(CLASS_TAG); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) { + return Aws::MakeShared(CLASS_TAG); + } + return Aws::MakeShared(CLASS_TAG); + } + + template + static Aws::String GetChecksumFromResult(const ResultT& result, S3::Model::ChecksumAlgorithm algorithm) { + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32) { + return result.GetChecksumCRC32(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC32C) { + return result.GetChecksumCRC32C(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::CRC64NVME) { + return result.GetChecksumCRC64NVME(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA1) { + return result.GetChecksumSHA1(); + } + if (algorithm == S3::Model::ChecksumAlgorithm::SHA256) { + return result.GetChecksumSHA256(); + } + return ""; + } + struct TransferHandleAsyncContext : public Aws::Client::AsyncCallerContext { std::shared_ptr handle; @@ -664,26 +700,7 @@ namespace Aws { if (handle->ShouldContinue()) { - partState->SetChecksum([&]() -> Aws::String { - if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32) - { - return outcome.GetResult().GetChecksumCRC32(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::CRC32C) - { - return outcome.GetResult().GetChecksumCRC32C(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA1) - { - return outcome.GetResult().GetChecksumSHA1(); - } - else if (m_transferConfig.checksumAlgorithm == S3::Model::ChecksumAlgorithm::SHA256) - { - return outcome.GetResult().GetChecksumSHA256(); - } - //Return empty checksum for not set. - return ""; - }()); + partState->SetChecksum(GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm)); handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); AWS_LOGSTREAM_DEBUG(CLASS_TAG, "Transfer handle [" << handle->GetId() << " successfully uploaded Part: [" << partState->GetPartId() << "] to Bucket: [" @@ -938,6 +955,61 @@ namespace Aws handle->SetContentType(getObjectOutcome.GetResult().GetContentType()); handle->ChangePartToCompleted(partState, getObjectOutcome.GetResult().GetETag()); getObjectOutcome.GetResult().GetBody().flush(); + + // Validate checksum for single-part download by reading file + if (m_transferConfig.validateChecksums) + { + Aws::String expectedChecksum = GetChecksumFromResult(getObjectOutcome.GetResult(), m_transferConfig.checksumAlgorithm); + + if (!expectedChecksum.empty() && !handle->GetTargetFilePath().empty()) + { + auto hash = CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm); + Aws::IFStream fileStream(handle->GetTargetFilePath().c_str(), std::ios::binary); + + if (fileStream.good()) + { + const size_t bufferSize = 8192; + char buffer[bufferSize]; + while (fileStream.good()) + { + fileStream.read(buffer, bufferSize); + std::streamsize bytesRead = fileStream.gcount(); + if (bytesRead > 0) + { + hash->Update(reinterpret_cast(buffer), static_cast(bytesRead)); + } + } + fileStream.close(); + + auto calculatedResult = hash->GetHash(); + if (calculatedResult.IsSuccess()) + { + Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(calculatedResult.GetResult()); + if (calculatedChecksum != expectedChecksum) + { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] Checksum mismatch for single-part download. Expected: " + << expectedChecksum << ", Calculated: " << calculatedChecksum); + + // Delete the corrupted file + Aws::FileSystem::RemoveFileIfExists(handle->GetTargetFilePath().c_str()); + + handle->ChangePartToFailed(partState); + handle->UpdateStatus(TransferStatus::FAILED); + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "ChecksumMismatch", + "Single-part download checksum validation failed", + false); + handle->SetError(error); + TriggerErrorCallback(handle, error); + TriggerTransferStatusUpdatedCallback(handle); + return; + } + } + } + } + } + handle->UpdateStatus(TransferStatus::COMPLETED); } else @@ -1074,6 +1146,12 @@ namespace Aws { partState->SetDownloadBuffer(buffer); + // Initialize checksum Hash for this part if validation is enabled + if (m_transferConfig.validateChecksums) + { + handle->SetPartChecksum(partState->GetPartId(), CreateHashForAlgorithm(m_transferConfig.checksumAlgorithm)); + } + auto getObjectRangeRequest = m_transferConfig.getObjectTemplate; getObjectRangeRequest.SetCustomizedAccessLogTag(m_transferConfig.customizedAccessLogTag); getObjectRangeRequest.SetContinueRequestHandler([handle](const Aws::Http::HttpRequest*) { return handle->ShouldContinue(); }); @@ -1202,6 +1280,48 @@ namespace Aws Aws::IOStream* bufferStream = partState->GetDownloadPartStream(); assert(bufferStream); + // Calculate and validate checksum for this part if validation is enabled + if (m_transferConfig.validateChecksums) + { + auto hash = handle->GetPartChecksum(partState->GetPartId()); + if (hash && partState->GetDownloadBuffer()) + { + hash->Update(partState->GetDownloadBuffer(), static_cast(partState->GetSizeInBytes())); + + // Get expected checksum from response + Aws::String expectedChecksum = GetChecksumFromResult(outcome.GetResult(), m_transferConfig.checksumAlgorithm); + + // Validate part checksum + if (!expectedChecksum.empty()) + { + auto calculatedResult = hash->GetHash(); + if (calculatedResult.IsSuccess()) + { + Aws::String calculatedChecksum = Utils::HashingUtils::Base64Encode(calculatedResult.GetResult()); + if (calculatedChecksum != expectedChecksum) + { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "Transfer handle [" << handle->GetId() + << "] Checksum mismatch for part " << partState->GetPartId() + << ". Expected: " << expectedChecksum << ", Calculated: " << calculatedChecksum); + Aws::Client::AWSError error(Aws::S3::S3Errors::INTERNAL_FAILURE, + "ChecksumMismatch", + "Part checksum validation failed", + false); + handle->ChangePartToFailed(partState); + handle->SetError(error); + TriggerErrorCallback(handle, error); + if(partState->GetDownloadBuffer()) + { + m_bufferManager.Release(partState->GetDownloadBuffer()); + partState->SetDownloadBuffer(nullptr); + } + return; + } + } + } + } + } + Aws::String errMsg{handle->WritePartToDownloadStream(bufferStream, partState->GetRangeBegin())}; if (errMsg.empty()) { handle->ChangePartToCompleted(partState, outcome.GetResult().GetETag()); @@ -1239,6 +1359,7 @@ namespace Aws { if (failedParts.size() == 0 && handle->GetBytesTransferred() == handle->GetBytesTotalSize()) { + // TODO: Combine part checksums and validate full-object checksum when CRT provides combining utility outcome.GetResult().GetBody().flush(); handle->UpdateStatus(TransferStatus::COMPLETED); }