Browse Source

[to #42322933] fix generate Merge remote-tracking branch 'origin/fix_generate'

master
Yingda Chen 3 years ago
parent
commit
1c2f2055cb
2 changed files with 4 additions and 2 deletions
  1. +3
    -0
      modelscope/models/nlp/gpt3/modeling_gpt3.py
  2. +1
    -2
      requirements/nlp.txt

+ 3
- 0
modelscope/models/nlp/gpt3/modeling_gpt3.py View File

@@ -346,3 +346,6 @@ class GPT3Model(PreTrainedModel):
}
model.load_state_dict(state_dict)
return model

def prepare_inputs_for_generation(self, input_ids, *args, **kwargs):
return {'input_ids': input_ids}

+ 1
- 2
requirements/nlp.txt View File

@@ -14,5 +14,4 @@ spacy>=2.3.5
subword_nmt>=0.3.8
text2sql_lgesql
tokenizers
# recent 4.23.1 update introduce breaking api change, limit upper version temporarily.
transformers>=4.12.0,<=4.22.0
transformers

Loading…
Cancel
Save