@@ -1131,6 +1131,65 @@ impl Collection {
1131
1131
. collect ( ) )
1132
1132
}
1133
1133
1134
+ /// Performs rag on the [Collection]
1135
+ ///
1136
+ /// # Arguments
1137
+ /// * `query` - The query to search for
1138
+ /// * `pipeline` - The [Pipeline] to use for the search
1139
+ ///
1140
+ /// # Example
1141
+ /// ```
1142
+ /// use pgml::Collection;
1143
+ /// use pgml::Pipeline;
1144
+ /// use serde_json::json;
1145
+ /// use anyhow::Result;
1146
+ /// async fn run() -> anyhow::Result<()> {
1147
+ /// let mut collection = Collection::new("my_collection", None)?;
1148
+ /// let mut pipeline = Pipeline::new("my_pipeline", None)?;
1149
+ /// let results = collection.rag(json!({
1150
+ /// "CONTEXT": {
1151
+ /// "vector_search": {
1152
+ /// "query": {
1153
+ /// "fields": {
1154
+ /// "body": {
1155
+ /// "query": "Test document: 2",
1156
+ /// "parameters": {
1157
+ /// "prompt": "query: "
1158
+ /// }
1159
+ /// },
1160
+ /// },
1161
+ /// },
1162
+ /// "document": {
1163
+ /// "keys": [
1164
+ /// "id"
1165
+ /// ]
1166
+ /// },
1167
+ /// "limit": 2
1168
+ /// },
1169
+ /// "aggregate": {
1170
+ /// "join": "\n"
1171
+ /// }
1172
+ /// },
1173
+ /// "CUSTOM": {
1174
+ /// "sql": "SELECT 'test'"
1175
+ /// },
1176
+ /// "chat": {
1177
+ /// "model": "meta-llama/Meta-Llama-3-8B-Instruct",
1178
+ /// "messages": [
1179
+ /// {
1180
+ /// "role": "system",
1181
+ /// "content": "You are a friendly and helpful chatbot"
1182
+ /// },
1183
+ /// {
1184
+ /// "role": "user",
1185
+ /// "content": "Some text with {CONTEXT} - {CUSTOM}",
1186
+ /// }
1187
+ /// ],
1188
+ /// "max_tokens": 10
1189
+ /// }
1190
+ /// }).into(), &mut pipeline).await?;
1191
+ /// Ok(())
1192
+ /// }
1134
1193
#[ instrument( skip( self ) ) ]
1135
1194
pub async fn rag ( & self , query : Json , pipeline : & mut Pipeline ) -> anyhow:: Result < Json > {
1136
1195
let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
@@ -1141,6 +1200,7 @@ impl Collection {
1141
1200
Ok ( std:: mem:: take ( & mut results[ 0 ] . 0 ) )
1142
1201
}
1143
1202
1203
+ /// Same as rag buit returns a stream of results
1144
1204
#[ instrument( skip( self ) ) ]
1145
1205
pub async fn rag_stream (
1146
1206
& self ,
0 commit comments