Skip to content

Commit e8ad7c5

Browse files
authored
Python SDK documentation, tests and examples for 0.7.0 (#754)
1 parent 403816a commit e8ad7c5

19 files changed

+3697
-2449
lines changed

pgml-sdks/python/pgml/README.md

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -81,14 +81,16 @@ import json
8181
from datasets import load_dataset
8282
from time import time
8383
from rich import print as rprint
84+
import asyncio
8485

85-
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
86+
async def main():
87+
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
88+
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
8689

87-
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
88-
db = Database(conninfo)
90+
db = Database(conninfo)
8991

90-
collection_name = "test_pgml_sdk_1"
91-
collection = db.create_or_get_collection(collection_name)
92+
collection_name = "test_collection"
93+
collection = await db.create_or_get_collection(collection_name)
9294
```
9395

9496
**Explanation:**
@@ -98,19 +100,21 @@ collection = db.create_or_get_collection(collection_name)
98100
- An instance of the Database class is created by passing the connection information.
99101
- The method [`create_or_get_collection`](#create-or-get-a-collection) collection with the name `test_pgml_sdk_1` is retrieved if it exists or a new collection is created.
100102

101-
```python
102-
data = load_dataset("squad", split="train")
103-
data = data.to_pandas()
104-
data = data.drop_duplicates(subset=["context"])
105-
106-
documents = [
107-
{'id': r['id'], "text": r["context"], "title": r["title"]}
108-
for r in data.to_dict(orient="records")
109-
]
103+
Continuing within `async def main():`
110104

111-
collection.upsert_documents(documents[:200])
112-
collection.generate_chunks()
113-
collection.generate_embeddings()
105+
```python
106+
data = load_dataset("squad", split="train")
107+
data = data.to_pandas()
108+
data = data.drop_duplicates(subset=["context"])
109+
110+
documents = [
111+
{'id': r['id'], "text": r["context"], "title": r["title"]}
112+
for r in data.to_dict(orient="records")
113+
]
114+
115+
await collection.upsert_documents(documents[:200])
116+
await collection.generate_chunks()
117+
await collection.generate_embeddings()
114118
```
115119

116120
**Explanation:**
@@ -121,12 +125,13 @@ collection.generate_embeddings()
121125
- The [`generate_chunks`](#generate-chunks) method splits the documents into smaller text chunks for efficient indexing and search.
122126
- The [`generate_embeddings`](#generate-embeddings) method generates embeddings for the documents in the collection.
123127

128+
Continuing within `async def main():`
124129
```python
125-
start = time()
126-
results = collection.vector_search("Who won 20 grammy awards?", top_k=2)
127-
rprint(json.dumps(results, indent=2))
128-
rprint("Query time: %0.3f seconds" % (time() - start))
129-
db.archive_collection(collection_name)
130+
start = time()
131+
results = await collection.vector_search("Who won 20 grammy awards?", top_k=2)
132+
rprint(json.dumps(results, indent=2))
133+
rprint("Query time: %0.3f seconds" % (time() - start))
134+
await db.archive_collection(collection_name)
130135
```
131136

132137
**Explanation:**
@@ -137,6 +142,12 @@ db.archive_collection(collection_name)
137142
- The query time is calculated by subtracting the start time from the current time.
138143
- Finally, the `archive_collection` method is called to archive the collection and free up resources in the PostgresML database.
139144

145+
Call `main` function in an async loop.
146+
147+
```python
148+
if __name__ == "__main__":
149+
asyncio.run(main())
150+
```
140151
**Running the Code**
141152

142153
Open a terminal or command prompt and navigate to the directory where the file is saved.
@@ -193,32 +204,32 @@ This initializes a connection pool to the DB and creates a table named `pgml.col
193204
#### Create or Get a Collection
194205

195206
```python
196-
collection_name = "test_pgml_sdk_1"
197-
collection = db.create_or_get_collection(collection_name)
207+
collection_name = "test_collection"
208+
collection = await db.create_or_get_collection(collection_name)
198209
```
199210

200211
This creates a new schema in a PostgreSQL database if it does not already exist and creates tables and indices for documents, chunks, models, splitters, and embeddings.
201212

202213
#### Upsert Documents
203214

204215
```python
205-
collection.upsert_documents(documents)
216+
await collection.upsert_documents(documents)
206217
```
207218

208219
The method is used to insert or update documents in a database table based on their ID, text, and metadata.
209220

210221
#### Generate Chunks
211222

212223
```python
213-
collection.generate_chunks(splitter_id = 1)
224+
await collection.generate_chunks(splitter_id = 1)
214225
```
215226

216227
This method is used to generate chunks of text from unchunked documents using a specified text splitter. By default it uses `RecursiveCharacterTextSplitter` with default parameters. `splitter_id` is optional. You can pass a `splitter_id` corresponding to a new splitter that is registered. See below for `register_text_splitter`.
217228

218229
#### Generate Embeddings
219230

220231
```python
221-
collection.generate_embeddings(model_id = 1, splitter_id = 1)
232+
await collection.generate_embeddings(model_id = 1, splitter_id = 1)
222233
```
223234

224235
This methods generates embeddings uing the chunks from the text. By default it uses `intfloat/e5-small` embeddings model. `model_id` is optional. You can pass a `model_id` corresponding to a new model that is registered and `splitter_id`. See below for `register_model`.
@@ -227,53 +238,42 @@ This methods generates embeddings uing the chunks from the text. By default it u
227238
#### Vector Search
228239

229240
```python
230-
results = collection.vector_search("Who won 20 grammy awards?", top_k=2, model_id = 1, splitter_id = 1)
241+
results = await collection.vector_search("Who won 20 grammy awards?", top_k=2, model_id = 1, splitter_id = 1)
231242
```
232243

233244
This method converts the input query into embeddings and searches embeddings table for nearest match. You can change the number of results using `top_k`. You can also pass specific `splitter_id` and `model_id` that were used for chunking and generating embeddings.
234245

235246
#### Register Model
236247

237248
```python
238-
collection.register_model(model_name="hkunlp/instructor-xl", model_params={"instruction": "Represent the Wikipedia document for retrieval: "})
249+
await collection.register_model(model_name="hkunlp/instructor-xl", model_params={"instruction": "Represent the Wikipedia document for retrieval: "})
239250
```
240251

241252
This function allows for the registration of a model in a database, creating a record if it does not already exist. `model_name` is the name of the open source HuggingFace model being registered and `model_params` is a dictionary containing parameters for configuring the model. It can be empty if no parameters are needed.
242253

243254
#### Register Text Splitter
244255

245256
```python
246-
collection.register_text_splitter(splitter_name="RecursiveCharacterTextSplitter",splitter_params={"chunk_size": 100,"chunk_overlap": 20})
257+
await collection.register_text_splitter(splitter_name="recursive_character",splitter_params={"chunk_size": 100,"chunk_overlap": 20})
247258
```
248259

249-
This function allows for the registration of a text spliter in a database, creating a record if it doesn't already exist. `splitter_name` is the name of the splitter from [LangChain](https://python.langchain.com/en/latest/reference/modules/text_splitter.html) and `splitter_params` are chunking parameters that the splitter supports.
250-
251-
252-
### Developer Setup
253-
1. Install Python 3.11. SDK should work for Python >=3.8.
254-
2. Install poetry `pip install poetry`
255-
3. Initialize Python environment
256-
257-
```
258-
poetry env use python3.11
259-
poetry shell
260-
poetry install
261-
poetry build
262-
```
263-
4. SDK uses your local PostgresML database by default
264-
`postgres://postgres@127.0.0.1:5433/pgml_development`
265-
266-
If it is not up to date with `pgml.embed` please [signup for a free database](https://postgresml.org/signup) and set `PGML_CONNECTION` environment variable with serverless hosted database.
260+
This function allows for the registration of a text spliter in a database, creating a record if it doesn't already exist. Following [LangChain](https://python.langchain.com/en/latest/reference/modules/text_splitter.html) splitters are supported.
267261

268262
```
269-
export PGML_CONNECTION="postgres://<username>:<password>@<hostname>:<port>/pgm<database>"
263+
SPLITTERS = {
264+
"character": CharacterTextSplitter,
265+
"latex": LatexTextSplitter,
266+
"markdown": MarkdownTextSplitter,
267+
"nltk": NLTKTextSplitter,
268+
"python": PythonCodeTextSplitter,
269+
"recursive_character": RecursiveCharacterTextSplitter,
270+
"spacy": SpacyTextSplitter,
271+
}
270272
```
271273

272-
5. Run tests
273274

274-
```
275-
LOGLEVEL=INFO python -m unittest tests/test_collection.py
276-
```
275+
### Developer Setup
276+
This Python library is generated from our core rust-sdk. Please check [rust-sdk documentation](../../rust/pgml/README.md) for developer setup.
277277

278278
### API Reference
279279

pgml-sdks/python/pgml/examples/extractive_question_answering.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -5,65 +5,74 @@
55
from time import time
66
from dotenv import load_dotenv
77
from rich.console import Console
8-
from psycopg import sql
9-
from pgml.dbutils import run_select_statement
8+
from psycopg_pool import ConnectionPool
9+
import asyncio
1010

11-
load_dotenv()
12-
console = Console()
11+
async def main():
12+
load_dotenv()
13+
console = Console()
14+
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
15+
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
16+
db = Database(conninfo)
1317

14-
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
18+
collection_name = "squad_collection"
19+
collection = await db.create_or_get_collection(collection_name)
1520

16-
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
17-
db = Database(conninfo)
1821

19-
collection_name = "squad_collection"
20-
collection = db.create_or_get_collection(collection_name)
22+
data = load_dataset("squad", split="train")
23+
data = data.to_pandas()
24+
data = data.drop_duplicates(subset=["context"])
2125

26+
documents = [
27+
{"id": r["id"], "text": r["context"], "title": r["title"]}
28+
for r in data.to_dict(orient="records")
29+
]
2230

23-
data = load_dataset("squad", split="train")
24-
data = data.to_pandas()
25-
data = data.drop_duplicates(subset=["context"])
26-
27-
documents = [
28-
{"id": r["id"], "text": r["context"], "title": r["title"]}
29-
for r in data.to_dict(orient="records")
30-
]
31-
32-
collection.upsert_documents(documents[:200])
33-
collection.generate_chunks()
34-
collection.generate_embeddings()
31+
console.print("Upserting documents ..")
32+
await collection.upsert_documents(documents[:200])
33+
console.print("Generating chunks ..")
34+
await collection.generate_chunks()
35+
console.print("Generating embeddings ..")
36+
await collection.generate_embeddings()
3537

36-
start = time()
37-
query = "Who won more than 20 grammy awards?"
38-
results = collection.vector_search(query, top_k=5)
39-
_end = time()
40-
console.print("\nResults for '%s'" % (query), style="bold")
41-
console.print(results)
42-
console.print("Query time = %0.3f" % (_end - start))
38+
console.print("Querying ..")
39+
start = time()
40+
query = "Who won more than 20 grammy awards?"
41+
results = await collection.vector_search(query, top_k=5)
42+
_end = time()
43+
console.print("\nResults for '%s'" % (query), style="bold")
44+
console.print(results)
45+
console.print("Query time = %0.3f" % (_end - start))
4346

44-
# Get the context passage and use pgml.transform to get short answer to the question
47+
# Get the context passage and use pgml.transform to get short answer to the question
4548

49+
console.print("Getting context passage ..")
50+
context = " ".join(results[0][1].strip().split())
51+
context = context.replace('"', '\\"').replace("'", "''")
4652

47-
conn = db.pool.getconn()
48-
context = " ".join(results[0]["chunk"].strip().split())
49-
context = context.replace('"', '\\"').replace("'", "''")
53+
select_statement = """SELECT pgml.transform(
54+
'question-answering',
55+
inputs => ARRAY[
56+
'{
57+
\"question\": \"%s\",
58+
\"context\": \"%s\"
59+
}'
60+
]
61+
) AS answer;""" % (
62+
query,
63+
context,
64+
)
5065

51-
select_statement = """SELECT pgml.transform(
52-
'question-answering',
53-
inputs => ARRAY[
54-
'{
55-
\"question\": \"%s\",
56-
\"context\": \"%s\"
57-
}'
58-
]
59-
) AS answer;""" % (
60-
query,
61-
context,
62-
)
66+
pool = ConnectionPool(conninfo)
67+
conn = pool.getconn()
68+
cursor = conn.cursor()
69+
cursor.execute(select_statement)
70+
results = cursor.fetchall()
71+
pool.putconn(conn)
6372

64-
results = run_select_statement(conn, select_statement)
65-
db.pool.putconn(conn)
73+
console.print("\nResults for query '%s'" % query)
74+
console.print(results)
75+
await db.archive_collection(collection_name)
6676

67-
console.print("\nResults for query '%s'" % query)
68-
console.print(results)
69-
db.archive_collection(collection_name)
77+
if __name__ == "__main__":
78+
asyncio.run(main())

0 commit comments

Comments
 (0)