|
@@ -128,7 +128,7 @@ class SequenceClassificationTrainer(BaseTrainer): |
|
|
collate_fn=pre_dataset.batch_fn) |
|
|
collate_fn=pre_dataset.batch_fn) |
|
|
|
|
|
|
|
|
# generate a model |
|
|
# generate a model |
|
|
model = SequenceClassification(checkpoint_path) |
|
|
|
|
|
|
|
|
model = SequenceClassification.from_pretrained(checkpoint_path) |
|
|
|
|
|
|
|
|
# copy from easynlp (start) |
|
|
# copy from easynlp (start) |
|
|
model.eval() |
|
|
model.eval() |
|
|