diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index b23f81e2..89b55a25 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -81,6 +81,12 @@ class DataSetGetter: raise ValueError self.idx_list = idx_list + def __getattr__(self, item): + if hasattr(self.dataset, item): + return getattr(self.dataset, item) + else: + raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item)) + class SamplerAdapter(torch.utils.data.Sampler): def __init__(self, sampler, dataset): @@ -131,9 +137,9 @@ class DataSetIter(BatchIter): timeout=0, worker_init_fn=None): super().__init__() assert isinstance(dataset, DataSet) + sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) dataset = DataSetGetter(dataset, as_numpy) collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None - sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) self.dataiter = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, sampler=sampler, collate_fn=collate_fn, num_workers=num_workers, diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index a8836b5a..7dc29ba3 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -179,8 +179,6 @@ class FieldArray: return self.pad(contents) def pad(self, contents): - if self.padder is None: - raise RuntimeError return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) def set_padder(self, padder): @@ -355,8 +353,15 @@ class FieldArray: :return: Counter, key是label,value是出现次数 """ count = Counter() + + def cum(cell): + if _is_iterable(cell) and not isinstance(cell, str): + for cell_ in cell: + cum(cell_) + else: + count[cell] += 1 for cell in self.content: - count[cell] += 1 + cum(cell) return count def _after_process(self, new_contents, inplace): diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 8b17f75a..66234ce7 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -34,14 +34,23 @@ class LossBase(object): """ def __init__(self): - self.param_map = {} + self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value self._checked = False - + + @property + def param_map(self): + if len(self._param_map) == 0: # 如果为空说明还没有初始化 + func_spect = inspect.getfullargspec(self.get_loss) + func_args = [arg for arg in func_spect.args if arg != 'self'] + for arg in func_args: + self._param_map[arg] = arg + return self._param_map + def get_loss(self, *args, **kwargs): raise NotImplementedError def _init_param_map(self, key_map=None, **kwargs): - """检查key_map和其他参数map,并将这些映射关系添加到self.param_map + """检查key_map和其他参数map,并将这些映射关系添加到self._param_map :param dict key_map: 表示key的映射关系 :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 @@ -53,30 +62,30 @@ class LossBase(object): raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) for key, value in key_map.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(key, str): raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") if not isinstance(value, str): raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for key, value in kwargs.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(value, str): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for value, key_set in value_counter.items(): if len(key_set) > 1: raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") - # check consistence between signature and param_map + # check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.get_loss) func_args = [arg for arg in func_spect.args if arg != 'self'] - for func_param, input_param in self.param_map.items(): + for func_param, input_param in self._param_map.items(): if func_param not in func_args: raise NameError( f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " @@ -96,7 +105,7 @@ class LossBase(object): :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. """ fast_param = {} - if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] fast_param['target'] = list(target_dict.values())[0] return fast_param @@ -115,19 +124,19 @@ class LossBase(object): return loss if not self._checked: - # 1. check consistence between signature and param_map + # 1. check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.get_loss) func_args = set([arg for arg in func_spect.args if arg != 'self']) - for func_arg, input_arg in self.param_map.items(): + for func_arg, input_arg in self._param_map.items(): if func_arg not in func_args: raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") - # 2. only part of the param_map are passed, left are not + # 2. only part of the _param_map are passed, left are not for arg in func_args: - if arg not in self.param_map: - self.param_map[arg] = arg # This param does not need mapping. + if arg not in self._param_map: + self._param_map[arg] = arg # This param does not need mapping. self._evaluate_args = func_args - self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} + self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} mapped_pred_dict = {} mapped_target_dict = {} @@ -149,7 +158,7 @@ class LossBase(object): replaced_missing = list(missing) for idx, func_arg in enumerate(missing): # Don't delete `` in this information, nor add `` - replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ + replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" check_res = _CheckRes(missing=replaced_missing, @@ -162,6 +171,8 @@ class LossBase(object): if check_res.missing or check_res.duplicated: raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.get_loss)) + self._checked = True + refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) loss = self.get_loss(**refined_args) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 37a94a08..cfcb9039 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -115,9 +115,18 @@ class MetricBase(object): """ def __init__(self): - self.param_map = {} # key is param in function, value is input param. + self._param_map = {} # key is param in function, value is input param. self._checked = False + @property + def param_map(self): + if len(self._param_map) == 0: # 如果为空说明还没有初始化 + func_spect = inspect.getfullargspec(self.evaluate) + func_args = [arg for arg in func_spect.args if arg != 'self'] + for arg in func_args: + self._param_map[arg] = arg + return self._param_map + @abstractmethod def evaluate(self, *args, **kwargs): raise NotImplementedError @@ -127,7 +136,7 @@ class MetricBase(object): raise NotImplemented def _init_param_map(self, key_map=None, **kwargs): - """检查key_map和其他参数map,并将这些映射关系添加到self.param_map + """检查key_map和其他参数map,并将这些映射关系添加到self._param_map :param dict key_map: 表示key的映射关系 :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 @@ -139,30 +148,30 @@ class MetricBase(object): raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) for key, value in key_map.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(key, str): raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") if not isinstance(value, str): raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for key, value in kwargs.items(): if value is None: - self.param_map[key] = key + self._param_map[key] = key continue if not isinstance(value, str): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") - self.param_map[key] = value + self._param_map[key] = value value_counter[value].add(key) for value, key_set in value_counter.items(): if len(key_set) > 1: raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") - # check consistence between signature and param_map + # check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = [arg for arg in func_spect.args if arg != 'self'] - for func_param, input_param in self.param_map.items(): + for func_param, input_param in self._param_map.items(): if func_param not in func_args: raise NameError( f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " @@ -177,7 +186,7 @@ class MetricBase(object): :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. """ fast_param = {} - if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: + if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: fast_param['pred'] = list(pred_dict.values())[0] fast_param['target'] = list(target_dict.values())[0] return fast_param @@ -206,19 +215,19 @@ class MetricBase(object): if not self._checked: if not callable(self.evaluate): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") - # 1. check consistence between signature and param_map + # 1. check consistence between signature and _param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = set([arg for arg in func_spect.args if arg != 'self']) - for func_arg, input_arg in self.param_map.items(): + for func_arg, input_arg in self._param_map.items(): if func_arg not in func_args: raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") - # 2. only part of the param_map are passed, left are not + # 2. only part of the _param_map are passed, left are not for arg in func_args: - if arg not in self.param_map: - self.param_map[arg] = arg # This param does not need mapping. + if arg not in self._param_map: + self._param_map[arg] = arg # This param does not need mapping. self._evaluate_args = func_args - self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} + self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} # need to wrap inputs in dict. mapped_pred_dict = {} @@ -242,7 +251,7 @@ class MetricBase(object): replaced_missing = list(missing) for idx, func_arg in enumerate(missing): # Don't delete `` in this information, nor add `` - replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ + replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ f"in `{self.__class__.__name__}`)" check_res = _CheckRes(missing=replaced_missing, @@ -255,10 +264,10 @@ class MetricBase(object): if check_res.missing or check_res.duplicated: raise _CheckError(check_res=check_res, func_signature=_get_func_signature(self.evaluate)) + self._checked = True refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) self.evaluate(**refined_args) - self._checked = True return @@ -416,19 +425,19 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): ignore_labels = set(ignore_labels) if ignore_labels else set() spans = [] - prev_bmes_tag = None + prev_bioes_tag = None for idx, tag in enumerate(tags): tag = tag.lower() - bmes_tag, label = tag[:1], tag[2:] - if bmes_tag in ('b', 's'): + bieso_tag, label = tag[:1], tag[2:] + if bieso_tag in ('b', 's'): spans.append((label, [idx, idx])) - elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]: + elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: spans[-1][1][1] = idx - elif bmes_tag == 'o': + elif bieso_tag == 'o': pass else: spans.append((label, [idx, idx])) - prev_bmes_tag = bmes_tag + prev_bioes_tag = bieso_tag return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 41f760e3..8dece12d 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -432,9 +432,8 @@ class Trainer(object): if metric_key is not None: self.increase_better = False if metric_key[0] == "-" else True self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key - elif len(metrics) > 0: - self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') - + else: + self.metric_key = None # prepare loss losser = _prepare_losser(loss) @@ -454,9 +453,7 @@ class Trainer(object): raise TypeError("train_data type {} not support".format(type(train_data))) if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): - # TODO 考虑不同的dataset类型怎么check - _check_code(data_iterator=self.data_iterator, - model=model, losser=losser, metrics=metrics, dev_data=dev_data, + _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, metric_key=metric_key, check_level=check_code_level, batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 @@ -758,7 +755,9 @@ class Trainer(object): :return bool value: True means current results on dev set is the best. """ - indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) + indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) + if self.metric_key is None: + self.metric_key = indicator is_better = True if self.best_metric_indicator is None: # first-time validation @@ -797,16 +796,34 @@ def _get_value_info(_dict): strs.append(_str) return strs - -def _check_code(data_iterator, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, +from numbers import Number +from .batch import _to_tensor +def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 - model_devcie = model.parameters().__next__().device + model_devcie = _get_model_device(model=model) - batch = data_iterator - dataset = data_iterator.dataset - for batch_count, (batch_x, batch_y) in enumerate(batch): + def _iter(): + start_idx = 0 + while start_idx 1 and metric_key is None: - raise RuntimeError( - f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") else: # metric_key is set if metric_key not in metric_dict: raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") indicator_val = metric_dict[metric_key] + indicator = metric_key else: raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) - return indicator_val + return indicator, indicator_val diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py index adfa8ca1..465fb7e8 100644 --- a/fastNLP/io/base_loader.py +++ b/fastNLP/io/base_loader.py @@ -124,6 +124,14 @@ class DataInfo: self.embeddings = embeddings or {} self.datasets = datasets or {} + def __repr__(self): + _str = 'In total {} datasets:\n'.format(len(self.datasets)) + for name, dataset in self.datasets.items(): + _str += '\t{} has {} instances.\n'.format(name, len(dataset)) + _str += 'In total {} vocabs:\n'.format(len(self.vocabs)) + for name, vocab in self.vocabs.items(): + _str += '\t{} has {} entries.\n'.format(name, len(vocab)) + return _str class DataSetLoader: """ diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 0595ad46..c63ff2f4 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -120,7 +120,8 @@ class ConllLoader(DataSetLoader): """ 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` - 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html + 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 + 该符号在conll 2003中被用为文档分割符。 列号从0开始, 每列对应内容为:: diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 5963bb56..34b5d7c0 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -90,11 +90,12 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): return sample with open(path, 'r', encoding=encoding) as f: sample = [] - start = next(f) - if '-DOCSTART-' not in start: + start = next(f).strip() + if '-DOCSTART-' not in start and start!='': sample.append(start.split()) for line_idx, line in enumerate(f, 1): - if line.startswith('\n'): + line = line.strip() + if line=='': if len(sample): try: res = parse_conll(sample) @@ -107,7 +108,8 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): elif line.startswith('#'): continue else: - sample.append(line.split()) + if not line.startswith('-DOCSTART-'): + sample.append(line.split()) if len(sample) > 0: try: res = parse_conll(sample) @@ -115,4 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): except Exception as e: if dropna: return - raise ValueError('invalid instance at line: {}'.format(line_idx)) + print('invalid instance at line: {}'.format(line_idx)) + raise e diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index beb2b9be..c0717d6f 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -9,7 +9,7 @@ from torch import nn from ..utils import initial_parameter -def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): +def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): """ 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` @@ -17,7 +17,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): :param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 - :param str encoding_type: 支持"bio", "bmes", "bmeso"。 + :param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 @@ -58,7 +58,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): """ - :param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 + :param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag :param str from_label: 比如"PER", "LOC"等label :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag @@ -134,9 +134,19 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label return to_tag in ['b', 's', 'end', 'o'] else: raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) - + elif encoding_type == 'bioes': + if from_tag == 'start': + return to_tag in ['b', 's', 'o'] + elif from_tag == 'b': + return to_tag in ['i', 'e'] and from_label == to_label + elif from_tag == 'i': + return to_tag in ['i', 'e'] and from_label == to_label + elif from_tag in ['e', 's', 'o']: + return to_tag in ['b', 's', 'end', 'o'] + else: + raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) else: - raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) + raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) class ConditionalRandomField(nn.Module): diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py index 4be75f20..349bce69 100644 --- a/fastNLP/modules/encoder/__init__.py +++ b/fastNLP/modules/encoder/__init__.py @@ -24,7 +24,8 @@ __all__ = [ "VarLSTM", "VarGRU" ] -from .bert import BertModel +from ._bert import BertModel +from .bert import BertWordPieceEncoder from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder from .conv_maxpool import ConvMaxpool from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py index 1423f333..317b78d8 100644 --- a/fastNLP/modules/encoder/_bert.py +++ b/fastNLP/modules/encoder/_bert.py @@ -6,18 +6,399 @@ """ -import torch -from torch import nn from ...core.vocabulary import Vocabulary import collections -import os import unicodedata from ...io.file_utils import _get_base_url, cached_path -from .bert import BertModel import numpy as np from itertools import chain +import copy +import json +import math +import os + +import torch +from torch import nn + +CONFIG_FILE = 'bert_config.json' +MODEL_WEIGHTS = 'pytorch_model.bin' + + +def gelu(x): + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + + +def swish(x): + return x * torch.sigmoid(x) + + +ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} + + +class BertLayerNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-12): + super(BertLayerNorm, self).__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.bias = nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class BertEmbeddings(nn.Module): + def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): + super(BertEmbeddings, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids=None): + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + words_embeddings = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + + embeddings = words_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): + super(BertSelfAttention, self).__init__() + if hidden_size % num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (hidden_size, num_attention_heads)) + self.num_attention_heads = num_attention_heads + self.attention_head_size = int(hidden_size / num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + self.value = nn.Linear(hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask): + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + context_layer = torch.matmul(attention_probs, value_layer) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + return context_layer + + +class BertSelfOutput(nn.Module): + def __init__(self, hidden_size, hidden_dropout_prob): + super(BertSelfOutput, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) + self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) + + def forward(self, input_tensor, attention_mask): + self_output = self.self(input_tensor, attention_mask) + attention_output = self.output(self_output, input_tensor) + return attention_output + + +class BertIntermediate(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_act): + super(BertIntermediate, self).__init__() + self.dense = nn.Linear(hidden_size, intermediate_size) + self.intermediate_act_fn = ACT2FN[hidden_act] \ + if isinstance(hidden_act, str) else hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): + super(BertOutput, self).__init__() + self.dense = nn.Linear(intermediate_size, hidden_size) + self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) + self.dropout = nn.Dropout(hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertLayer, self).__init__() + self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob) + self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) + self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) + + def forward(self, hidden_states, attention_mask): + attention_output = self.attention(hidden_states, attention_mask) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, + hidden_dropout_prob, + intermediate_size, hidden_act): + super(BertEncoder, self).__init__() + layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, + intermediate_size, hidden_act) + self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): + all_encoder_layers = [] + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + return all_encoder_layers + + +class BertPooler(nn.Module): + def __init__(self, hidden_size): + super(BertPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + """BERT(Bidirectional Embedding Representations from Transformers). + + 如果你想使用预训练好的权重矩阵,请在以下网址下载. + sources:: + + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", + + + 用预训练权重矩阵来建立BERT模型:: + + model = BertModel.from_pretrained("path/to/weights/directory") + + 用随机初始化权重矩阵来建立BERT模型:: + + model = BertModel() + + :param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 + :param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 + :param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 + :param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 + :param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 + :param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` + :param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 + :param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 + :param int max_position_embeddings: 最大的序列长度,默认值为512, + :param int type_vocab_size: 最大segment数量,默认值为2 + :param int initializer_range: 初始化权重范围,默认值为0.02 + """ + + def __init__(self, vocab_size=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02): + super(BertModel, self).__init__() + self.hidden_size = hidden_size + self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, + type_vocab_size, hidden_dropout_prob) + self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, + attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, + hidden_act) + self.pooler = BertPooler(hidden_size) + self.initializer_range = initializer_range + + self.apply(self.init_bert_weights) + + def init_bert_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + elif isinstance(module, BertLayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(input_ids, token_type_ids) + encoded_layers = self.encoder(embedding_output, + extended_attention_mask, + output_all_encoded_layers=output_all_encoded_layers) + sequence_output = encoded_layers[-1] + pooled_output = self.pooler(sequence_output) + if not output_all_encoded_layers: + encoded_layers = encoded_layers[-1] + return encoded_layers, pooled_output + + @classmethod + def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): + # Load config + config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) + config = json.load(open(config_file, "r")) + # config = BertConfig.from_json_file(config_file) + # logger.info("Model config {}".format(config)) + # Instantiate model. + model = cls(*inputs, **config, **kwargs) + if state_dict is None: + weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) + state_dict = torch.load(weights_path) + + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + + missing_keys = [] + unexpected_keys = [] + error_msgs = [] + # copy state_dict so _load_from_state_dict can modify it + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + def load(module, prefix=''): + local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) + module._load_from_state_dict( + state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(model, prefix='' if hasattr(model, 'bert') else 'bert.') + if len(missing_keys) > 0: + print("Weights of {} not initialized from pretrained model: {}".format( + model.__class__.__name__, missing_keys)) + if len(unexpected_keys) > 0: + print("Weights from pretrained model not used in {}: {}".format( + model.__class__.__name__, unexpected_keys)) + return model + + + + + + + + + + + def whitespace_tokenize(text): """Runs basic whitespace cleaning and splitting on a piece of text.""" @@ -547,79 +928,3 @@ class _WordPieceBertModel(nn.Module): outputs[l_index] = bert_outputs[l] return outputs -class BertWordPieceEncoder(nn.Module): - """ - 可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 - - :param vocab: Vocabulary. - :param model_dir_or_name: - :param layers: - :param requires_grad: - """ - def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', - requires_grad:bool=False): - super().__init__() - PRETRAIN_URL = _get_base_url('bert') - # TODO 修改 - PRETRAINED_BERT_MODEL_DIR = {'en-base': 'bert_en-80f95ea7.tar.gz', - 'cn': 'elmo_cn.zip'} - - if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: - model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] - model_url = PRETRAIN_URL + model_name - model_dir = cached_path(model_url) - # 检查是否存在 - elif os.path.isdir(model_dir_or_name): - model_dir = model_dir_or_name - else: - raise ValueError(f"Cannot recognize {model_dir_or_name}.") - - self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) - self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size - self.requires_grad = requires_grad - - @property - def requires_grad(self): - """ - Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 - :return: - """ - requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) - if len(requires_grads)==1: - return requires_grads.pop() - else: - return None - - @requires_grad.setter - def requires_grad(self, value): - for name, param in self.named_parameters(): - param.requires_grad = value - - @property - def embed_size(self): - return self._embed_size - - def index_datasets(self, *datasets): - """ - 对datasets进行word piece的index。 - - Example:: - - :param datasets: - :return: - """ - self.model.index_dataset(*datasets) - - def forward(self, words, token_type_ids=None): - """ - 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 - 删除这两个表示。 - - :param words: batch_size x max_len - :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 - :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) - """ - outputs = self.model(words, token_type_ids) - outputs = torch.cat([*outputs], dim=-1) - - return outputs \ No newline at end of file diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 38a35fc9..e9739c28 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -1,378 +1,95 @@ -""" -bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. -""" -import copy -import json -import math import os - -import torch from torch import nn +import torch +from ...core import Vocabulary +from ...io.file_utils import _get_base_url, cached_path +from ._bert import _WordPieceBertModel -CONFIG_FILE = 'bert_config.json' -MODEL_WEIGHTS = 'pytorch_model.bin' - - -def gelu(x): - return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) - - -def swish(x): - return x * torch.sigmoid(x) - - -ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} - - -class BertLayerNorm(nn.Module): - def __init__(self, hidden_size, eps=1e-12): - super(BertLayerNorm, self).__init__() - self.weight = nn.Parameter(torch.ones(hidden_size)) - self.bias = nn.Parameter(torch.zeros(hidden_size)) - self.variance_epsilon = eps - - def forward(self, x): - u = x.mean(-1, keepdim=True) - s = (x - u).pow(2).mean(-1, keepdim=True) - x = (x - u) / torch.sqrt(s + self.variance_epsilon) - return self.weight * x + self.bias - - -class BertEmbeddings(nn.Module): - def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): - super(BertEmbeddings, self).__init__() - self.word_embeddings = nn.Embedding(vocab_size, hidden_size) - self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) - self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) - - def forward(self, input_ids, token_type_ids=None): - seq_length = input_ids.size(1) - position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) - position_ids = position_ids.unsqueeze(0).expand_as(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - words_embeddings = self.word_embeddings(input_ids) - position_embeddings = self.position_embeddings(position_ids) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - - embeddings = words_embeddings + position_embeddings + token_type_embeddings - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): - super(BertSelfAttention, self).__init__() - if hidden_size % num_attention_heads != 0: - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (hidden_size, num_attention_heads)) - self.num_attention_heads = num_attention_heads - self.attention_head_size = int(hidden_size / num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(hidden_size, self.all_head_size) - self.key = nn.Linear(hidden_size, self.all_head_size) - self.value = nn.Linear(hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(attention_probs_dropout_prob) - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward(self, hidden_states, attention_mask): - mixed_query_layer = self.query(hidden_states) - mixed_key_layer = self.key(hidden_states) - mixed_value_layer = self.value(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - key_layer = self.transpose_for_scores(mixed_key_layer) - value_layer = self.transpose_for_scores(mixed_value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - context_layer = torch.matmul(attention_probs, value_layer) - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(*new_context_layer_shape) - return context_layer - - -class BertSelfOutput(nn.Module): - def __init__(self, hidden_size, hidden_dropout_prob): - super(BertSelfOutput, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertAttention(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): - super(BertAttention, self).__init__() - self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) - self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) - - def forward(self, input_tensor, attention_mask): - self_output = self.self(input_tensor, attention_mask) - attention_output = self.output(self_output, input_tensor) - return attention_output - - -class BertIntermediate(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_act): - super(BertIntermediate, self).__init__() - self.dense = nn.Linear(hidden_size, intermediate_size) - self.intermediate_act_fn = ACT2FN[hidden_act] \ - if isinstance(hidden_act, str) else hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): - super(BertOutput, self).__init__() - self.dense = nn.Linear(intermediate_size, hidden_size) - self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) - self.dropout = nn.Dropout(hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertLayer(nn.Module): - def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act): - super(BertLayer, self).__init__() - self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob) - self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) - self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) - - def forward(self, hidden_states, attention_mask): - attention_output = self.attention(hidden_states, attention_mask) - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - -class BertEncoder(nn.Module): - def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, - hidden_dropout_prob, - intermediate_size, hidden_act): - super(BertEncoder, self).__init__() - layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, - intermediate_size, hidden_act) - self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) - - def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): - all_encoder_layers = [] - for layer_module in self.layer: - hidden_states = layer_module(hidden_states, attention_mask) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - return all_encoder_layers - - -class BertPooler(nn.Module): - def __init__(self, hidden_size): - super(BertPooler, self).__init__() - self.dense = nn.Linear(hidden_size, hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertModel(nn.Module): - """BERT(Bidirectional Embedding Representations from Transformers). - - 如果你想使用预训练好的权重矩阵,请在以下网址下载. - sources:: - - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", - - - 用预训练权重矩阵来建立BERT模型:: - - model = BertModel.from_pretrained("path/to/weights/directory") - - 用随机初始化权重矩阵来建立BERT模型:: - - model = BertModel() - :param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 - :param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 - :param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 - :param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 - :param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 - :param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` - :param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 - :param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 - :param int max_position_embeddings: 最大的序列长度,默认值为512, - :param int type_vocab_size: 最大segment数量,默认值为2 - :param int initializer_range: 初始化权重范围,默认值为0.02 +class BertWordPieceEncoder(nn.Module): """ + 可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 - def __init__(self, vocab_size=30522, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02): - super(BertModel, self).__init__() - self.hidden_size = hidden_size - self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, - type_vocab_size, hidden_dropout_prob) - self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, - attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, - hidden_act) - self.pooler = BertPooler(hidden_size) - self.initializer_range = initializer_range - - self.apply(self.init_bert_weights) - - def init_bert_weights(self, module): - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_(mean=0.0, std=self.initializer_range) - elif isinstance(module, BertLayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros_like(input_ids) - - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(input_ids, token_type_ids) - encoded_layers = self.encoder(embedding_output, - extended_attention_mask, - output_all_encoded_layers=output_all_encoded_layers) - sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) - if not output_all_encoded_layers: - encoded_layers = encoded_layers[-1] - return encoded_layers, pooled_output - - @classmethod - def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): - # Load config - config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) - config = json.load(open(config_file, "r")) - # config = BertConfig.from_json_file(config_file) - # logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(*inputs, **config, **kwargs) - if state_dict is None: - weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) - state_dict = torch.load(weights_path) - - old_keys = [] - new_keys = [] - for key in state_dict.keys(): - new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') - if new_key: - old_keys.append(key) - new_keys.append(new_key) - for old_key, new_key in zip(old_keys, new_keys): - state_dict[new_key] = state_dict.pop(old_key) - - missing_keys = [] - unexpected_keys = [] - error_msgs = [] - # copy state_dict so _load_from_state_dict can modify it - metadata = getattr(state_dict, '_metadata', None) - state_dict = state_dict.copy() - if metadata is not None: - state_dict._metadata = metadata - - def load(module, prefix=''): - local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) - module._load_from_state_dict( - state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) - for name, child in module._modules.items(): - if child is not None: - load(child, prefix + name + '.') - - load(model, prefix='' if hasattr(model, 'bert') else 'bert.') - if len(missing_keys) > 0: - print("Weights of {} not initialized from pretrained model: {}".format( - model.__class__.__name__, missing_keys)) - if len(unexpected_keys) > 0: - print("Weights from pretrained model not used in {}: {}".format( - model.__class__.__name__, unexpected_keys)) - return model + :param fastNLP.Vocabulary vocab: 词表 + :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` + :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 + :param bool requires_grad: 是否需要gradient。 + """ + def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', + requires_grad:bool=False): + super().__init__() + PRETRAIN_URL = _get_base_url('bert') + PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', + 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', + 'en-base-cased': 'bert-base-cased-f89bfe08.zip', + 'en-large-uncased': 'bert-large-uncased-20939f45.zip', + 'en-large-cased': 'bert-large-cased-e0cf90fc.zip', + + 'cn': 'bert-base-chinese-29d0a84a.zip', + 'cn-base': 'bert-base-chinese-29d0a84a.zip', + + 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', + 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', + 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', + } + + if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: + model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] + model_url = PRETRAIN_URL + model_name + model_dir = cached_path(model_url) + # 检查是否存在 + elif os.path.isdir(model_dir_or_name): + model_dir = model_dir_or_name + else: + raise ValueError(f"Cannot recognize {model_dir_or_name}.") + + self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) + self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size + self.requires_grad = requires_grad + + @property + def requires_grad(self): + """ + Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 + :return: + """ + requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) + if len(requires_grads)==1: + return requires_grads.pop() + else: + return None + + @requires_grad.setter + def requires_grad(self, value): + for name, param in self.named_parameters(): + param.requires_grad = value + + @property + def embed_size(self): + return self._embed_size + + def index_datasets(self, *datasets): + """ + 根据datasets中的'words'列对datasets进行word piece的index。 + + Example:: + + :param datasets: + :return: + """ + self.model.index_dataset(*datasets) + + def forward(self, words, token_type_ids=None): + """ + 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 + 删除这两个表示。 + + :param words: batch_size x max_len + :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 + :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) + """ + outputs = self.model(words, token_type_ids) + outputs = torch.cat([*outputs], dim=-1) + + return outputs \ No newline at end of file diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py index f956aae7..7fd85578 100644 --- a/fastNLP/modules/encoder/embedding.py +++ b/fastNLP/modules/encoder/embedding.py @@ -165,7 +165,6 @@ class StaticEmbedding(TokenEmbedding): super(StaticEmbedding, self).__init__(vocab) # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, - PRETRAIN_URL = _get_base_url('static') PRETRAIN_STATIC_FILES = { 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', @@ -178,6 +177,7 @@ class StaticEmbedding(TokenEmbedding): # 得到cache_path if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: + PRETRAIN_URL = _get_base_url('static') model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] model_url = PRETRAIN_URL + model_name model_path = cached_path(model_url) @@ -333,12 +333,11 @@ class ElmoEmbedding(ContextualEmbedding): self.layers = layers # 根据model_dir_or_name检查是否存在并下载 - PRETRAIN_URL = _get_base_url('elmo') - # TODO 把baidu云上的加上去 PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', 'cn': 'elmo_cn-5e9b34e2.tar.gz'} if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: + PRETRAIN_URL = _get_base_url('elmo') model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] model_url = PRETRAIN_URL + model_name model_dir = cached_path(model_url) @@ -392,7 +391,7 @@ class ElmoEmbedding(ContextualEmbedding): def requires_grad(self, value): for name, param in self.named_parameters(): if 'words_to_chars_embedding' in name: # 这个不能加入到requires_grad中 - pass + continue param.requires_grad = value @@ -420,7 +419,6 @@ class BertEmbedding(ContextualEmbedding): pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): super(BertEmbedding, self).__init__(vocab) # 根据model_dir_or_name检查是否存在并下载 - PRETRAIN_URL = _get_base_url('bert') PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', 'en-base-cased': 'bert-base-cased-f89bfe08.zip', @@ -436,6 +434,7 @@ class BertEmbedding(ContextualEmbedding): } if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: + PRETRAIN_URL = _get_base_url('bert') model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] model_url = PRETRAIN_URL + model_name model_dir = cached_path(model_url) @@ -487,7 +486,7 @@ class BertEmbedding(ContextualEmbedding): def requires_grad(self, value): for name, param in self.named_parameters(): if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 - pass + continue param.requires_grad = value @@ -575,6 +574,7 @@ class CNNCharEmbedding(TokenEmbedding): for i in range(len(kernel_sizes))]) self._embed_size = embed_size self.fc = nn.Linear(sum(filter_nums), embed_size) + self.init_param() def forward(self, words): """ @@ -627,9 +627,17 @@ class CNNCharEmbedding(TokenEmbedding): def requires_grad(self, value): for name, param in self.named_parameters(): if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 - pass + continue param.requires_grad = value + def init_param(self): + for name, param in self.named_parameters(): + if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset + continue + if param.data.dim()>1: + nn.init.xavier_normal_(param, 1) + else: + nn.init.uniform_(param, -1, 1) class LSTMCharEmbedding(TokenEmbedding): """ @@ -753,7 +761,7 @@ class LSTMCharEmbedding(TokenEmbedding): def requires_grad(self, value): for name, param in self.named_parameters(): if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 - pass + continue param.requires_grad = value diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index b4d3aff2..3b97f4a7 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -35,8 +35,18 @@ class LSTM(nn.Module): self.batch_first = batch_first self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) + self.init_param() initial_parameter(self, initial_method) - + + def init_param(self): + for name, param in self.named_parameters(): + if 'bias_i' in name: + param.data.fill_(1) + elif 'bias_h' in name: + param.data.fill_(0) + else: + nn.init.xavier_normal_(param) + def forward(self, x, seq_len=None, h0=None, c0=None): """ diff --git a/reproduction/seqence_labelling/cws/train_shift_relay.py b/reproduction/seqence_labelling/cws/train_shift_relay.py index 55576575..805521e7 100644 --- a/reproduction/seqence_labelling/cws/train_shift_relay.py +++ b/reproduction/seqence_labelling/cws/train_shift_relay.py @@ -57,8 +57,12 @@ callbacks = [clipper] # if pretrain: # fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) # callbacks.append(fixer) -trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, - update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), - metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, +trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, + batch_size=32, sampler=sampler, update_every=5, + n_epochs=3, print_every=5, + dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f', + validate_every=-1, save_path=None, + prefetch=True, use_tqdm=True, device=device, + callbacks=callbacks, check_code_level=0) trainer.train() \ No newline at end of file diff --git a/reproduction/utils.py b/reproduction/utils.py index bbfed4dd..26b2014c 100644 --- a/reproduction/utils.py +++ b/reproduction/utils.py @@ -25,7 +25,7 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: if not os.path.isfile(train_fp): raise FileNotFoundError(f"train.txt is not found in folder {paths}.") files = {'train': train_fp} - for filename in ['test.txt', 'dev.txt']: + for filename in ['dev.txt', 'test.txt']: fp = os.path.join(paths, filename) if os.path.isfile(fp): files[filename.split('.')[0]] = fp diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index f3b0178c..9c8a586c 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -161,7 +161,15 @@ class TestAccuracyMetric(unittest.TestCase): print(e) return self.assertTrue(True, False), "No exception catches." - + + def test_duplicate(self): + # 0.4.1的潜在bug,不能出现形参重复的情况 + metric = AccuracyMetric(pred='predictions', target='targets') + pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} + target_dict = {'targets':torch.zeros(4, 3), 'target': 0} + metric(pred_dict=pred_dict, target_dict=target_dict) + + def test_seq_len(self): N = 256 seq_len = torch.zeros(N).long() diff --git a/test/models/test_biaffine_parser.py b/test/models/test_biaffine_parser.py index e6fca6a8..4f93b994 100644 --- a/test/models/test_biaffine_parser.py +++ b/test/models/test_biaffine_parser.py @@ -1,6 +1,5 @@ import unittest -import fastNLP from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric from .model_runner import * diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index 5dec7d47..647af7d3 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -10,14 +10,14 @@ class TestCRF(unittest.TestCase): id2label = {0: 'B', 1: 'I', 2:'O'} expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), (2, 4), (3, 0), (3, 2)} - self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} - self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) id2label = {0: 'B', 1: 'I', 2:'O', 3: '', 4:""} - allowed_transitions(id2label) + allowed_transitions(id2label, include_start_end=True) labels = ['O'] for label in ['X', 'Y']: @@ -27,7 +27,7 @@ class TestCRF(unittest.TestCase): expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} - self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) labels = [] for label in ['X', 'Y']: @@ -37,7 +37,7 @@ class TestCRF(unittest.TestCase): expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} - self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) + self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) def test_case2(self): # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 diff --git a/test/test_tutorials.py b/test/test_tutorials.py index 428d584d..87910c3d 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -80,7 +80,7 @@ class TestTutorial(unittest.TestCase): test_data.rename_field('label', 'label_seq') loss = CrossEntropyLoss(pred="output", target="label_seq") - metric = AccuracyMetric(pred="predict", target="label_seq") + metric = AccuracyMetric(target="label_seq") # 实例化Trainer,传入模型和数据,进行训练 # 先在test_data拟合(确保模型的实现是正确的) @@ -90,16 +90,19 @@ class TestTutorial(unittest.TestCase): overfit_trainer.train() # 用train_data训练,在test_data验证 - trainer = Trainer(train_data=train_data, model=model, loss=CrossEntropyLoss(pred="output", target="label_seq"), - batch_size=32, n_epochs=5, dev_data=test_data, - metrics=AccuracyMetric(pred="predict", target="label_seq"), save_path=None) + trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, + loss=CrossEntropyLoss(pred="output", target="label_seq"), + metrics=AccuracyMetric(target="label_seq"), + save_path=None, + batch_size=32, + n_epochs=5) trainer.train() print('Train finished!') # 调用Tester在test_data上评价效果 from fastNLP import Tester - tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), + tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(target="label_seq"), batch_size=4) acc = tester.test() print(acc)