Skip to content

Commit 9800636

Browse files
stronger typing of Cursor and ExecuteResponse
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent eecc67d commit 9800636

File tree

2 files changed

+21
-19
lines changed

2 files changed

+21
-19
lines changed

src/databricks/sql/backend/databricks_client.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from abc import ABC, abstractmethod
1414
from typing import Dict, Tuple, List, Optional, Any, Union
1515

16+
from databricks.sql.client import Cursor
1617
from databricks.sql.thrift_api.TCLIService import ttypes
1718
from databricks.sql.backend.types import SessionId, CommandId
1819
from databricks.sql.utils import ExecuteResponse
@@ -75,7 +76,7 @@ def execute_command(
7576
max_rows: int,
7677
max_bytes: int,
7778
lz4_compression: bool,
78-
cursor: Any,
79+
cursor: Cursor,
7980
use_cloud_fetch: bool,
8081
parameters: List[ttypes.TSparkParameter],
8182
async_op: bool,
@@ -173,7 +174,7 @@ def get_query_state(self, command_id: CommandId) -> ttypes.TOperationState:
173174
def get_execution_result(
174175
self,
175176
command_id: CommandId,
176-
cursor: Any,
177+
cursor: Cursor,
177178
) -> ExecuteResponse:
178179
"""
179180
Retrieves the results of a previously executed command.
@@ -201,8 +202,8 @@ def get_catalogs(
201202
session_id: SessionId,
202203
max_rows: int,
203204
max_bytes: int,
204-
cursor: Any,
205-
) -> Any:
205+
cursor: Cursor,
206+
) -> ExecuteResponse:
206207
"""
207208
Retrieves a list of available catalogs.
208209
@@ -230,10 +231,10 @@ def get_schemas(
230231
session_id: SessionId,
231232
max_rows: int,
232233
max_bytes: int,
233-
cursor: Any,
234+
cursor: Cursor,
234235
catalog_name: Optional[str] = None,
235236
schema_name: Optional[str] = None,
236-
) -> Any:
237+
) -> ExecuteResponse:
237238
"""
238239
Retrieves a list of available schemas.
239240
@@ -263,12 +264,12 @@ def get_tables(
263264
session_id: SessionId,
264265
max_rows: int,
265266
max_bytes: int,
266-
cursor: Any,
267+
cursor: Cursor,
267268
catalog_name: Optional[str] = None,
268269
schema_name: Optional[str] = None,
269270
table_name: Optional[str] = None,
270271
table_types: Optional[List[str]] = None,
271-
) -> Any:
272+
) -> ExecuteResponse:
272273
"""
273274
Retrieves a list of available tables.
274275
@@ -300,12 +301,12 @@ def get_columns(
300301
session_id: SessionId,
301302
max_rows: int,
302303
max_bytes: int,
303-
cursor: Any,
304+
cursor: Cursor,
304305
catalog_name: Optional[str] = None,
305306
schema_name: Optional[str] = None,
306307
table_name: Optional[str] = None,
307308
column_name: Optional[str] = None,
308-
) -> Any:
309+
) -> ExecuteResponse:
309310
"""
310311
Retrieves column metadata for tables.
311312

src/databricks/sql/backend/thrift_backend.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import threading
88
from typing import List, Union, Any
99

10+
from databricks.sql.client import Cursor
1011
from databricks.sql.thrift_api.TCLIService.ttypes import TOperationState
1112
from databricks.sql.backend.types import (
1213
SessionId,
@@ -928,7 +929,7 @@ def execute_command(
928929
max_rows: int,
929930
max_bytes: int,
930931
lz4_compression: bool,
931-
cursor: Any,
932+
cursor: Cursor,
932933
use_cloud_fetch=True,
933934
parameters=[],
934935
async_op=False,
@@ -986,8 +987,8 @@ def get_catalogs(
986987
session_id: SessionId,
987988
max_rows: int,
988989
max_bytes: int,
989-
cursor: Any,
990-
):
990+
cursor: Cursor,
991+
) -> ExecuteResponse:
991992
thrift_handle = session_id.to_thrift_handle()
992993
if not thrift_handle:
993994
raise ValueError("Not a valid Thrift session ID")
@@ -1006,10 +1007,10 @@ def get_schemas(
10061007
session_id: SessionId,
10071008
max_rows: int,
10081009
max_bytes: int,
1009-
cursor: Any,
1010+
cursor: Cursor,
10101011
catalog_name=None,
10111012
schema_name=None,
1012-
):
1013+
) -> ExecuteResponse:
10131014
thrift_handle = session_id.to_thrift_handle()
10141015
if not thrift_handle:
10151016
raise ValueError("Not a valid Thrift session ID")
@@ -1030,12 +1031,12 @@ def get_tables(
10301031
session_id: SessionId,
10311032
max_rows: int,
10321033
max_bytes: int,
1033-
cursor: Any,
1034+
cursor: Cursor,
10341035
catalog_name=None,
10351036
schema_name=None,
10361037
table_name=None,
10371038
table_types=None,
1038-
):
1039+
) -> ExecuteResponse:
10391040
thrift_handle = session_id.to_thrift_handle()
10401041
if not thrift_handle:
10411042
raise ValueError("Not a valid Thrift session ID")
@@ -1058,12 +1059,12 @@ def get_columns(
10581059
session_id: SessionId,
10591060
max_rows: int,
10601061
max_bytes: int,
1061-
cursor: Any,
1062+
cursor: Cursor,
10621063
catalog_name=None,
10631064
schema_name=None,
10641065
table_name=None,
10651066
column_name=None,
1066-
):
1067+
) -> ExecuteResponse:
10671068
thrift_handle = session_id.to_thrift_handle()
10681069
if not thrift_handle:
10691070
raise ValueError("Not a valid Thrift session ID")

0 commit comments

Comments
 (0)