1
+ use anyhow:: Context ;
1
2
use rust_bridge:: { alias, alias_methods} ;
2
3
use sqlx:: Row ;
3
4
use tracing:: instrument;
@@ -13,7 +14,7 @@ use crate::{get_or_initialize_pool, query_runner::QueryRunner, types::Json};
13
14
#[ cfg( feature = "python" ) ]
14
15
use crate :: { query_runner:: QueryRunnerPython , types:: JsonPython } ;
15
16
16
- #[ alias_methods( new, query, transform) ]
17
+ #[ alias_methods( new, query, transform, embed , embed_batch ) ]
17
18
impl Builtins {
18
19
pub fn new ( database_url : Option < String > ) -> Self {
19
20
Self { database_url }
@@ -87,6 +88,55 @@ impl Builtins {
87
88
let results = results. first ( ) . unwrap ( ) . get :: < serde_json:: Value , _ > ( 0 ) ;
88
89
Ok ( Json ( results) )
89
90
}
91
+
92
+ /// Run the built-in `pgml.embed()` function.
93
+ ///
94
+ /// # Arguments
95
+ ///
96
+ /// * `model` - The model to use.
97
+ /// * `text` - The text to embed.
98
+ ///
99
+ pub async fn embed ( & self , model : & str , text : & str ) -> anyhow:: Result < Json > {
100
+ let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
101
+ let query = sqlx:: query ( "SELECT embed FROM pgml.embed($1, $2)" ) ;
102
+ let result = query. bind ( model) . bind ( text) . fetch_one ( & pool) . await ?;
103
+ let result = result. get :: < Vec < f32 > , _ > ( 0 ) ;
104
+ let result = serde_json:: to_value ( result) ?;
105
+ Ok ( Json ( result) )
106
+ }
107
+
108
+ /// Run the built-in `pgml.embed()` function, but with handling for batch inputs and outputs.
109
+ ///
110
+ /// # Arguments
111
+ ///
112
+ /// * `model` - The model to use.
113
+ /// * `texts` - The texts to embed.
114
+ ///
115
+ pub async fn embed_batch ( & self , model : & str , texts : Json ) -> anyhow:: Result < Json > {
116
+ let texts = texts
117
+ . 0
118
+ . as_array ( )
119
+ . with_context ( || "embed_batch takes an array of strings" ) ?
120
+ . into_iter ( )
121
+ . map ( |v| {
122
+ v. as_str ( )
123
+ . with_context ( || "only text embeddings are supported" )
124
+ . unwrap ( )
125
+ . to_string ( )
126
+ } )
127
+ . collect :: < Vec < String > > ( ) ;
128
+ let pool = get_or_initialize_pool ( & self . database_url ) . await ?;
129
+ let query = sqlx:: query ( "SELECT embed AS embed_batch FROM pgml.embed($1, $2)" ) ;
130
+ let results = query
131
+ . bind ( model)
132
+ . bind ( texts)
133
+ . fetch_all ( & pool)
134
+ . await ?
135
+ . into_iter ( )
136
+ . map ( |embeddings| embeddings. get :: < Vec < f32 > , _ > ( 0 ) )
137
+ . collect :: < Vec < Vec < f32 > > > ( ) ;
138
+ Ok ( Json ( serde_json:: to_value ( results) ?) )
139
+ }
90
140
}
91
141
92
142
#[ cfg( test) ]
@@ -117,4 +167,28 @@ mod tests {
117
167
assert ! ( results. as_array( ) . is_some( ) ) ;
118
168
Ok ( ( ) )
119
169
}
170
+
171
+ #[ tokio:: test]
172
+ async fn can_embed ( ) -> anyhow:: Result < ( ) > {
173
+ internal_init_logger ( None , None ) . ok ( ) ;
174
+ let builtins = Builtins :: new ( None ) ;
175
+ let results = builtins. embed ( "intfloat/e5-small-v2" , "test" ) . await ?;
176
+ assert ! ( results. as_array( ) . is_some( ) ) ;
177
+ Ok ( ( ) )
178
+ }
179
+
180
+ #[ tokio:: test]
181
+ async fn can_embed_batch ( ) -> anyhow:: Result < ( ) > {
182
+ internal_init_logger ( None , None ) . ok ( ) ;
183
+ let builtins = Builtins :: new ( None ) ;
184
+ let results = builtins
185
+ . embed_batch (
186
+ "intfloat/e5-small-v2" ,
187
+ Json ( serde_json:: json!( [ "test" , "test2" , ] ) ) ,
188
+ )
189
+ . await ?;
190
+ assert ! ( results. as_array( ) . is_some( ) ) ;
191
+ assert_eq ! ( results. as_array( ) . unwrap( ) . len( ) , 2 ) ;
192
+ Ok ( ( ) )
193
+ }
120
194
}
0 commit comments