|
|
@@ -78,7 +78,7 @@ class RobertaEmbedding(ContextualEmbedding): |
|
|
|
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s> |
|
|
|
来进行分类的任务将auto_truncate置为True。 |
|
|
|
:param kwargs: |
|
|
|
int min_freq: 小于该次数的词会被unk代替 |
|
|
|
int min_freq: 小于该次数的词会被unk代替, 默认为1 |
|
|
|
""" |
|
|
|
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
@@ -93,7 +93,7 @@ class RobertaEmbedding(ContextualEmbedding): |
|
|
|
if '<s>' in vocab: |
|
|
|
self._word_cls_index = vocab['<s>'] |
|
|
|
|
|
|
|
min_freq = kwargs.get('min_freq', 2) |
|
|
|
min_freq = kwargs.get('min_freq', 1) |
|
|
|
self._min_freq = min_freq |
|
|
|
|
|
|
|
self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, |
|
|
@@ -464,7 +464,7 @@ class RobertaWordPieceEncoder(nn.Module): |
|
|
|
|
|
|
|
os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True) |
|
|
|
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER)) |
|
|
|
logger.debug(f"BertWordPieceEncoder has been saved in {folder}") |
|
|
|
logger.debug(f"RobertaWordPieceEncoder has been saved in {folder}") |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def load(cls, folder): |
|
|
|