@@ -1017,7 +1017,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
1017
1017
logs ["step" ] = state .global_step
1018
1018
logs ["max_steps" ] = state .max_steps
1019
1019
logs ["timestamp" ] = str (datetime .now ())
1020
- print_info (json .dumps (logs ))
1020
+ print_info (json .dumps (logs , indent = 4 ))
1021
1021
insert_logs (self .project_id , self .model_id , json .dumps (logs ))
1022
1022
1023
1023
@@ -1275,7 +1275,6 @@ def evaluate(self):
1275
1275
1276
1276
if "eval_accuracy" in metrics .keys ():
1277
1277
metrics ["accuracy" ] = metrics .pop ("eval_accuracy" )
1278
-
1279
1278
1280
1279
# Drop all the keys that are not floats or ints to be compatible for pgml-extension metrics typechecks
1281
1280
metrics = {
@@ -1286,6 +1285,7 @@ def evaluate(self):
1286
1285
1287
1286
return metrics
1288
1287
1288
+
1289
1289
class FineTuningTextPairClassification (FineTuningTextClassification ):
1290
1290
def __init__ (
1291
1291
self ,
@@ -1313,7 +1313,7 @@ def __init__(
1313
1313
super ().__init__ (
1314
1314
project_id , model_id , train_dataset , test_dataset , path , hyperparameters
1315
1315
)
1316
-
1316
+
1317
1317
def tokenize_function (self , example ):
1318
1318
"""
1319
1319
Tokenizes the input text using the tokenizer specified in the class.
@@ -1326,13 +1326,20 @@ def tokenize_function(self, example):
1326
1326
1327
1327
"""
1328
1328
if self .tokenizer_args :
1329
- tokenized_example = self .tokenizer (example ["text1" ], example ["text2" ], ** self .tokenizer_args )
1329
+ tokenized_example = self .tokenizer (
1330
+ example ["text1" ], example ["text2" ], ** self .tokenizer_args
1331
+ )
1330
1332
else :
1331
1333
tokenized_example = self .tokenizer (
1332
- example ["text1" ], example ["text2" ], padding = True , truncation = True , return_tensors = "pt"
1334
+ example ["text1" ],
1335
+ example ["text2" ],
1336
+ padding = True ,
1337
+ truncation = True ,
1338
+ return_tensors = "pt" ,
1333
1339
)
1334
1340
return tokenized_example
1335
1341
1342
+
1336
1343
class FineTuningConversation (FineTuningBase ):
1337
1344
def __init__ (
1338
1345
self ,
@@ -1459,7 +1466,7 @@ def formatting_prompts_func(example):
1459
1466
callbacks = [PGMLCallback (self .project_id , self .model_id )],
1460
1467
)
1461
1468
print_info ("Creating Supervised Fine Tuning trainer done. Training ... " )
1462
-
1469
+
1463
1470
# Train
1464
1471
self .trainer .train ()
1465
1472
0 commit comments