From 876058556deabcdf1a399e79983444d97ec790f2 Mon Sep 17 00:00:00 2001 From: hemu Date: Fri, 14 Oct 2022 18:15:52 +0800 Subject: [PATCH] fix generate --- modelscope/models/nlp/gpt3/modeling_gpt3.py | 3 +++ requirements/nlp.txt | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/modelscope/models/nlp/gpt3/modeling_gpt3.py b/modelscope/models/nlp/gpt3/modeling_gpt3.py index 498d15de..ade36e36 100644 --- a/modelscope/models/nlp/gpt3/modeling_gpt3.py +++ b/modelscope/models/nlp/gpt3/modeling_gpt3.py @@ -346,3 +346,6 @@ class GPT3Model(PreTrainedModel): } model.load_state_dict(state_dict) return model + + def prepare_inputs_for_generation(self, input_ids, *args, **kwargs): + return {'input_ids': input_ids} diff --git a/requirements/nlp.txt b/requirements/nlp.txt index 2e0838fc..123c238e 100644 --- a/requirements/nlp.txt +++ b/requirements/nlp.txt @@ -14,5 +14,4 @@ spacy>=2.3.5 subword_nmt>=0.3.8 text2sql_lgesql tokenizers -# recent 4.23.1 update introduce breaking api change, limit upper version temporarily. -transformers>=4.12.0,<=4.22.0 +transformers