@@ -497,6 +497,54 @@ def test_handle_execute_response_can_handle_with_direct_results(self):
497497 ttypes .TOperationState .FINISHED_STATE ,
498498 )
499499
500+ @patch ("databricks.sql.thrift_backend.TCLIService.Client" )
501+ def test_use_arrow_schema_if_available (self , tcli_service_class ):
502+ tcli_service_instance = tcli_service_class .return_value
503+ arrow_schema_mock = MagicMock (name = "Arrow schema mock" )
504+ hive_schema_mock = MagicMock (name = "Hive schema mock" )
505+
506+ t_get_result_set_metadata_resp = ttypes .TGetResultSetMetadataResp (
507+ status = self .okay_status ,
508+ resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET ,
509+ schema = hive_schema_mock ,
510+ arrowSchema = arrow_schema_mock )
511+
512+ t_execute_resp = ttypes .TExecuteStatementResp (
513+ status = self .okay_status ,
514+ directResults = None ,
515+ operationHandle = self .operation_handle ,
516+ )
517+
518+ tcli_service_instance .GetResultSetMetadata .return_value = t_get_result_set_metadata_resp
519+ thrift_backend = self ._make_fake_thrift_backend ()
520+ execute_response = thrift_backend ._handle_execute_response (t_execute_resp , Mock ())
521+
522+ self .assertEqual (execute_response .arrow_schema_bytes , arrow_schema_mock )
523+
524+ @patch ("databricks.sql.thrift_backend.TCLIService.Client" )
525+ def test_fall_back_to_hive_schema_if_no_arrow_schema (self , tcli_service_class ):
526+ tcli_service_instance = tcli_service_class .return_value
527+ hive_schema_mock = MagicMock (name = "Hive schema mock" )
528+
529+ hive_schema_req = ttypes .TGetResultSetMetadataResp (
530+ status = self .okay_status ,
531+ resultFormat = ttypes .TSparkRowSetType .ARROW_BASED_SET ,
532+ arrowSchema = None ,
533+ schema = hive_schema_mock )
534+
535+ t_execute_resp = ttypes .TExecuteStatementResp (
536+ status = self .okay_status ,
537+ directResults = None ,
538+ operationHandle = self .operation_handle ,
539+ )
540+
541+ tcli_service_instance .GetResultSetMetadata .return_value = hive_schema_req
542+ thrift_backend = self ._make_fake_thrift_backend ()
543+ thrift_backend ._handle_execute_response (t_execute_resp , Mock ())
544+
545+ self .assertEqual (hive_schema_mock ,
546+ thrift_backend ._hive_schema_to_arrow_schema .call_args [0 ][0 ])
547+
500548 @patch ("databricks.sql.thrift_backend.TCLIService.Client" )
501549 def test_handle_execute_response_reads_has_more_rows_in_direct_results (
502550 self , tcli_service_class ):
@@ -567,7 +615,7 @@ def test_handle_execute_response_reads_has_more_rows_in_result_response(
567615 max_rows = 1 ,
568616 max_bytes = 1 ,
569617 expected_row_start_offset = 0 ,
570- arrow_schema = Mock (),
618+ arrow_schema_bytes = Mock (),
571619 description = Mock ())
572620
573621 self .assertEqual (has_more_rows , has_more_rows_resp )
@@ -591,15 +639,15 @@ def test_arrow_batches_row_count_are_respected(self, tcli_service_class):
591639 pyarrow .field ("column2" , pyarrow .string ()),
592640 pyarrow .field ("column3" , pyarrow .float64 ()),
593641 pyarrow .field ("column3" , pyarrow .binary ())
594- ])
642+ ]). serialize (). to_pybytes ()
595643
596644 thrift_backend = ThriftBackend ("foobar" , 443 , "path" , [])
597645 arrow_queue , has_more_results = thrift_backend .fetch_results (
598646 op_handle = Mock (),
599647 max_rows = 1 ,
600648 max_bytes = 1 ,
601649 expected_row_start_offset = 0 ,
602- arrow_schema = schema ,
650+ arrow_schema_bytes = schema ,
603651 description = MagicMock ())
604652
605653 self .assertEqual (arrow_queue .n_valid_rows , 15 * 10 )
@@ -792,24 +840,21 @@ def test_create_arrow_table_calls_correct_conversion_method(self, convert_col_mo
792840 schema = Mock ()
793841 cols = Mock ()
794842 arrow_batches = Mock ()
843+ description = Mock ()
795844
796845 t_col_set = ttypes .TRowSet (columns = cols )
797- thrift_backend ._create_arrow_table (t_col_set , schema , Mock () )
846+ thrift_backend ._create_arrow_table (t_col_set , schema , description )
798847 convert_arrow_mock .assert_not_called ()
799- convert_col_mock .assert_called_once_with (cols , schema )
848+ convert_col_mock .assert_called_once_with (cols , description )
800849
801850 t_arrow_set = ttypes .TRowSet (arrowBatches = arrow_batches )
802851 thrift_backend ._create_arrow_table (t_arrow_set , schema , Mock ())
803852 convert_arrow_mock .assert_called_once_with (arrow_batches , schema )
804- convert_col_mock .assert_called_once_with (cols , schema )
805853
806854 def test_convert_column_based_set_to_arrow_table_without_nulls (self ):
807- schema = pyarrow .schema ([
808- pyarrow .field ("column1" , pyarrow .int32 ()),
809- pyarrow .field ("column2" , pyarrow .string ()),
810- pyarrow .field ("column3" , pyarrow .float64 ()),
811- pyarrow .field ("column3" , pyarrow .binary ())
812- ])
855+ # Deliberately duplicate the column name to check that dups work
856+ field_names = ["column1" , "column2" , "column3" , "column3" ]
857+ description = [(name , ) for name in field_names ]
813858
814859 t_cols = [
815860 ttypes .TColumn (i32Val = ttypes .TI32Column (values = [1 , 2 , 3 ], nulls = bytes (1 ))),
@@ -820,7 +865,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
820865 binaryVal = ttypes .TBinaryColumn (values = [b'\x11 ' , b'\x22 ' , b'\x33 ' ], nulls = bytes (1 )))
821866 ]
822867
823- arrow_table , n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (t_cols , schema )
868+ arrow_table , n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (
869+ t_cols , description )
824870 self .assertEqual (n_rows , 3 )
825871
826872 # Check schema, column names and types
@@ -841,12 +887,8 @@ def test_convert_column_based_set_to_arrow_table_without_nulls(self):
841887 self .assertEqual (arrow_table .column (3 ).to_pylist (), [b'\x11 ' , b'\x22 ' , b'\x33 ' ])
842888
843889 def test_convert_column_based_set_to_arrow_table_with_nulls (self ):
844- schema = pyarrow .schema ([
845- pyarrow .field ("column1" , pyarrow .int32 ()),
846- pyarrow .field ("column2" , pyarrow .string ()),
847- pyarrow .field ("column3" , pyarrow .float64 ()),
848- pyarrow .field ("column3" , pyarrow .binary ())
849- ])
890+ field_names = ["column1" , "column2" , "column3" , "column3" ]
891+ description = [(name , ) for name in field_names ]
850892
851893 t_cols = [
852894 ttypes .TColumn (i32Val = ttypes .TI32Column (values = [1 , 2 , 3 ], nulls = bytes ([1 ]))),
@@ -859,7 +901,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
859901 values = [b'\x11 ' , b'\x22 ' , b'\x33 ' ], nulls = bytes ([3 ])))
860902 ]
861903
862- arrow_table , n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (t_cols , schema )
904+ arrow_table , n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (
905+ t_cols , description )
863906 self .assertEqual (n_rows , 3 )
864907
865908 # Check data
@@ -869,12 +912,8 @@ def test_convert_column_based_set_to_arrow_table_with_nulls(self):
869912 self .assertEqual (arrow_table .column (3 ).to_pylist (), [None , None , b'\x33 ' ])
870913
871914 def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set (self ):
872- schema = pyarrow .schema ([
873- pyarrow .field ("column1" , pyarrow .string ()),
874- pyarrow .field ("column2" , pyarrow .string ()),
875- pyarrow .field ("column3" , pyarrow .string ()),
876- pyarrow .field ("column3" , pyarrow .string ())
877- ])
915+ field_names = ["column1" , "column2" , "column3" , "column3" ]
916+ description = [(name , ) for name in field_names ]
878917
879918 t_cols = [
880919 ttypes .TColumn (i32Val = ttypes .TI32Column (values = [1 , 2 , 3 ], nulls = bytes (1 ))),
@@ -885,7 +924,8 @@ def test_convert_column_based_set_to_arrow_table_uses_types_from_col_set(self):
885924 binaryVal = ttypes .TBinaryColumn (values = [b'\x11 ' , b'\x22 ' , b'\x33 ' ], nulls = bytes (1 )))
886925 ]
887926
888- arrow_table , n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (t_cols , schema )
927+ arrow_table , n_rows = ThriftBackend ._convert_column_based_set_to_arrow_table (
928+ t_cols , description )
889929 self .assertEqual (n_rows , 3 )
890930
891931 # Check schema, column names and types
0 commit comments