Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/databricks/sql/cloudfetch/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,14 @@ class DownloadableResultSettings:
link_expiry_buffer_secs (int): Time in seconds to prevent download of a link before it expires. Default 0 secs.
download_timeout (int): Timeout for download requests. Default 60 secs.
max_consecutive_file_download_retries (int): Number of consecutive download retries before shutting down.
min_cloudfetch_download_speed (float): Threshold in MB/s below which to log warning. Default 0.1 MB/s.
"""

is_lz4_compressed: bool
link_expiry_buffer_secs: int = 0
download_timeout: int = 60
max_consecutive_file_download_retries: int = 0
min_cloudfetch_download_speed: float = 0.1


class ResultSetDownloadHandler:
Expand Down Expand Up @@ -90,6 +92,8 @@ def run(self) -> DownloadedFile:
self.link, self.settings.link_expiry_buffer_secs
)

start_time = time.time()

with self._http_client.execute(
method=HttpMethod.GET,
url=self.link.fileLink,
Expand All @@ -102,6 +106,13 @@ def run(self) -> DownloadedFile:

# Save (and decompress if needed) the downloaded file
compressed_data = response.content

# Log download metrics
download_duration = time.time() - start_time
self._log_download_metrics(
self.link.fileLink, len(compressed_data), download_duration
)

decompressed_data = (
ResultSetDownloadHandler._decompress_data(compressed_data)
if self.settings.is_lz4_compressed
Expand All @@ -128,6 +139,32 @@ def run(self) -> DownloadedFile:
self.link.rowCount,
)

def _log_download_metrics(
self, url: str, bytes_downloaded: int, duration_seconds: float
):
"""Log download speed metrics at INFO/WARN levels."""
# Calculate speed in MB/s (ensure float division for precision)
speed_mbps = (float(bytes_downloaded) / (1024 * 1024)) / duration_seconds

urlEndpoint = url.split("?")[0]
# INFO level logging
logger.info(
"CloudFetch download completed: %.4f MB/s, %d bytes in %.3fs from %s",
speed_mbps,
bytes_downloaded,
duration_seconds,
urlEndpoint,
)

# WARN level logging if below threshold
if speed_mbps < self.settings.min_cloudfetch_download_speed:
logger.warning(
"CloudFetch download slower than threshold: %.4f MB/s (threshold: %.1f MB/s) from %s",
speed_mbps,
self.settings.min_cloudfetch_download_speed,
url,
)

@staticmethod
def _validate_link(link: TSparkArrowResultLink, expiry_buffer_secs: int):
"""
Expand Down
23 changes: 21 additions & 2 deletions tests/unit/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ class DownloaderTests(unittest.TestCase):
Unit tests for checking downloader logic.
"""

def _setup_time_mock_for_download(self, mock_time, end_time):
"""Helper to setup time mock that handles logging system calls."""
call_count = [0]
def time_side_effect():
call_count[0] += 1
if call_count[0] <= 2: # First two calls (validation, start_time)
return 1000
else: # All subsequent calls (logging, duration calculation)
return end_time
mock_time.side_effect = time_side_effect

@patch("time.time", return_value=1000)
def test_run_link_expired(self, mock_time):
settings = Mock()
Expand Down Expand Up @@ -75,13 +86,17 @@ def test_run_get_response_not_ok(self, mock_time):
d.run()
self.assertTrue("404" in str(context.exception))

@patch("time.time", return_value=1000)
@patch("time.time")
def test_run_uncompressed_successful(self, mock_time):
self._setup_time_mock_for_download(mock_time, 1000.5)

http_client = DatabricksHttpClient.get_instance()
file_bytes = b"1234567890" * 10
settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings.is_lz4_compressed = False
settings.min_cloudfetch_download_speed = 1.0
result_link = Mock(bytesNum=100, expiryTime=1001)
result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=abc123"

with patch.object(
http_client,
Expand All @@ -95,15 +110,19 @@ def test_run_uncompressed_successful(self, mock_time):

assert file.file_bytes == b"1234567890" * 10

@patch("time.time", return_value=1000)
@patch("time.time")
def test_run_compressed_successful(self, mock_time):
self._setup_time_mock_for_download(mock_time, 1000.2)

http_client = DatabricksHttpClient.get_instance()
file_bytes = b"1234567890" * 10
compressed_bytes = b'\x04"M\x18h@d\x00\x00\x00\x00\x00\x00\x00#\x14\x00\x00\x00\xaf1234567890\n\x00BP67890\x00\x00\x00\x00'

settings = Mock(link_expiry_buffer_secs=0, download_timeout=0, use_proxy=False)
settings.is_lz4_compressed = True
settings.min_cloudfetch_download_speed = 1.0
result_link = Mock(bytesNum=100, expiryTime=1001)
result_link.fileLink = "https://s3.amazonaws.com/bucket/file.arrow?token=xyz789"
with patch.object(
http_client,
"execute",
Expand Down
Loading