@@ -81,6 +81,12 @@ class DataSetGetter: | |||||
raise ValueError | raise ValueError | ||||
self.idx_list = idx_list | self.idx_list = idx_list | ||||
def __getattr__(self, item): | |||||
if hasattr(self.dataset, item): | |||||
return getattr(self.dataset, item) | |||||
else: | |||||
raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item)) | |||||
class SamplerAdapter(torch.utils.data.Sampler): | class SamplerAdapter(torch.utils.data.Sampler): | ||||
def __init__(self, sampler, dataset): | def __init__(self, sampler, dataset): | ||||
@@ -131,9 +137,9 @@ class DataSetIter(BatchIter): | |||||
timeout=0, worker_init_fn=None): | timeout=0, worker_init_fn=None): | ||||
super().__init__() | super().__init__() | ||||
assert isinstance(dataset, DataSet) | assert isinstance(dataset, DataSet) | ||||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||||
dataset = DataSetGetter(dataset, as_numpy) | dataset = DataSetGetter(dataset, as_numpy) | ||||
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | ||||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||||
self.dataiter = torch.utils.data.DataLoader( | self.dataiter = torch.utils.data.DataLoader( | ||||
dataset=dataset, batch_size=batch_size, sampler=sampler, | dataset=dataset, batch_size=batch_size, sampler=sampler, | ||||
collate_fn=collate_fn, num_workers=num_workers, | collate_fn=collate_fn, num_workers=num_workers, | ||||
@@ -179,8 +179,6 @@ class FieldArray: | |||||
return self.pad(contents) | return self.pad(contents) | ||||
def pad(self, contents): | def pad(self, contents): | ||||
if self.padder is None: | |||||
raise RuntimeError | |||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | ||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
@@ -355,8 +353,15 @@ class FieldArray: | |||||
:return: Counter, key是label,value是出现次数 | :return: Counter, key是label,value是出现次数 | ||||
""" | """ | ||||
count = Counter() | count = Counter() | ||||
def cum(cell): | |||||
if _is_iterable(cell) and not isinstance(cell, str): | |||||
for cell_ in cell: | |||||
cum(cell_) | |||||
else: | |||||
count[cell] += 1 | |||||
for cell in self.content: | for cell in self.content: | ||||
count[cell] += 1 | |||||
cum(cell) | |||||
return count | return count | ||||
def _after_process(self, new_contents, inplace): | def _after_process(self, new_contents, inplace): | ||||
@@ -34,14 +34,23 @@ class LossBase(object): | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.param_map = {} | |||||
self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value | |||||
self._checked = False | self._checked = False | ||||
@property | |||||
def param_map(self): | |||||
if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||||
func_spect = inspect.getfullargspec(self.get_loss) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for arg in func_args: | |||||
self._param_map[arg] = arg | |||||
return self._param_map | |||||
def get_loss(self, *args, **kwargs): | def get_loss(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
:param dict key_map: 表示key的映射关系 | :param dict key_map: 表示key的映射关系 | ||||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | ||||
@@ -53,30 +62,30 @@ class LossBase(object): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | ||||
for key, value in key_map.items(): | for key, value in key_map.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(key, str): | if not isinstance(key, str): | ||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for key, value in kwargs.items(): | for key, value in kwargs.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and param_map | |||||
# check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
for func_param, input_param in self.param_map.items(): | |||||
for func_param, input_param in self._param_map.items(): | |||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | ||||
@@ -96,7 +105,7 @@ class LossBase(object): | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | ||||
""" | """ | ||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
@@ -115,19 +124,19 @@ class LossBase(object): | |||||
return loss | return loss | ||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | |||||
# 1. check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | func_args = set([arg for arg in func_spect.args if arg != 'self']) | ||||
for func_arg, input_arg in self.param_map.items(): | |||||
for func_arg, input_arg in self._param_map.items(): | |||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | ||||
# 2. only part of the param_map are passed, left are not | |||||
# 2. only part of the _param_map are passed, left are not | |||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
if arg not in self._param_map: | |||||
self._param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
@@ -149,7 +158,7 @@ class LossBase(object): | |||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
@@ -162,6 +171,8 @@ class LossBase(object): | |||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.get_loss)) | func_signature=_get_func_signature(self.get_loss)) | ||||
self._checked = True | |||||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | ||||
loss = self.get_loss(**refined_args) | loss = self.get_loss(**refined_args) | ||||
@@ -115,9 +115,18 @@ class MetricBase(object): | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.param_map = {} # key is param in function, value is input param. | |||||
self._param_map = {} # key is param in function, value is input param. | |||||
self._checked = False | self._checked = False | ||||
@property | |||||
def param_map(self): | |||||
if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||||
func_spect = inspect.getfullargspec(self.evaluate) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for arg in func_args: | |||||
self._param_map[arg] = arg | |||||
return self._param_map | |||||
@abstractmethod | @abstractmethod | ||||
def evaluate(self, *args, **kwargs): | def evaluate(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -127,7 +136,7 @@ class MetricBase(object): | |||||
raise NotImplemented | raise NotImplemented | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
:param dict key_map: 表示key的映射关系 | :param dict key_map: 表示key的映射关系 | ||||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | ||||
@@ -139,30 +148,30 @@ class MetricBase(object): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | ||||
for key, value in key_map.items(): | for key, value in key_map.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(key, str): | if not isinstance(key, str): | ||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for key, value in kwargs.items(): | for key, value in kwargs.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and param_map | |||||
# check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
for func_param, input_param in self.param_map.items(): | |||||
for func_param, input_param in self._param_map.items(): | |||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | ||||
@@ -177,7 +186,7 @@ class MetricBase(object): | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | ||||
""" | """ | ||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
@@ -206,19 +215,19 @@ class MetricBase(object): | |||||
if not self._checked: | if not self._checked: | ||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
# 1. check consistence between signature and param_map | |||||
# 1. check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | func_args = set([arg for arg in func_spect.args if arg != 'self']) | ||||
for func_arg, input_arg in self.param_map.items(): | |||||
for func_arg, input_arg in self._param_map.items(): | |||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | ||||
# 2. only part of the param_map are passed, left are not | |||||
# 2. only part of the _param_map are passed, left are not | |||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
if arg not in self._param_map: | |||||
self._param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
@@ -242,7 +251,7 @@ class MetricBase(object): | |||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
@@ -255,10 +264,10 @@ class MetricBase(object): | |||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.evaluate)) | func_signature=_get_func_signature(self.evaluate)) | ||||
self._checked = True | |||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | ||||
self.evaluate(**refined_args) | self.evaluate(**refined_args) | ||||
self._checked = True | |||||
return | return | ||||
@@ -416,19 +425,19 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | |||||
prev_bioes_tag = None | |||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
tag = tag.lower() | tag = tag.lower() | ||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
bieso_tag, label = tag[:1], tag[2:] | |||||
if bieso_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]: | |||||
elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
elif bmes_tag == 'o': | |||||
elif bieso_tag == 'o': | |||||
pass | pass | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | |||||
prev_bioes_tag = bieso_tag | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | return [(span[0], (span[1][0], span[1][1] + 1)) | ||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | if span[0] not in ignore_labels | ||||
@@ -432,9 +432,8 @@ class Trainer(object): | |||||
if metric_key is not None: | if metric_key is not None: | ||||
self.increase_better = False if metric_key[0] == "-" else True | self.increase_better = False if metric_key[0] == "-" else True | ||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | ||||
elif len(metrics) > 0: | |||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||||
else: | |||||
self.metric_key = None | |||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
@@ -454,9 +453,7 @@ class Trainer(object): | |||||
raise TypeError("train_data type {} not support".format(type(train_data))) | raise TypeError("train_data type {} not support".format(type(train_data))) | ||||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | ||||
# TODO 考虑不同的dataset类型怎么check | |||||
_check_code(data_iterator=self.data_iterator, | |||||
model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||||
metric_key=metric_key, check_level=check_code_level, | metric_key=metric_key, check_level=check_code_level, | ||||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | ||||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | ||||
@@ -758,7 +755,9 @@ class Trainer(object): | |||||
:return bool value: True means current results on dev set is the best. | :return bool value: True means current results on dev set is the best. | ||||
""" | """ | ||||
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||||
indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||||
if self.metric_key is None: | |||||
self.metric_key = indicator | |||||
is_better = True | is_better = True | ||||
if self.best_metric_indicator is None: | if self.best_metric_indicator is None: | ||||
# first-time validation | # first-time validation | ||||
@@ -797,16 +796,34 @@ def _get_value_info(_dict): | |||||
strs.append(_str) | strs.append(_str) | ||||
return strs | return strs | ||||
def _check_code(data_iterator, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||||
from numbers import Number | |||||
from .batch import _to_tensor | |||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||||
dev_data=None, metric_key=None, | dev_data=None, metric_key=None, | ||||
check_level=0): | check_level=0): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_devcie = model.parameters().__next__().device | |||||
model_devcie = _get_model_device(model=model) | |||||
batch = data_iterator | |||||
dataset = data_iterator.dataset | |||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||||
def _iter(): | |||||
start_idx = 0 | |||||
while start_idx<len(dataset): | |||||
batch_x = {} | |||||
batch_y = {} | |||||
for field_name, field in dataset.get_all_fields().items(): | |||||
indices = list(range(start_idx, min(start_idx+batch_size, len(dataset)))) | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | |||||
if field.dtype is not None and \ | |||||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||||
batch, _ = _to_tensor(batch, field.dtype) | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | |||||
yield (batch_x, batch_y) | |||||
start_idx += batch_size | |||||
for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
# forward check | # forward check | ||||
if batch_count == 0: | if batch_count == 0: | ||||
@@ -874,26 +891,16 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||||
loss, metrics = metrics | loss, metrics = metrics | ||||
if isinstance(metrics, dict): | if isinstance(metrics, dict): | ||||
if len(metrics) == 1: | |||||
# only single metric, just use it | |||||
metric_dict = list(metrics.values())[0] | |||||
metrics_name = list(metrics.keys())[0] | |||||
else: | |||||
metrics_name = metric_list[0].__class__.__name__ | |||||
if metrics_name not in metrics: | |||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||||
metric_dict = metrics[metrics_name] | |||||
metric_dict = list(metrics.values())[0] # 取第一个metric | |||||
if len(metric_dict) == 1: | |||||
if metric_key is None: | |||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | ||||
elif len(metric_dict) > 1 and metric_key is None: | |||||
raise RuntimeError( | |||||
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") | |||||
else: | else: | ||||
# metric_key is set | # metric_key is set | ||||
if metric_key not in metric_dict: | if metric_key not in metric_dict: | ||||
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | ||||
indicator_val = metric_dict[metric_key] | indicator_val = metric_dict[metric_key] | ||||
indicator = metric_key | |||||
else: | else: | ||||
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | ||||
return indicator_val | |||||
return indicator, indicator_val |
@@ -124,6 +124,14 @@ class DataInfo: | |||||
self.embeddings = embeddings or {} | self.embeddings = embeddings or {} | ||||
self.datasets = datasets or {} | self.datasets = datasets or {} | ||||
def __repr__(self): | |||||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||||
for name, dataset in self.datasets.items(): | |||||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||||
for name, vocab in self.vocabs.items(): | |||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||||
return _str | |||||
class DataSetLoader: | class DataSetLoader: | ||||
""" | """ | ||||
@@ -120,7 +120,8 @@ class ConllLoader(DataSetLoader): | |||||
""" | """ | ||||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | ||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html | |||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||||
该符号在conll 2003中被用为文档分割符。 | |||||
列号从0开始, 每列对应内容为:: | 列号从0开始, 每列对应内容为:: | ||||
@@ -90,11 +90,12 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
return sample | return sample | ||||
with open(path, 'r', encoding=encoding) as f: | with open(path, 'r', encoding=encoding) as f: | ||||
sample = [] | sample = [] | ||||
start = next(f) | |||||
if '-DOCSTART-' not in start: | |||||
start = next(f).strip() | |||||
if '-DOCSTART-' not in start and start!='': | |||||
sample.append(start.split()) | sample.append(start.split()) | ||||
for line_idx, line in enumerate(f, 1): | for line_idx, line in enumerate(f, 1): | ||||
if line.startswith('\n'): | |||||
line = line.strip() | |||||
if line=='': | |||||
if len(sample): | if len(sample): | ||||
try: | try: | ||||
res = parse_conll(sample) | res = parse_conll(sample) | ||||
@@ -107,7 +108,8 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
elif line.startswith('#'): | elif line.startswith('#'): | ||||
continue | continue | ||||
else: | else: | ||||
sample.append(line.split()) | |||||
if not line.startswith('-DOCSTART-'): | |||||
sample.append(line.split()) | |||||
if len(sample) > 0: | if len(sample) > 0: | ||||
try: | try: | ||||
res = parse_conll(sample) | res = parse_conll(sample) | ||||
@@ -115,4 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
return | return | ||||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||||
print('invalid instance at line: {}'.format(line_idx)) | |||||
raise e |
@@ -9,7 +9,7 @@ from torch import nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | |||||
""" | """ | ||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | ||||
@@ -17,7 +17,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | :param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | ||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | ||||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | |||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | ||||
@@ -58,7 +58,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | ||||
""" | """ | ||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
:param str from_label: 比如"PER", "LOC"等label | :param str from_label: 比如"PER", "LOC"等label | ||||
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
@@ -134,9 +134,19 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
return to_tag in ['b', 's', 'end', 'o'] | return to_tag in ['b', 's', 'end', 'o'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | ||||
elif encoding_type == 'bioes': | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's', 'o'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['i', 'e'] and from_label == to_label | |||||
elif from_tag == 'i': | |||||
return to_tag in ['i', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's', 'o']: | |||||
return to_tag in ['b', 's', 'end', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) | |||||
else: | else: | ||||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||||
raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
@@ -24,7 +24,8 @@ __all__ = [ | |||||
"VarLSTM", | "VarLSTM", | ||||
"VarGRU" | "VarGRU" | ||||
] | ] | ||||
from .bert import BertModel | |||||
from ._bert import BertModel | |||||
from .bert import BertWordPieceEncoder | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | ||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ | from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ | ||||
@@ -6,18 +6,399 @@ | |||||
""" | """ | ||||
import torch | |||||
from torch import nn | |||||
from ...core.vocabulary import Vocabulary | from ...core.vocabulary import Vocabulary | ||||
import collections | import collections | ||||
import os | |||||
import unicodedata | import unicodedata | ||||
from ...io.file_utils import _get_base_url, cached_path | from ...io.file_utils import _get_base_url, cached_path | ||||
from .bert import BertModel | |||||
import numpy as np | import numpy as np | ||||
from itertools import chain | from itertools import chain | ||||
import copy | |||||
import json | |||||
import math | |||||
import os | |||||
import torch | |||||
from torch import nn | |||||
CONFIG_FILE = 'bert_config.json' | |||||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||||
def gelu(x): | |||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||||
def swish(x): | |||||
return x * torch.sigmoid(x) | |||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||||
class BertLayerNorm(nn.Module): | |||||
def __init__(self, hidden_size, eps=1e-12): | |||||
super(BertLayerNorm, self).__init__() | |||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
self.variance_epsilon = eps | |||||
def forward(self, x): | |||||
u = x.mean(-1, keepdim=True) | |||||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
return self.weight * x + self.bias | |||||
class BertEmbeddings(nn.Module): | |||||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||||
super(BertEmbeddings, self).__init__() | |||||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||||
# any TensorFlow checkpoint file | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, input_ids, token_type_ids=None): | |||||
seq_length = input_ids.size(1) | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
words_embeddings = self.word_embeddings(input_ids) | |||||
position_embeddings = self.position_embeddings(position_ids) | |||||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||||
embeddings = self.LayerNorm(embeddings) | |||||
embeddings = self.dropout(embeddings) | |||||
return embeddings | |||||
class BertSelfAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||||
super(BertSelfAttention, self).__init__() | |||||
if hidden_size % num_attention_heads != 0: | |||||
raise ValueError( | |||||
"The hidden size (%d) is not a multiple of the number of attention " | |||||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||||
self.num_attention_heads = num_attention_heads | |||||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||||
def transpose_for_scores(self, x): | |||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||||
x = x.view(*new_x_shape) | |||||
return x.permute(0, 2, 1, 3) | |||||
def forward(self, hidden_states, attention_mask): | |||||
mixed_query_layer = self.query(hidden_states) | |||||
mixed_key_layer = self.key(hidden_states) | |||||
mixed_value_layer = self.value(hidden_states) | |||||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||||
attention_scores = attention_scores + attention_mask | |||||
# Normalize the attention scores to probabilities. | |||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||||
# This is actually dropping out entire tokens to attend to, which might | |||||
# seem a bit unusual, but is taken from the original Transformer paper. | |||||
attention_probs = self.dropout(attention_probs) | |||||
context_layer = torch.matmul(attention_probs, value_layer) | |||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||||
context_layer = context_layer.view(*new_context_layer_shape) | |||||
return context_layer | |||||
class BertSelfOutput(nn.Module): | |||||
def __init__(self, hidden_size, hidden_dropout_prob): | |||||
super(BertSelfOutput, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||||
super(BertAttention, self).__init__() | |||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||||
def forward(self, input_tensor, attention_mask): | |||||
self_output = self.self(input_tensor, attention_mask) | |||||
attention_output = self.output(self_output, input_tensor) | |||||
return attention_output | |||||
class BertIntermediate(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||||
super(BertIntermediate, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||||
if isinstance(hidden_act, str) else hidden_act | |||||
def forward(self, hidden_states): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.intermediate_act_fn(hidden_states) | |||||
return hidden_states | |||||
class BertOutput(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||||
super(BertOutput, self).__init__() | |||||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertLayer(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertLayer, self).__init__() | |||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob) | |||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||||
def forward(self, hidden_states, attention_mask): | |||||
attention_output = self.attention(hidden_states, attention_mask) | |||||
intermediate_output = self.intermediate(attention_output) | |||||
layer_output = self.output(intermediate_output, attention_output) | |||||
return layer_output | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertEncoder, self).__init__() | |||||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act) | |||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||||
all_encoder_layers = [] | |||||
for layer_module in self.layer: | |||||
hidden_states = layer_module(hidden_states, attention_mask) | |||||
if output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
if not output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
return all_encoder_layers | |||||
class BertPooler(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(BertPooler, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.activation = nn.Tanh() | |||||
def forward(self, hidden_states): | |||||
# We "pool" the model by simply taking the hidden state corresponding | |||||
# to the first token. | |||||
first_token_tensor = hidden_states[:, 0] | |||||
pooled_output = self.dense(first_token_tensor) | |||||
pooled_output = self.activation(pooled_output) | |||||
return pooled_output | |||||
class BertModel(nn.Module): | |||||
"""BERT(Bidirectional Embedding Representations from Transformers). | |||||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||||
sources:: | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
用预训练权重矩阵来建立BERT模型:: | |||||
model = BertModel.from_pretrained("path/to/weights/directory") | |||||
用随机初始化权重矩阵来建立BERT模型:: | |||||
model = BertModel() | |||||
:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||||
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||||
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||||
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||||
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||||
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||||
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||||
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||||
:param int max_position_embeddings: 最大的序列长度,默认值为512, | |||||
:param int type_vocab_size: 最大segment数量,默认值为2 | |||||
:param int initializer_range: 初始化权重范围,默认值为0.02 | |||||
""" | |||||
def __init__(self, vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02): | |||||
super(BertModel, self).__init__() | |||||
self.hidden_size = hidden_size | |||||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||||
type_vocab_size, hidden_dropout_prob) | |||||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||||
hidden_act) | |||||
self.pooler = BertPooler(hidden_size) | |||||
self.initializer_range = initializer_range | |||||
self.apply(self.init_bert_weights) | |||||
def init_bert_weights(self, module): | |||||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||||
# Slightly different from the TF version which uses truncated_normal for initialization | |||||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||||
elif isinstance(module, BertLayerNorm): | |||||
module.bias.data.zero_() | |||||
module.weight.data.fill_(1.0) | |||||
if isinstance(module, nn.Linear) and module.bias is not None: | |||||
module.bias.data.zero_() | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||||
if attention_mask is None: | |||||
attention_mask = torch.ones_like(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
# We create a 3D attention mask from a 2D tensor mask. | |||||
# Sizes are [batch_size, 1, 1, to_seq_length] | |||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||||
# this attention mask is more simple than the triangular masking of causal attention | |||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||||
# masked positions, this operation will create a tensor which is 0.0 for | |||||
# positions we want to attend and -10000.0 for masked positions. | |||||
# Since we are adding it to the raw scores before the softmax, this is | |||||
# effectively the same as removing these entirely. | |||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||||
encoded_layers = self.encoder(embedding_output, | |||||
extended_attention_mask, | |||||
output_all_encoded_layers=output_all_encoded_layers) | |||||
sequence_output = encoded_layers[-1] | |||||
pooled_output = self.pooler(sequence_output) | |||||
if not output_all_encoded_layers: | |||||
encoded_layers = encoded_layers[-1] | |||||
return encoded_layers, pooled_output | |||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||||
# Load config | |||||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||||
config = json.load(open(config_file, "r")) | |||||
# config = BertConfig.from_json_file(config_file) | |||||
# logger.info("Model config {}".format(config)) | |||||
# Instantiate model. | |||||
model = cls(*inputs, **config, **kwargs) | |||||
if state_dict is None: | |||||
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||||
state_dict = torch.load(weights_path) | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if 'gamma' in key: | |||||
new_key = key.replace('gamma', 'weight') | |||||
if 'beta' in key: | |||||
new_key = key.replace('beta', 'bias') | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
missing_keys = [] | |||||
unexpected_keys = [] | |||||
error_msgs = [] | |||||
# copy state_dict so _load_from_state_dict can modify it | |||||
metadata = getattr(state_dict, '_metadata', None) | |||||
state_dict = state_dict.copy() | |||||
if metadata is not None: | |||||
state_dict._metadata = metadata | |||||
def load(module, prefix=''): | |||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||||
module._load_from_state_dict( | |||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||||
for name, child in module._modules.items(): | |||||
if child is not None: | |||||
load(child, prefix + name + '.') | |||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||||
if len(missing_keys) > 0: | |||||
print("Weights of {} not initialized from pretrained model: {}".format( | |||||
model.__class__.__name__, missing_keys)) | |||||
if len(unexpected_keys) > 0: | |||||
print("Weights from pretrained model not used in {}: {}".format( | |||||
model.__class__.__name__, unexpected_keys)) | |||||
return model | |||||
def whitespace_tokenize(text): | def whitespace_tokenize(text): | ||||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" | """Runs basic whitespace cleaning and splitting on a piece of text.""" | ||||
@@ -547,79 +928,3 @@ class _WordPieceBertModel(nn.Module): | |||||
outputs[l_index] = bert_outputs[l] | outputs[l_index] = bert_outputs[l] | ||||
return outputs | return outputs | ||||
class BertWordPieceEncoder(nn.Module): | |||||
""" | |||||
可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 | |||||
:param vocab: Vocabulary. | |||||
:param model_dir_or_name: | |||||
:param layers: | |||||
:param requires_grad: | |||||
""" | |||||
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', | |||||
requires_grad:bool=False): | |||||
super().__init__() | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
# TODO 修改 | |||||
PRETRAINED_BERT_MODEL_DIR = {'en-base': 'bert_en-80f95ea7.tar.gz', | |||||
'cn': 'elmo_cn.zip'} | |||||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(model_dir_or_name): | |||||
model_dir = model_dir_or_name | |||||
else: | |||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||||
self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||||
self.requires_grad = requires_grad | |||||
@property | |||||
def requires_grad(self): | |||||
""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||||
:return: | |||||
""" | |||||
requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||||
if len(requires_grads)==1: | |||||
return requires_grads.pop() | |||||
else: | |||||
return None | |||||
@requires_grad.setter | |||||
def requires_grad(self, value): | |||||
for name, param in self.named_parameters(): | |||||
param.requires_grad = value | |||||
@property | |||||
def embed_size(self): | |||||
return self._embed_size | |||||
def index_datasets(self, *datasets): | |||||
""" | |||||
对datasets进行word piece的index。 | |||||
Example:: | |||||
:param datasets: | |||||
:return: | |||||
""" | |||||
self.model.index_dataset(*datasets) | |||||
def forward(self, words, token_type_ids=None): | |||||
""" | |||||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||||
删除这两个表示。 | |||||
:param words: batch_size x max_len | |||||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||||
""" | |||||
outputs = self.model(words, token_type_ids) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return outputs |
@@ -1,378 +1,95 @@ | |||||
""" | |||||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||||
""" | |||||
import copy | |||||
import json | |||||
import math | |||||
import os | import os | ||||
import torch | |||||
from torch import nn | from torch import nn | ||||
import torch | |||||
from ...core import Vocabulary | |||||
from ...io.file_utils import _get_base_url, cached_path | |||||
from ._bert import _WordPieceBertModel | |||||
CONFIG_FILE = 'bert_config.json' | |||||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||||
def gelu(x): | |||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||||
def swish(x): | |||||
return x * torch.sigmoid(x) | |||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||||
class BertLayerNorm(nn.Module): | |||||
def __init__(self, hidden_size, eps=1e-12): | |||||
super(BertLayerNorm, self).__init__() | |||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
self.variance_epsilon = eps | |||||
def forward(self, x): | |||||
u = x.mean(-1, keepdim=True) | |||||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
return self.weight * x + self.bias | |||||
class BertEmbeddings(nn.Module): | |||||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||||
super(BertEmbeddings, self).__init__() | |||||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||||
# any TensorFlow checkpoint file | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, input_ids, token_type_ids=None): | |||||
seq_length = input_ids.size(1) | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
words_embeddings = self.word_embeddings(input_ids) | |||||
position_embeddings = self.position_embeddings(position_ids) | |||||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||||
embeddings = self.LayerNorm(embeddings) | |||||
embeddings = self.dropout(embeddings) | |||||
return embeddings | |||||
class BertSelfAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||||
super(BertSelfAttention, self).__init__() | |||||
if hidden_size % num_attention_heads != 0: | |||||
raise ValueError( | |||||
"The hidden size (%d) is not a multiple of the number of attention " | |||||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||||
self.num_attention_heads = num_attention_heads | |||||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||||
def transpose_for_scores(self, x): | |||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||||
x = x.view(*new_x_shape) | |||||
return x.permute(0, 2, 1, 3) | |||||
def forward(self, hidden_states, attention_mask): | |||||
mixed_query_layer = self.query(hidden_states) | |||||
mixed_key_layer = self.key(hidden_states) | |||||
mixed_value_layer = self.value(hidden_states) | |||||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||||
attention_scores = attention_scores + attention_mask | |||||
# Normalize the attention scores to probabilities. | |||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||||
# This is actually dropping out entire tokens to attend to, which might | |||||
# seem a bit unusual, but is taken from the original Transformer paper. | |||||
attention_probs = self.dropout(attention_probs) | |||||
context_layer = torch.matmul(attention_probs, value_layer) | |||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||||
context_layer = context_layer.view(*new_context_layer_shape) | |||||
return context_layer | |||||
class BertSelfOutput(nn.Module): | |||||
def __init__(self, hidden_size, hidden_dropout_prob): | |||||
super(BertSelfOutput, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||||
super(BertAttention, self).__init__() | |||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||||
def forward(self, input_tensor, attention_mask): | |||||
self_output = self.self(input_tensor, attention_mask) | |||||
attention_output = self.output(self_output, input_tensor) | |||||
return attention_output | |||||
class BertIntermediate(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||||
super(BertIntermediate, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||||
if isinstance(hidden_act, str) else hidden_act | |||||
def forward(self, hidden_states): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.intermediate_act_fn(hidden_states) | |||||
return hidden_states | |||||
class BertOutput(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||||
super(BertOutput, self).__init__() | |||||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertLayer(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertLayer, self).__init__() | |||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob) | |||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||||
def forward(self, hidden_states, attention_mask): | |||||
attention_output = self.attention(hidden_states, attention_mask) | |||||
intermediate_output = self.intermediate(attention_output) | |||||
layer_output = self.output(intermediate_output, attention_output) | |||||
return layer_output | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertEncoder, self).__init__() | |||||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act) | |||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||||
all_encoder_layers = [] | |||||
for layer_module in self.layer: | |||||
hidden_states = layer_module(hidden_states, attention_mask) | |||||
if output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
if not output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
return all_encoder_layers | |||||
class BertPooler(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(BertPooler, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.activation = nn.Tanh() | |||||
def forward(self, hidden_states): | |||||
# We "pool" the model by simply taking the hidden state corresponding | |||||
# to the first token. | |||||
first_token_tensor = hidden_states[:, 0] | |||||
pooled_output = self.dense(first_token_tensor) | |||||
pooled_output = self.activation(pooled_output) | |||||
return pooled_output | |||||
class BertModel(nn.Module): | |||||
"""BERT(Bidirectional Embedding Representations from Transformers). | |||||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||||
sources:: | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
用预训练权重矩阵来建立BERT模型:: | |||||
model = BertModel.from_pretrained("path/to/weights/directory") | |||||
用随机初始化权重矩阵来建立BERT模型:: | |||||
model = BertModel() | |||||
:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||||
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||||
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||||
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||||
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||||
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||||
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||||
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||||
:param int max_position_embeddings: 最大的序列长度,默认值为512, | |||||
:param int type_vocab_size: 最大segment数量,默认值为2 | |||||
:param int initializer_range: 初始化权重范围,默认值为0.02 | |||||
class BertWordPieceEncoder(nn.Module): | |||||
""" | """ | ||||
可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 | |||||
def __init__(self, vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02): | |||||
super(BertModel, self).__init__() | |||||
self.hidden_size = hidden_size | |||||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||||
type_vocab_size, hidden_dropout_prob) | |||||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||||
hidden_act) | |||||
self.pooler = BertPooler(hidden_size) | |||||
self.initializer_range = initializer_range | |||||
self.apply(self.init_bert_weights) | |||||
def init_bert_weights(self, module): | |||||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||||
# Slightly different from the TF version which uses truncated_normal for initialization | |||||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||||
elif isinstance(module, BertLayerNorm): | |||||
module.bias.data.zero_() | |||||
module.weight.data.fill_(1.0) | |||||
if isinstance(module, nn.Linear) and module.bias is not None: | |||||
module.bias.data.zero_() | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||||
if attention_mask is None: | |||||
attention_mask = torch.ones_like(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
# We create a 3D attention mask from a 2D tensor mask. | |||||
# Sizes are [batch_size, 1, 1, to_seq_length] | |||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||||
# this attention mask is more simple than the triangular masking of causal attention | |||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||||
# masked positions, this operation will create a tensor which is 0.0 for | |||||
# positions we want to attend and -10000.0 for masked positions. | |||||
# Since we are adding it to the raw scores before the softmax, this is | |||||
# effectively the same as removing these entirely. | |||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||||
encoded_layers = self.encoder(embedding_output, | |||||
extended_attention_mask, | |||||
output_all_encoded_layers=output_all_encoded_layers) | |||||
sequence_output = encoded_layers[-1] | |||||
pooled_output = self.pooler(sequence_output) | |||||
if not output_all_encoded_layers: | |||||
encoded_layers = encoded_layers[-1] | |||||
return encoded_layers, pooled_output | |||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||||
# Load config | |||||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||||
config = json.load(open(config_file, "r")) | |||||
# config = BertConfig.from_json_file(config_file) | |||||
# logger.info("Model config {}".format(config)) | |||||
# Instantiate model. | |||||
model = cls(*inputs, **config, **kwargs) | |||||
if state_dict is None: | |||||
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||||
state_dict = torch.load(weights_path) | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if 'gamma' in key: | |||||
new_key = key.replace('gamma', 'weight') | |||||
if 'beta' in key: | |||||
new_key = key.replace('beta', 'bias') | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
missing_keys = [] | |||||
unexpected_keys = [] | |||||
error_msgs = [] | |||||
# copy state_dict so _load_from_state_dict can modify it | |||||
metadata = getattr(state_dict, '_metadata', None) | |||||
state_dict = state_dict.copy() | |||||
if metadata is not None: | |||||
state_dict._metadata = metadata | |||||
def load(module, prefix=''): | |||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||||
module._load_from_state_dict( | |||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||||
for name, child in module._modules.items(): | |||||
if child is not None: | |||||
load(child, prefix + name + '.') | |||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||||
if len(missing_keys) > 0: | |||||
print("Weights of {} not initialized from pretrained model: {}".format( | |||||
model.__class__.__name__, missing_keys)) | |||||
if len(unexpected_keys) > 0: | |||||
print("Weights from pretrained model not used in {}: {}".format( | |||||
model.__class__.__name__, unexpected_keys)) | |||||
return model | |||||
:param fastNLP.Vocabulary vocab: 词表 | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | |||||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||||
:param bool requires_grad: 是否需要gradient。 | |||||
""" | |||||
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', | |||||
requires_grad:bool=False): | |||||
super().__init__() | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
} | |||||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(model_dir_or_name): | |||||
model_dir = model_dir_or_name | |||||
else: | |||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||||
self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||||
self.requires_grad = requires_grad | |||||
@property | |||||
def requires_grad(self): | |||||
""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||||
:return: | |||||
""" | |||||
requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||||
if len(requires_grads)==1: | |||||
return requires_grads.pop() | |||||
else: | |||||
return None | |||||
@requires_grad.setter | |||||
def requires_grad(self, value): | |||||
for name, param in self.named_parameters(): | |||||
param.requires_grad = value | |||||
@property | |||||
def embed_size(self): | |||||
return self._embed_size | |||||
def index_datasets(self, *datasets): | |||||
""" | |||||
根据datasets中的'words'列对datasets进行word piece的index。 | |||||
Example:: | |||||
:param datasets: | |||||
:return: | |||||
""" | |||||
self.model.index_dataset(*datasets) | |||||
def forward(self, words, token_type_ids=None): | |||||
""" | |||||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||||
删除这两个表示。 | |||||
:param words: batch_size x max_len | |||||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||||
""" | |||||
outputs = self.model(words, token_type_ids) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return outputs |
@@ -165,7 +165,6 @@ class StaticEmbedding(TokenEmbedding): | |||||
super(StaticEmbedding, self).__init__(vocab) | super(StaticEmbedding, self).__init__(vocab) | ||||
# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | ||||
PRETRAIN_URL = _get_base_url('static') | |||||
PRETRAIN_STATIC_FILES = { | PRETRAIN_STATIC_FILES = { | ||||
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | ||||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | ||||
@@ -178,6 +177,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
# 得到cache_path | # 得到cache_path | ||||
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | ||||
PRETRAIN_URL = _get_base_url('static') | |||||
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_path = cached_path(model_url) | model_path = cached_path(model_url) | ||||
@@ -333,12 +333,11 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
self.layers = layers | self.layers = layers | ||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
PRETRAIN_URL = _get_base_url('elmo') | |||||
# TODO 把baidu云上的加上去 | |||||
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | ||||
'cn': 'elmo_cn-5e9b34e2.tar.gz'} | 'cn': 'elmo_cn-5e9b34e2.tar.gz'} | ||||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | ||||
PRETRAIN_URL = _get_base_url('elmo') | |||||
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_dir = cached_path(model_url) | model_dir = cached_path(model_url) | ||||
@@ -392,7 +391,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'words_to_chars_embedding' in name: # 这个不能加入到requires_grad中 | if 'words_to_chars_embedding' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
@@ -420,7 +419,6 @@ class BertEmbedding(ContextualEmbedding): | |||||
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | ||||
super(BertEmbedding, self).__init__(vocab) | super(BertEmbedding, self).__init__(vocab) | ||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
PRETRAIN_URL = _get_base_url('bert') | |||||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | ||||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | ||||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | 'en-base-cased': 'bert-base-cased-f89bfe08.zip', | ||||
@@ -436,6 +434,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
} | } | ||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | ||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_dir = cached_path(model_url) | model_dir = cached_path(model_url) | ||||
@@ -487,7 +486,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 | if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
@@ -575,6 +574,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
for i in range(len(kernel_sizes))]) | for i in range(len(kernel_sizes))]) | ||||
self._embed_size = embed_size | self._embed_size = embed_size | ||||
self.fc = nn.Linear(sum(filter_nums), embed_size) | self.fc = nn.Linear(sum(filter_nums), embed_size) | ||||
self.init_param() | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
@@ -627,9 +627,17 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
def init_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | |||||
continue | |||||
if param.data.dim()>1: | |||||
nn.init.xavier_normal_(param, 1) | |||||
else: | |||||
nn.init.uniform_(param, -1, 1) | |||||
class LSTMCharEmbedding(TokenEmbedding): | class LSTMCharEmbedding(TokenEmbedding): | ||||
""" | """ | ||||
@@ -753,7 +761,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
@@ -35,8 +35,18 @@ class LSTM(nn.Module): | |||||
self.batch_first = batch_first | self.batch_first = batch_first | ||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | ||||
dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
self.init_param() | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def init_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'bias_i' in name: | |||||
param.data.fill_(1) | |||||
elif 'bias_h' in name: | |||||
param.data.fill_(0) | |||||
else: | |||||
nn.init.xavier_normal_(param) | |||||
def forward(self, x, seq_len=None, h0=None, c0=None): | def forward(self, x, seq_len=None, h0=None, c0=None): | ||||
""" | """ | ||||
@@ -57,8 +57,12 @@ callbacks = [clipper] | |||||
# if pretrain: | # if pretrain: | ||||
# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | # fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | ||||
# callbacks.append(fixer) | # callbacks.append(fixer) | ||||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, | |||||
update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), | |||||
metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, | |||||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, | |||||
batch_size=32, sampler=sampler, update_every=5, | |||||
n_epochs=3, print_every=5, | |||||
dev_data=data.datasets['dev'], metrics=RelayMetric(), metric_key='f', | |||||
validate_every=-1, save_path=None, | |||||
prefetch=True, use_tqdm=True, device=device, | |||||
callbacks=callbacks, | |||||
check_code_level=0) | check_code_level=0) | ||||
trainer.train() | trainer.train() |
@@ -25,7 +25,7 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
if not os.path.isfile(train_fp): | if not os.path.isfile(train_fp): | ||||
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | ||||
files = {'train': train_fp} | files = {'train': train_fp} | ||||
for filename in ['test.txt', 'dev.txt']: | |||||
for filename in ['dev.txt', 'test.txt']: | |||||
fp = os.path.join(paths, filename) | fp = os.path.join(paths, filename) | ||||
if os.path.isfile(fp): | if os.path.isfile(fp): | ||||
files[filename.split('.')[0]] = fp | files[filename.split('.')[0]] = fp | ||||
@@ -161,7 +161,15 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_duplicate(self): | |||||
# 0.4.1的潜在bug,不能出现形参重复的情况 | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} | |||||
target_dict = {'targets':torch.zeros(4, 3), 'target': 0} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
def test_seq_len(self): | def test_seq_len(self): | ||||
N = 256 | N = 256 | ||||
seq_len = torch.zeros(N).long() | seq_len = torch.zeros(N).long() | ||||
@@ -1,6 +1,5 @@ | |||||
import unittest | import unittest | ||||
import fastNLP | |||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | ||||
from .model_runner import * | from .model_runner import * | ||||
@@ -10,14 +10,14 @@ class TestCRF(unittest.TestCase): | |||||
id2label = {0: 'B', 1: 'I', 2:'O'} | id2label = {0: 'B', 1: 'I', 2:'O'} | ||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | ||||
(2, 4), (3, 0), (3, 2)} | (2, 4), (3, 0), (3, 2)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | ||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | ||||
allowed_transitions(id2label) | |||||
allowed_transitions(id2label, include_start_end=True) | |||||
labels = ['O'] | labels = ['O'] | ||||
for label in ['X', 'Y']: | for label in ['X', 'Y']: | ||||
@@ -27,7 +27,7 @@ class TestCRF(unittest.TestCase): | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | ||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | ||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||||
labels = [] | labels = [] | ||||
for label in ['X', 'Y']: | for label in ['X', 'Y']: | ||||
@@ -37,7 +37,7 @@ class TestCRF(unittest.TestCase): | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | ||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | ||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
def test_case2(self): | def test_case2(self): | ||||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | ||||
@@ -80,7 +80,7 @@ class TestTutorial(unittest.TestCase): | |||||
test_data.rename_field('label', 'label_seq') | test_data.rename_field('label', 'label_seq') | ||||
loss = CrossEntropyLoss(pred="output", target="label_seq") | loss = CrossEntropyLoss(pred="output", target="label_seq") | ||||
metric = AccuracyMetric(pred="predict", target="label_seq") | |||||
metric = AccuracyMetric(target="label_seq") | |||||
# 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
# 先在test_data拟合(确保模型的实现是正确的) | # 先在test_data拟合(确保模型的实现是正确的) | ||||
@@ -90,16 +90,19 @@ class TestTutorial(unittest.TestCase): | |||||
overfit_trainer.train() | overfit_trainer.train() | ||||
# 用train_data训练,在test_data验证 | # 用train_data训练,在test_data验证 | ||||
trainer = Trainer(train_data=train_data, model=model, loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
batch_size=32, n_epochs=5, dev_data=test_data, | |||||
metrics=AccuracyMetric(pred="predict", target="label_seq"), save_path=None) | |||||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||||
metrics=AccuracyMetric(target="label_seq"), | |||||
save_path=None, | |||||
batch_size=32, | |||||
n_epochs=5) | |||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||
# 调用Tester在test_data上评价效果 | # 调用Tester在test_data上评价效果 | ||||
from fastNLP import Tester | from fastNLP import Tester | ||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(target="label_seq"), | |||||
batch_size=4) | batch_size=4) | ||||
acc = tester.test() | acc = tester.test() | ||||
print(acc) | print(acc) | ||||