Skip to content

Commit cac1a6a

Browse files
authored
pgml Python SDK with vector search support (#636)
1 parent 04f7e26 commit cac1a6a

File tree

11 files changed

+4009
-0
lines changed

11 files changed

+4009
-0
lines changed

pgml-sdks/python/pgml/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# PostgresML Python SDK
2+
This Python SDK provides an easy interface to use PostgresML generative AI capabilities.
3+
4+
## Table of Contents
5+
6+
- [Quickstart](#quickstart)
7+
8+
### Quickstart
9+
1. Install Python 3.11. SDK should work for Python >=3.8. However, at this time, we have only tested Python 3.11.
10+
2. Clone the repository and checkout the SDK branch (before PR)
11+
```
12+
git clone https://github.com/postgresml/postgresml
13+
cd postgresml
14+
git checkout santi-pgml-memory-sdk-python
15+
cd pgml-sdks/python/pgml
16+
```
17+
3. Install poetry `pip install poetry`
18+
4. Initialize Python environment
19+
20+
```
21+
poetry env use python3.11
22+
poetry shell
23+
poetry install
24+
poetry build
25+
```
26+
5. SDK uses your local PostgresML database by default
27+
`postgres://postgres@127.0.0.1:5433/pgml_development`
28+
29+
If it is not up to date with `pgml.embed` please [signup for a free database](https://postgresml.org/signup) and set `PGML_CONNECTION` environment variable with serverless hosted database.
30+
31+
```
32+
export PGML_CONNECTION="postgres://<username>:<password>@<hostname>:<port>/pgm<database>"
33+
```
34+
6. Run a **vector search** example
35+
```
36+
python examples/vector_search.py
37+
```
38+
Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from pgml import Database\n",
10+
"import os\n",
11+
"import json"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"metadata": {},
18+
"outputs": [],
19+
"source": [
20+
"local_pgml = \"postgres://postgres@127.0.0.1:5433/pgml_development\"\n",
21+
"\n",
22+
"conninfo = os.environ.get(\"PGML_CONNECTION\",local_pgml)\n",
23+
"db = Database(conninfo,min_connections=4)"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": null,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"collection_name = \"test_pgml_sdk_1\"\n",
33+
"collection = db.create_or_get_collection(collection_name)"
34+
]
35+
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"metadata": {},
40+
"outputs": [],
41+
"source": [
42+
"from datasets import load_dataset\n",
43+
"\n",
44+
"data = load_dataset(\"squad\", split=\"train\")\n",
45+
"data = data.to_pandas()\n",
46+
"data.head()\n",
47+
"\n",
48+
"data = data.drop_duplicates(subset=[\"context\"])\n",
49+
"print(len(data))\n",
50+
"data.head()\n",
51+
"\n",
52+
"documents = [\n",
53+
" {\n",
54+
" 'text': r['context'],\n",
55+
" 'metadata': {\n",
56+
" 'title': r['title']\n",
57+
" }\n",
58+
" } for r in data.to_dict(orient='records')\n",
59+
"]\n",
60+
"documents[:3]"
61+
]
62+
},
63+
{
64+
"cell_type": "code",
65+
"execution_count": null,
66+
"metadata": {},
67+
"outputs": [],
68+
"source": [
69+
"collection.upsert_documents(documents[0:200])\n",
70+
"collection.generate_chunks()\n",
71+
"collection.generate_embeddings()"
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2)\n",
81+
"print(json.dumps(results,indent=2))"
82+
]
83+
},
84+
{
85+
"cell_type": "code",
86+
"execution_count": null,
87+
"metadata": {},
88+
"outputs": [],
89+
"source": [
90+
"collection.register_model(model_name=\"paraphrase-MiniLM-L6-v2\")"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": null,
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"collection.get_models()"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": null,
105+
"metadata": {},
106+
"outputs": [],
107+
"source": [
108+
"print(json.dumps(collection.get_models(),indent=2))"
109+
]
110+
},
111+
{
112+
"cell_type": "code",
113+
"execution_count": null,
114+
"metadata": {},
115+
"outputs": [],
116+
"source": [
117+
"collection.generate_embeddings(model_id=2)"
118+
]
119+
},
120+
{
121+
"cell_type": "code",
122+
"execution_count": null,
123+
"metadata": {},
124+
"outputs": [],
125+
"source": [
126+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2, model_id=2)\n",
127+
"print(json.dumps(results,indent=2))"
128+
]
129+
},
130+
{
131+
"cell_type": "code",
132+
"execution_count": null,
133+
"metadata": {},
134+
"outputs": [],
135+
"source": [
136+
"collection.register_model(model_name=\"hkunlp/instructor-xl\", model_params={\"instruction\": \"Represent the Wikipedia document for retrieval: \"})"
137+
]
138+
},
139+
{
140+
"cell_type": "code",
141+
"execution_count": null,
142+
"metadata": {},
143+
"outputs": [],
144+
"source": [
145+
"collection.get_models()"
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"metadata": {},
152+
"outputs": [],
153+
"source": [
154+
"collection.generate_embeddings(model_id=3)"
155+
]
156+
},
157+
{
158+
"cell_type": "code",
159+
"execution_count": null,
160+
"metadata": {},
161+
"outputs": [],
162+
"source": [
163+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2, model_id=3, query_parameters={\"instruction\": \"Represent the Wikipedia question for retrieving supporting documents: \"})\n",
164+
"print(json.dumps(results,indent=2))"
165+
]
166+
},
167+
{
168+
"cell_type": "code",
169+
"execution_count": null,
170+
"metadata": {},
171+
"outputs": [],
172+
"source": [
173+
"collection.register_text_splitter(splitter_name=\"RecursiveCharacterTextSplitter\",splitter_params={\"chunk_size\": 100,\"chunk_overlap\": 20})"
174+
]
175+
},
176+
{
177+
"cell_type": "code",
178+
"execution_count": null,
179+
"metadata": {},
180+
"outputs": [],
181+
"source": [
182+
"collection.generate_chunks(splitter_id=2)"
183+
]
184+
},
185+
{
186+
"cell_type": "code",
187+
"execution_count": null,
188+
"metadata": {},
189+
"outputs": [],
190+
"source": [
191+
"collection.generate_embeddings(splitter_id=2)"
192+
]
193+
},
194+
{
195+
"cell_type": "code",
196+
"execution_count": null,
197+
"metadata": {},
198+
"outputs": [],
199+
"source": [
200+
"results = collection.vector_search(\"Who won 20 Grammy awards?\", top_k=2, splitter_id=2)\n",
201+
"print(json.dumps(results,indent=2))"
202+
]
203+
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": null,
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"db.delete_collection(collection_name)"
211+
]
212+
}
213+
],
214+
"metadata": {
215+
"kernelspec": {
216+
"display_name": "pgml-zoggicR5-py3.11",
217+
"language": "python",
218+
"name": "python3"
219+
},
220+
"language_info": {
221+
"codemirror_mode": {
222+
"name": "ipython",
223+
"version": 3
224+
},
225+
"file_extension": ".py",
226+
"mimetype": "text/x-python",
227+
"name": "python",
228+
"nbconvert_exporter": "python",
229+
"pygments_lexer": "ipython3",
230+
"version": "3.11.3"
231+
},
232+
"orig_nbformat": 4
233+
},
234+
"nbformat": 4,
235+
"nbformat_minor": 2
236+
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pgml import Database
2+
import os
3+
import json
4+
from datasets import load_dataset
5+
from time import time
6+
from rich import print as rprint
7+
8+
local_pgml = "postgres://postgres@127.0.0.1:5433/pgml_development"
9+
10+
conninfo = os.environ.get("PGML_CONNECTION", local_pgml)
11+
db = Database(conninfo)
12+
13+
collection_name = "test_pgml_sdk_1"
14+
collection = db.create_or_get_collection(collection_name)
15+
16+
17+
data = load_dataset("squad", split="train")
18+
data = data.to_pandas()
19+
data = data.drop_duplicates(subset=["context"])
20+
21+
documents = [
22+
{'id': r['id'], "text": r["context"], "title": r["title"]}
23+
for r in data.to_dict(orient="records")
24+
]
25+
26+
collection.upsert_documents(documents[:200])
27+
collection.generate_chunks()
28+
collection.generate_embeddings()
29+
30+
start = time()
31+
results = collection.vector_search("Who won 20 grammy awards?", top_k=2)
32+
rprint(json.dumps(results, indent=2))
33+
rprint("Query time %0.3f"%(time()-start))
34+
db.archive_collection(collection_name)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .database import Database
2+
from .collection import Collection
3+
from .dbutils import (
4+
run_create_or_insert_statement,
5+
run_select_statement,
6+
run_drop_or_delete_statement,
7+
)

0 commit comments

Comments
 (0)