Skip to content

Commit ca562ab

Browse files
NiallEgansusodapop
authored andcommitted
Use Arrow schema if available
This PR changes the Python client to use the Arrow schema if it has been sent by the server, instead of re-constructing an approximation from the Hive schema. The primary difference is in the timezone information for timestamps * Added new unit tests to check the correct field is used * Adapted integration tests to add timezones as appropriate
1 parent 963d5b0 commit ca562ab

File tree

6 files changed

+90
-48
lines changed

6 files changed

+90
-48
lines changed

cmdexec/clients/python/dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ thrift==0.13.0
66
pandas==1.3.4
77
future==0.18.2
88
packaging==21.3
9+
pytz==2021.3

cmdexec/clients/python/src/databricks/sql/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def __init__(self,
480480
self.arraysize = arraysize
481481
self.thrift_backend = thrift_backend
482482
self.description = execute_response.description
483-
self._arrow_schema = execute_response.arrow_schema
483+
self._arrow_schema_bytes = execute_response.arrow_schema_bytes
484484
self._next_row_index = 0
485485

486486
if execute_response.arrow_queue:
@@ -505,7 +505,7 @@ def _fill_results_buffer(self):
505505
max_rows=self.arraysize,
506506
max_bytes=self.buffer_size_bytes,
507507
expected_row_start_offset=self._next_row_index,
508-
arrow_schema=self._arrow_schema,
508+
arrow_schema_bytes=self._arrow_schema_bytes,
509509
description=self.description)
510510
self.results = results
511511
self.has_more_rows = has_more_rows

cmdexec/clients/python/src/databricks/sql/thrift_backend.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def open_session(self, session_configuration, catalog, schema):
330330
initial_namespace = None
331331

332332
open_session_req = ttypes.TOpenSessionReq(
333-
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V4,
333+
client_protocol_i64=ttypes.TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V5,
334334
client_protocol=None,
335335
initialNamespace=initial_namespace,
336336
canUseMultipleCatalogs=True,
@@ -376,13 +376,13 @@ def _poll_for_status(self, op_handle):
376376
)
377377
return self.make_request(self._client.GetOperationStatus, req)
378378

379-
def _create_arrow_table(self, t_row_set, arrow_schema, description):
379+
def _create_arrow_table(self, t_row_set, schema_bytes, description):
380380
if t_row_set.columns is not None:
381381
arrow_table, num_rows = ThriftBackend._convert_column_based_set_to_arrow_table(
382-
t_row_set.columns, arrow_schema)
382+
t_row_set.columns, description)
383383
elif t_row_set.arrowBatches is not None:
384384
arrow_table, num_rows = ThriftBackend._convert_arrow_based_set_to_arrow_table(
385-
t_row_set.arrowBatches, arrow_schema)
385+
t_row_set.arrowBatches, schema_bytes)
386386
else:
387387
raise OperationalError("Unsupported TRowSet instance {}".format(t_row_set))
388388
return self._convert_decimals_in_arrow_table(arrow_table, description), num_rows
@@ -404,9 +404,8 @@ def _convert_decimals_in_arrow_table(table, description):
404404
return table
405405

406406
@staticmethod
407-
def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema):
407+
def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema_bytes):
408408
ba = bytearray()
409-
schema_bytes = schema.serialize().to_pybytes()
410409
ba += schema_bytes
411410
n_rows = 0
412411
for arrow_batch in arrow_batches:
@@ -416,13 +415,13 @@ def _convert_arrow_based_set_to_arrow_table(arrow_batches, schema):
416415
return arrow_table, n_rows
417416

418417
@staticmethod
419-
def _convert_column_based_set_to_arrow_table(columns, schema):
418+
def _convert_column_based_set_to_arrow_table(columns, description):
420419
arrow_table = pyarrow.Table.from_arrays(
421420
[ThriftBackend._convert_column_to_arrow_array(c) for c in columns],
422421
# Only use the column names from the schema, the types are determined by the
423422
# physical types used in column based set, as they can differ from the
424423
# mapping used in _hive_schema_to_arrow_schema.
425-
names=[c.name for c in schema])
424+
names=[c[0] for c in description])
426425
return arrow_table, arrow_table.num_rows
427426

428427
@staticmethod
@@ -555,13 +554,14 @@ def _results_message_to_execute_response(self, resp, operation_state):
555554
has_more_rows = (not direct_results) or (not direct_results.resultSet) \
556555
or direct_results.resultSet.hasMoreRows
557556
description = self._hive_schema_to_description(t_result_set_metadata_resp.schema)
558-
arrow_schema = self._hive_schema_to_arrow_schema(t_result_set_metadata_resp.schema)
557+
schema_bytes = (t_result_set_metadata_resp.arrowSchema or self._hive_schema_to_arrow_schema(
558+
t_result_set_metadata_resp.schema).serialize().to_pybytes())
559559

560560
if direct_results and direct_results.resultSet:
561561
assert (direct_results.resultSet.results.startRowOffset == 0)
562562
assert (direct_results.resultSetMetadata)
563563
arrow_results, n_rows = self._create_arrow_table(direct_results.resultSet.results,
564-
arrow_schema, description)
564+
schema_bytes, description)
565565
arrow_queue_opt = ArrowQueue(arrow_results, n_rows, 0)
566566
else:
567567
arrow_queue_opt = None
@@ -572,7 +572,7 @@ def _results_message_to_execute_response(self, resp, operation_state):
572572
has_more_rows=has_more_rows,
573573
command_handle=resp.operationHandle,
574574
description=description,
575-
arrow_schema=arrow_schema)
575+
arrow_schema_bytes=schema_bytes)
576576

577577
def _wait_until_command_done(self, op_handle, initial_operation_status_resp):
578578
if initial_operation_status_resp:
@@ -697,8 +697,8 @@ def _handle_execute_response(self, resp, cursor):
697697

698698
return self._results_message_to_execute_response(resp, final_operation_state)
699699

700-
def fetch_results(self, op_handle, max_rows, max_bytes, expected_row_start_offset, arrow_schema,
701-
description):
700+
def fetch_results(self, op_handle, max_rows, max_bytes, expected_row_start_offset,
701+
arrow_schema_bytes, description):
702702
assert (op_handle is not None)
703703

704704
req = ttypes.TFetchResultsReq(
@@ -716,7 +716,8 @@ def fetch_results(self, op_handle, max_rows, max_bytes, expected_row_start_offse
716716
if resp.results.startRowOffset > expected_row_start_offset:
717717
logger.warning("Expected results to start from {} but they instead start at {}".format(
718718
expected_row_start_offset, resp.results.startRowOffset))
719-
arrow_results, n_rows = self._create_arrow_table(resp.results, arrow_schema, description)
719+
arrow_results, n_rows = self._create_arrow_table(resp.results, arrow_schema_bytes,
720+
description)
720721
arrow_queue = ArrowQueue(arrow_results, n_rows)
721722

722723
return arrow_queue, resp.hasMoreRows

cmdexec/clients/python/src/databricks/sql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def remaining_rows(self) -> pyarrow.Table:
3535

3636
ExecuteResponse = namedtuple(
3737
'ExecuteResponse', 'status has_been_closed_server_side has_more_rows description '
38-
'command_handle arrow_queue arrow_schema')
38+
'command_handle arrow_queue arrow_schema_bytes')
3939

4040

4141
def _bound(min_x, max_x, x):

cmdexec/clients/python/tests/test_fetches.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def make_dummy_result_set_from_initial_results(initial_results):
4242
description=Mock(),
4343
command_handle=None,
4444
arrow_queue=arrow_queue,
45-
arrow_schema=schema))
45+
arrow_schema_bytes=schema.serialize().to_pybytes()))
4646
num_cols = len(initial_results[0]) if initial_results else 0
4747
rs.description = [(f'col{col_id}', 'integer', None, None, None, None, None)
4848
for col_id in range(num_cols)]
@@ -52,8 +52,8 @@ def make_dummy_result_set_from_initial_results(initial_results):
5252
def make_dummy_result_set_from_batch_list(batch_list):
5353
batch_index = 0
5454

55-
def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, arrow_schema,
56-
description):
55+
def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset,
56+
arrow_schema_bytes, description):
5757
nonlocal batch_index
5858
results = FetchTests.make_arrow_queue(batch_list[batch_index])
5959
batch_index += 1
@@ -75,7 +75,7 @@ def fetch_results(op_handle, max_rows, max_bytes, expected_row_start_offset, arr
7575
for col_id in range(num_cols)],
7676
command_handle=None,
7777
arrow_queue=None,
78-
arrow_schema=None))
78+
arrow_schema_bytes=None))
7979
return rs
8080

8181
def assertEqualRowValues(self, actual, expected):

cmdexec/clients/python/tests/test_thrift_backend.py

Lines changed: 67 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,54 @@ def test_handle_execute_response_can_handle_with_direct_results(self):
497497
ttypes.TOperationState.FINISHED_STATE,
498498
)
499499

500+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
501+
def test_use_arrow_schema_if_available(self, tcli_service_class):
502+
tcli_service_instance = tcli_service_class.return_value
503+
arrow_schema_mock = MagicMock(name="Arrow schema mock")
504+
hive_schema_mock = MagicMock(name="Hive schema mock")
505+
506+
t_get_result_set_metadata_resp = ttypes.TGetResultSetMetadataResp(
507+
status=self.okay_status,
508+
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET,
509+
schema=hive_schema_mock,
510+
arrowSchema=arrow_schema_mock)
511+
512+
t_execute_resp = ttypes.TExecuteStatementResp(
513+
status=self.okay_status,
514+
directResults=None,
515+
operationHandle=self.operation_handle,
516+
)
517+
518+
tcli_service_instance.GetResultSetMetadata.return_value = t_get_result_set_metadata_resp
519+
thrift_backend = self._make_fake_thrift_backend()
520+
execute_response = thrift_backend._handle_execute_response(t_execute_resp, Mock())
521+
522+
self.assertEqual(execute_response.arrow_schema_bytes, arrow_schema_mock)
523+
524+
@patch("databricks.sql.thrift_backend.TCLIService.Client")
525+
def test_fall_back_to_hive_schema_if_no_arrow_schema(self, tcli_service_class):
526+
tcli_service_instance = tcli_service_class.return_value
527+
hive_schema_mock = MagicMock(name="Hive schema mock")
528+
529+
hive_schema_req = ttypes.TGetResultSetMetadataResp(
530+
status=self.okay_status,
531+
resultFormat=ttypes.TSparkRowSetType.ARROW_BASED_SET,
532+
arrowSchema=None,
533+
schema=hive_schema_mock)
534+
535+
t_execute_resp = ttypes.TExecuteStatementResp(
536+
status=self.okay_status,
537+
directResults=None,
538+
operationHandle=self.operation_handle,
539+
)
540+
541+
tcli_service_instance.GetResultSetMetadata.return_value = hive_schema_req
542+
thrift_backend = self._make_fake_thrift_backend()
543+
thrift_backend._handle_execute_response(t_execute_resp, Mock())
544+
545+
self.assertEqual(hive_schema_mock,
546+
thrift_backend._hive_schema_to_arrow_schema.call_args[0][0])
547+
500548
@patch("databricks.sql.thrift_backend.TCLIService.Client")
501549
def test_handle_execute_response_reads_has_more_rows_in_direct_results(
502550
self, tcli_service_class):
@@ -567,7 +615,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
567615
max_rows=1,
568616
max_bytes=1,
569617
expected_row_start_offset=0,
570-
arrow_schema=Mock(),
618+
arrow_schema_bytes=Mock(),
571619
description=Mock())
572620

573621
self.assertEqual(has_more_rows, has_more_rows_resp)
@@ -591,15 +639,15 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
591639
pyarrow.field("column2", pyarrow.string()),
592640
pyarrow.field("column3", pyarrow.float64()),
593641
pyarrow.field("column3", pyarrow.binary())
594-
])
642+
]).serialize().to_pybytes()
595643

596644
thrift_backend = ThriftBackend("foobar", 443, "path", [])
597645
arrow_queue, has_more_results = thrift_backend.fetch_results(
598646
op_handle=Mock(),
599647
max_rows=1,
600648
max_bytes=1,
601649
expected_row_start_offset=0,
602-
arrow_schema=schema,
650+
arrow_schema_bytes=schema,
603651
description=MagicMock())
604652

605653
self.assertEqual(arrow_queue.n_valid_rows, 15 * 10)
@@ -792,24 +840,21 @@ def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mo
792840
schema = Mock()
793841
cols = Mock()
794842
arrow_batches = Mock()
843+
description = Mock()
795844

796845
t_col_set = ttypes.TRowSet(columns=cols)
797-
thrift_backend._create_arrow_table(t_col_set, schema, Mock())
846+
thrift_backend._create_arrow_table(t_col_set, schema, description)
798847
convert_arrow_mock.assert_not_called()
799-
convert_col_mock.assert_called_once_with(cols, schema)
848+
convert_col_mock.assert_called_once_with(cols, description)
800849

801850
t_arrow_set = ttypes.TRowSet(arrowBatches=arrow_batches)
802851
thrift_backend._create_arrow_table(t_arrow_set, schema, Mock())
803852
convert_arrow_mock.assert_called_once_with(arrow_batches, schema)
804-
convert_col_mock.assert_called_once_with(cols, schema)
805853

806854
def test_convert_column_based_set_to_arrow_table_without_nulls(self):
807-
schema = pyarrow.schema([
808-
pyarrow.field("column1", pyarrow.int32()),
809-
pyarrow.field("column2", pyarrow.string()),
810-
pyarrow.field("column3", pyarrow.float64()),
811-
pyarrow.field("column3", pyarrow.binary())
812-
])
855+
# Deliberately duplicate the column name to check that dups work
856+
field_names = ["column1", "column2", "column3", "column3"]
857+
description = [(name, ) for name in field_names]
813858

814859
t_cols = [
815860
ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))),
@@ -820,7 +865,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
820865
binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1)))
821866
]
822867

823-
arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table(t_cols, schema)
868+
arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table(
869+
t_cols, description)
824870
self.assertEqual(n_rows, 3)
825871

826872
# Check schema, column names and types
@@ -841,12 +887,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
841887
self.assertEqual(arrow_table.column(3).to_pylist(), [b'\x11', b'\x22', b'\x33'])
842888

843889
def test_convert_column_based_set_to_arrow_table_with_nulls(self):
844-
schema = pyarrow.schema([
845-
pyarrow.field("column1", pyarrow.int32()),
846-
pyarrow.field("column2", pyarrow.string()),
847-
pyarrow.field("column3", pyarrow.float64()),
848-
pyarrow.field("column3", pyarrow.binary())
849-
])
890+
field_names = ["column1", "column2", "column3", "column3"]
891+
description = [(name, ) for name in field_names]
850892

851893
t_cols = [
852894
ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes([1]))),
@@ -859,7 +901,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
859901
values=[b'\x11', b'\x22', b'\x33'], nulls=bytes([3])))
860902
]
861903

862-
arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table(t_cols, schema)
904+
arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table(
905+
t_cols, description)
863906
self.assertEqual(n_rows, 3)
864907

865908
# Check data
@@ -869,12 +912,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
869912
self.assertEqual(arrow_table.column(3).to_pylist(), [None, None, b'\x33'])
870913

871914
def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self):
872-
schema = pyarrow.schema([
873-
pyarrow.field("column1", pyarrow.string()),
874-
pyarrow.field("column2", pyarrow.string()),
875-
pyarrow.field("column3", pyarrow.string()),
876-
pyarrow.field("column3", pyarrow.string())
877-
])
915+
field_names = ["column1", "column2", "column3", "column3"]
916+
description = [(name, ) for name in field_names]
878917

879918
t_cols = [
880919
ttypes.TColumn(i32Val=ttypes.TI32Column(values=[1, 2, 3], nulls=bytes(1))),
@@ -885,7 +924,8 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self):
885924
binaryVal=ttypes.TBinaryColumn(values=[b'\x11', b'\x22', b'\x33'], nulls=bytes(1)))
886925
]
887926

888-
arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table(t_cols, schema)
927+
arrow_table, n_rows = ThriftBackend._convert_column_based_set_to_arrow_table(
928+
t_cols, description)
889929
self.assertEqual(n_rows, 3)
890930

891931
# Check schema, column names and types

0 commit comments

Comments
 (0)