2. 修改embedding.py中的bug 3. ConllReader默认跳过所有的DOCSTART标签 4. 交换bert的heavy lifting到_bert, 将BertEncoder在bert.py中暴露 5. crf中allow_transition的include_end_start修改为false,以与CRF的默认值适配 6. allow_transition与SpanMetric支持BIOES类型的tag 7. datainfo中增加打印格式化输出tags/v0.4.10
@@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||||
__all__ = [ | __all__ = [ | ||||
"Instance", | "Instance", | ||||
"FieldArray", | "FieldArray", | ||||
"Batch", | |||||
"DataSetIter", | |||||
"BatchIter", | |||||
"TorchLoaderIter", | |||||
"Vocabulary", | "Vocabulary", | ||||
"DataSet", | "DataSet", | ||||
"Const", | "Const", | ||||
@@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||||
介绍core 的子模块的分工,好像必要性不大 | 介绍core 的子模块的分工,好像必要性不大 | ||||
""" | """ | ||||
from .batch import Batch | |||||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | ||||
from .const import Const | from .const import Const | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
@@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"Batch" | |||||
"BatchIter", | |||||
"DataSetIter", | |||||
"TorchLoaderIter", | |||||
] | ] | ||||
import atexit | import atexit | ||||
@@ -12,9 +14,11 @@ from queue import Empty, Full | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
import torch.utils.data | |||||
from numbers import Number | from numbers import Number | ||||
from .sampler import RandomSampler | |||||
from .sampler import SequentialSampler | |||||
from .dataset import DataSet | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -27,162 +31,157 @@ def _set_python_is_exit(): | |||||
atexit.register(_set_python_is_exit) | atexit.register(_set_python_is_exit) | ||||
class Batch(object): | |||||
""" | |||||
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` | |||||
Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||||
组成 `x` 和 `y`:: | |||||
batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) | |||||
num_batch = len(batch) | |||||
for batch_x, batch_y in batch: | |||||
# do stuff ... | |||||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||||
:param int batch_size: 取出的batch大小 | |||||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`. | |||||
Default: ``None`` | |||||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||||
Default: ``False`` | |||||
:param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. | |||||
Default: ``False`` | |||||
""" | |||||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||||
class DataSetGetter: | |||||
def __init__(self, dataset: DataSet, as_numpy=False): | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.batch_size = batch_size | |||||
if sampler is None: | |||||
sampler = RandomSampler() | |||||
self.sampler = sampler | |||||
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} | |||||
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} | |||||
self.as_numpy = as_numpy | self.as_numpy = as_numpy | ||||
self.idx_list = None | |||||
self.curidx = 0 | |||||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||||
self.cur_batch_indices = None | |||||
self.prefetch = prefetch | |||||
self.lengths = 0 | |||||
def fetch_one(self): | |||||
if self.curidx >= len(self.idx_list): | |||||
return None | |||||
self.idx_list = list(range(len(dataset))) | |||||
def __getitem__(self, idx: int): | |||||
# mapping idx to sampled idx | |||||
idx = self.idx_list[idx] | |||||
inputs = {n:f.get(idx) for n, f in self.inputs.items()} | |||||
targets = {n:f.get(idx) for n, f in self.targets.items()} | |||||
return idx, inputs, targets | |||||
def __len__(self): | |||||
return len(self.dataset) | |||||
def collate_fn(self, batch: list): | |||||
batch_x = {n:[] for n in self.inputs.keys()} | |||||
batch_y = {n:[] for n in self.targets.keys()} | |||||
indices = [] | |||||
for idx, x, y in batch: | |||||
indices.append(idx) | |||||
for n, v in x.items(): | |||||
batch_x[n].append(v) | |||||
for n, v in y.items(): | |||||
batch_y[n].append(v) | |||||
def pad_batch(batch_dict, field_array): | |||||
for n, vlist in batch_dict.items(): | |||||
f = field_array[n] | |||||
if f.padder is None: | |||||
batch_dict[n] = np.array(vlist) | |||||
else: | |||||
data = f.pad(vlist) | |||||
if not self.as_numpy: | |||||
data, flag = _to_tensor(data, f.dtype) | |||||
batch_dict[n] = data | |||||
return batch_dict | |||||
return (indices, | |||||
pad_batch(batch_x, self.inputs), | |||||
pad_batch(batch_y, self.targets)) | |||||
def set_idx_list(self, idx_list): | |||||
if len(idx_list) != len(self.idx_list): | |||||
raise ValueError | |||||
self.idx_list = idx_list | |||||
def __getattr__(self, item): | |||||
if hasattr(self.dataset, item): | |||||
return getattr(self.dataset, item) | |||||
else: | else: | ||||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||||
batch_x, batch_y = {}, {} | |||||
indices = self.idx_list[self.curidx:endidx] | |||||
self.cur_batch_indices = indices | |||||
for field_name, field in self.dataset.get_all_fields().items(): | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | |||||
if not self.as_numpy and \ | |||||
field.dtype is not None and \ | |||||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||||
batch = _to_tensor(batch) | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | |||||
self.curidx = endidx | |||||
return batch_x, batch_y | |||||
raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item)) | |||||
class SamplerAdapter(torch.utils.data.Sampler): | |||||
def __init__(self, sampler, dataset): | |||||
self.sampler = sampler | |||||
self.dataset = dataset | |||||
def __iter__(self): | def __iter__(self): | ||||
""" | |||||
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||||
:return: | |||||
""" | |||||
if self.prefetch: | |||||
return self._run_batch_iter(self) | |||||
def batch_iter(): | |||||
self.init_iter() | |||||
while 1: | |||||
res = self.fetch_one() | |||||
if res is None: | |||||
break | |||||
yield res | |||||
return batch_iter() | |||||
return iter(self.sampler(self.dataset)) | |||||
class BatchIter: | |||||
def __init__(self): | |||||
self.dataiter = None | |||||
self.num_batches = None | |||||
self.cur_batch_indices = None | |||||
self.batch_size = None | |||||
def init_iter(self): | def init_iter(self): | ||||
self.idx_list = self.sampler(self.dataset) | |||||
self.curidx = 0 | |||||
self.lengths = self.dataset.get_length() | |||||
pass | |||||
@staticmethod | |||||
def get_num_batches(num_samples, batch_size, drop_last): | |||||
num_batches = num_samples // batch_size | |||||
if not drop_last and (num_samples % batch_size > 0): | |||||
num_batches += 1 | |||||
return num_batches | |||||
def __iter__(self): | |||||
self.init_iter() | |||||
for indices, batch_x, batch_y in self.dataiter: | |||||
self.cur_batch_indices = indices | |||||
yield batch_x, batch_y | |||||
def get_batch_indices(self): | |||||
return self.cur_batch_indices | |||||
def __len__(self): | def __len__(self): | ||||
return self.num_batches | return self.num_batches | ||||
def get_batch_indices(self): | |||||
""" | |||||
取得当前batch在DataSet中所在的index下标序列 | |||||
:return list(int) indexes: 下标序列 | |||||
""" | |||||
return self.cur_batch_indices | |||||
@staticmethod | |||||
def _run_fetch(batch, q): | |||||
try: | |||||
global _python_is_exit | |||||
batch.init_iter() | |||||
# print('start fetch') | |||||
while 1: | |||||
res = batch.fetch_one() | |||||
# print('fetch one') | |||||
while 1: | |||||
try: | |||||
q.put(res, timeout=3) | |||||
break | |||||
except Full: | |||||
if _python_is_exit: | |||||
return | |||||
if res is None: | |||||
# print('fetch done, waiting processing') | |||||
break | |||||
# print('fetch exit') | |||||
except Exception as e: | |||||
q.put(e) | |||||
finally: | |||||
q.join() | |||||
@staticmethod | |||||
def _run_batch_iter(batch): | |||||
q = mp.JoinableQueue(maxsize=10) | |||||
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) | |||||
fetch_p.daemon = True | |||||
fetch_p.start() | |||||
# print('fork fetch process') | |||||
while 1: | |||||
try: | |||||
res = q.get(timeout=1) | |||||
q.task_done() | |||||
# print('get fetched') | |||||
if res is None: | |||||
break | |||||
elif isinstance(res, Exception): | |||||
raise res | |||||
yield res | |||||
except Empty as e: | |||||
if fetch_p.is_alive(): | |||||
continue | |||||
else: | |||||
break | |||||
fetch_p.terminate() | |||||
fetch_p.join() | |||||
# print('iter done') | |||||
@property | |||||
def dataset(self): | |||||
return self.dataiter.dataset | |||||
class DataSetIter(BatchIter): | |||||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None): | |||||
super().__init__() | |||||
assert isinstance(dataset, DataSet) | |||||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||||
dataset = DataSetGetter(dataset, as_numpy) | |||||
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | |||||
self.dataiter = torch.utils.data.DataLoader( | |||||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||||
collate_fn=collate_fn, num_workers=num_workers, | |||||
pin_memory=pin_memory, drop_last=drop_last, | |||||
timeout=timeout, worker_init_fn=worker_init_fn) | |||||
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) | |||||
self.batch_size = batch_size | |||||
class TorchLoaderIter(BatchIter): | |||||
def __init__(self, dataset): | |||||
super().__init__() | |||||
assert isinstance(dataset, torch.utils.data.DataLoader) | |||||
self.dataiter = dataset | |||||
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) | |||||
self.batch_size = dataset.batch_size | |||||
def _to_tensor(batch): | |||||
class OnlineDataGettter: | |||||
# TODO | |||||
pass | |||||
class OnlineDataIter(BatchIter): | |||||
# TODO | |||||
def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False, | |||||
num_workers=0, pin_memory=False, drop_last=False, | |||||
timeout=0, worker_init_fn=None, **kwargs): | |||||
super().__init__() | |||||
def _to_tensor(batch, field_dtype): | |||||
try: | try: | ||||
if issubclass(batch.dtype.type, np.floating): | |||||
batch = torch.as_tensor(batch).float() # 默认使用float32 | |||||
if field_dtype is not None \ | |||||
and issubclass(field_dtype, Number) \ | |||||
and not isinstance(batch, torch.Tensor): | |||||
if issubclass(batch.dtype.type, np.floating): | |||||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | |||||
else: | |||||
new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||||
return new_batch, True | |||||
else: | else: | ||||
batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||||
return batch, False | |||||
except: | except: | ||||
pass | |||||
return batch | |||||
return batch, False |
@@ -176,7 +176,10 @@ class FieldArray: | |||||
if self.padder is None or pad is False: | if self.padder is None or pad is False: | ||||
return np.array(contents) | return np.array(contents) | ||||
else: | else: | ||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||||
return self.pad(contents) | |||||
def pad(self, contents): | |||||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||||
def set_padder(self, padder): | def set_padder(self, padder): | ||||
""" | """ | ||||
@@ -34,14 +34,23 @@ class LossBase(object): | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.param_map = {} | |||||
self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value | |||||
self._checked = False | self._checked = False | ||||
@property | |||||
def param_map(self): | |||||
if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||||
func_spect = inspect.getfullargspec(self.get_loss) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for arg in func_args: | |||||
self._param_map[arg] = arg | |||||
return self._param_map | |||||
def get_loss(self, *args, **kwargs): | def get_loss(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
:param dict key_map: 表示key的映射关系 | :param dict key_map: 表示key的映射关系 | ||||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | ||||
@@ -53,30 +62,30 @@ class LossBase(object): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | ||||
for key, value in key_map.items(): | for key, value in key_map.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(key, str): | if not isinstance(key, str): | ||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for key, value in kwargs.items(): | for key, value in kwargs.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and param_map | |||||
# check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
for func_param, input_param in self.param_map.items(): | |||||
for func_param, input_param in self._param_map.items(): | |||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | ||||
@@ -96,7 +105,7 @@ class LossBase(object): | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | ||||
""" | """ | ||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
@@ -115,19 +124,19 @@ class LossBase(object): | |||||
return loss | return loss | ||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | |||||
# 1. check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.get_loss) | func_spect = inspect.getfullargspec(self.get_loss) | ||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | func_args = set([arg for arg in func_spect.args if arg != 'self']) | ||||
for func_arg, input_arg in self.param_map.items(): | |||||
for func_arg, input_arg in self._param_map.items(): | |||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | ||||
# 2. only part of the param_map are passed, left are not | |||||
# 2. only part of the _param_map are passed, left are not | |||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
if arg not in self._param_map: | |||||
self._param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
@@ -149,7 +158,7 @@ class LossBase(object): | |||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
@@ -162,6 +171,8 @@ class LossBase(object): | |||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.get_loss)) | func_signature=_get_func_signature(self.get_loss)) | ||||
self._checked = True | |||||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | ||||
loss = self.get_loss(**refined_args) | loss = self.get_loss(**refined_args) | ||||
@@ -115,9 +115,18 @@ class MetricBase(object): | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
self.param_map = {} # key is param in function, value is input param. | |||||
self._param_map = {} # key is param in function, value is input param. | |||||
self._checked = False | self._checked = False | ||||
@property | |||||
def param_map(self): | |||||
if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||||
func_spect = inspect.getfullargspec(self.evaluate) | |||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||||
for arg in func_args: | |||||
self._param_map[arg] = arg | |||||
return self._param_map | |||||
@abstractmethod | @abstractmethod | ||||
def evaluate(self, *args, **kwargs): | def evaluate(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
@@ -127,7 +136,7 @@ class MetricBase(object): | |||||
raise NotImplemented | raise NotImplemented | ||||
def _init_param_map(self, key_map=None, **kwargs): | def _init_param_map(self, key_map=None, **kwargs): | ||||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||||
:param dict key_map: 表示key的映射关系 | :param dict key_map: 表示key的映射关系 | ||||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | :param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | ||||
@@ -139,30 +148,30 @@ class MetricBase(object): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | ||||
for key, value in key_map.items(): | for key, value in key_map.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(key, str): | if not isinstance(key, str): | ||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for key, value in kwargs.items(): | for key, value in kwargs.items(): | ||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | |||||
self._param_map[key] = key | |||||
continue | continue | ||||
if not isinstance(value, str): | if not isinstance(value, str): | ||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | |||||
self._param_map[key] = value | |||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
for value, key_set in value_counter.items(): | for value, key_set in value_counter.items(): | ||||
if len(key_set) > 1: | if len(key_set) > 1: | ||||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | ||||
# check consistence between signature and param_map | |||||
# check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = [arg for arg in func_spect.args if arg != 'self'] | func_args = [arg for arg in func_spect.args if arg != 'self'] | ||||
for func_param, input_param in self.param_map.items(): | |||||
for func_param, input_param in self._param_map.items(): | |||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | ||||
@@ -177,7 +186,7 @@ class MetricBase(object): | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | ||||
""" | """ | ||||
fast_param = {} | fast_param = {} | ||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||||
fast_param['pred'] = list(pred_dict.values())[0] | fast_param['pred'] = list(pred_dict.values())[0] | ||||
fast_param['target'] = list(target_dict.values())[0] | fast_param['target'] = list(target_dict.values())[0] | ||||
return fast_param | return fast_param | ||||
@@ -206,19 +215,19 @@ class MetricBase(object): | |||||
if not self._checked: | if not self._checked: | ||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
# 1. check consistence between signature and param_map | |||||
# 1. check consistence between signature and _param_map | |||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | func_args = set([arg for arg in func_spect.args if arg != 'self']) | ||||
for func_arg, input_arg in self.param_map.items(): | |||||
for func_arg, input_arg in self._param_map.items(): | |||||
if func_arg not in func_args: | if func_arg not in func_args: | ||||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | ||||
# 2. only part of the param_map are passed, left are not | |||||
# 2. only part of the _param_map are passed, left are not | |||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | |||||
self.param_map[arg] = arg # This param does not need mapping. | |||||
if arg not in self._param_map: | |||||
self._param_map[arg] = arg # This param does not need mapping. | |||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_pred_dict = {} | mapped_pred_dict = {} | ||||
@@ -242,7 +251,7 @@ class MetricBase(object): | |||||
replaced_missing = list(missing) | replaced_missing = list(missing) | ||||
for idx, func_arg in enumerate(missing): | for idx, func_arg in enumerate(missing): | ||||
# Don't delete `` in this information, nor add `` | # Don't delete `` in this information, nor add `` | ||||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||||
f"in `{self.__class__.__name__}`)" | f"in `{self.__class__.__name__}`)" | ||||
check_res = _CheckRes(missing=replaced_missing, | check_res = _CheckRes(missing=replaced_missing, | ||||
@@ -255,10 +264,10 @@ class MetricBase(object): | |||||
if check_res.missing or check_res.duplicated: | if check_res.missing or check_res.duplicated: | ||||
raise _CheckError(check_res=check_res, | raise _CheckError(check_res=check_res, | ||||
func_signature=_get_func_signature(self.evaluate)) | func_signature=_get_func_signature(self.evaluate)) | ||||
self._checked = True | |||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | ||||
self.evaluate(**refined_args) | self.evaluate(**refined_args) | ||||
self._checked = True | |||||
return | return | ||||
@@ -416,19 +425,19 @@ def _bioes_tag_to_spans(tags, ignore_labels=None): | |||||
ignore_labels = set(ignore_labels) if ignore_labels else set() | ignore_labels = set(ignore_labels) if ignore_labels else set() | ||||
spans = [] | spans = [] | ||||
prev_bmes_tag = None | |||||
prev_bioes_tag = None | |||||
for idx, tag in enumerate(tags): | for idx, tag in enumerate(tags): | ||||
tag = tag.lower() | tag = tag.lower() | ||||
bmes_tag, label = tag[:1], tag[2:] | |||||
if bmes_tag in ('b', 's'): | |||||
bieso_tag, label = tag[:1], tag[2:] | |||||
if bieso_tag in ('b', 's'): | |||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]: | |||||
elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: | |||||
spans[-1][1][1] = idx | spans[-1][1][1] = idx | ||||
elif bmes_tag == 'o': | |||||
elif bieso_tag == 'o': | |||||
pass | pass | ||||
else: | else: | ||||
spans.append((label, [idx, idx])) | spans.append((label, [idx, idx])) | ||||
prev_bmes_tag = bmes_tag | |||||
prev_bioes_tag = bieso_tag | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) | return [(span[0], (span[1][0], span[1][1] + 1)) | ||||
for span in spans | for span in spans | ||||
if span[0] not in ignore_labels | if span[0] not in ignore_labels | ||||
@@ -6,7 +6,7 @@ from collections import defaultdict | |||||
import torch | import torch | ||||
from . import Batch | |||||
from . import DataSetIter | |||||
from . import DataSet | from . import DataSet | ||||
from . import SequentialSampler | from . import SequentialSampler | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -44,8 +44,7 @@ class Predictor(object): | |||||
self.network.eval() | self.network.eval() | ||||
batch_output = defaultdict(list) | batch_output = defaultdict(list) | ||||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, | |||||
prefetch=False) | |||||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
if hasattr(self.network, "predict"): | if hasattr(self.network, "predict"): | ||||
predict_func = self.network.predict | predict_func = self.network.predict | ||||
@@ -37,7 +37,7 @@ import warnings | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from .batch import Batch | |||||
from .batch import BatchIter, DataSetIter | |||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .metrics import _prepare_metrics | from .metrics import _prepare_metrics | ||||
from .sampler import SequentialSampler | from .sampler import SequentialSampler | ||||
@@ -82,7 +82,7 @@ class Tester(object): | |||||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | ||||
""" | """ | ||||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): | |||||
super(Tester, self).__init__() | super(Tester, self).__init__() | ||||
if not isinstance(data, DataSet): | if not isinstance(data, DataSet): | ||||
@@ -96,6 +96,14 @@ class Tester(object): | |||||
self._model = _move_model_to_device(model, device=device) | self._model = _move_model_to_device(model, device=device) | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.verbose = verbose | self.verbose = verbose | ||||
if isinstance(data, DataSet): | |||||
self.data_iterator = DataSetIter( | |||||
dataset=data, batch_size=batch_size, num_workers=num_workers) | |||||
elif isinstance(data, BatchIter): | |||||
self.data_iterator = data | |||||
else: | |||||
raise TypeError("data type {} not support".format(type(data))) | |||||
# 如果是DataParallel将没有办法使用predict方法 | # 如果是DataParallel将没有办法使用predict方法 | ||||
if isinstance(self._model, nn.DataParallel): | if isinstance(self._model, nn.DataParallel): | ||||
@@ -124,7 +132,7 @@ class Tester(object): | |||||
self._model_device = _get_model_device(self._model) | self._model_device = _get_model_device(self._model) | ||||
network = self._model | network = self._model | ||||
self._mode(network, is_test=True) | self._mode(network, is_test=True) | ||||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||||
data_iterator = self.data_iterator | |||||
eval_results = {} | eval_results = {} | ||||
try: | try: | ||||
with torch.no_grad(): | with torch.no_grad(): | ||||
@@ -311,8 +311,9 @@ try: | |||||
from tqdm.auto import tqdm | from tqdm.auto import tqdm | ||||
except: | except: | ||||
from .utils import _pseudo_tqdm as tqdm | from .utils import _pseudo_tqdm as tqdm | ||||
import warnings | |||||
from .batch import Batch | |||||
from .batch import DataSetIter, BatchIter | |||||
from .callback import CallbackManager, CallbackException | from .callback import CallbackManager, CallbackException | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .losses import _prepare_losser | from .losses import _prepare_losser | ||||
@@ -320,7 +321,6 @@ from .metrics import _prepare_metrics | |||||
from .optimizer import Optimizer | from .optimizer import Optimizer | ||||
from .sampler import Sampler | from .sampler import Sampler | ||||
from .sampler import RandomSampler | from .sampler import RandomSampler | ||||
from .sampler import SequentialSampler | |||||
from .tester import Tester | from .tester import Tester | ||||
from .utils import _CheckError | from .utils import _CheckError | ||||
from .utils import _build_args | from .utils import _build_args | ||||
@@ -351,6 +351,8 @@ class Trainer(object): | |||||
:param int batch_size: 训练和验证的时候的batch大小。 | :param int batch_size: 训练和验证的时候的batch大小。 | ||||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | ||||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | ||||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | ||||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | ||||
:param int n_epochs: 需要优化迭代多少次。 | :param int n_epochs: 需要优化迭代多少次。 | ||||
@@ -367,7 +369,6 @@ class Trainer(object): | |||||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | ||||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | ||||
保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | ||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 | |||||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | ||||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | :param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | ||||
的计算位置进行管理。支持以下的输入: | 的计算位置进行管理。支持以下的输入: | ||||
@@ -394,16 +395,17 @@ class Trainer(object): | |||||
""" | """ | ||||
def __init__(self, train_data, model, optimizer=None, loss=None, | def __init__(self, train_data, model, optimizer=None, loss=None, | ||||
batch_size=32, sampler=None, update_every=1, | |||||
n_epochs=10, print_every=5, | |||||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||||
num_workers=0, n_epochs=10, print_every=5, | |||||
dev_data=None, metrics=None, metric_key=None, | dev_data=None, metrics=None, metric_key=None, | ||||
validate_every=-1, save_path=None, | |||||
prefetch=False, use_tqdm=True, device=None, | |||||
callbacks=None, | |||||
check_code_level=0): | |||||
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, | |||||
callbacks=None, check_code_level=0): | |||||
if prefetch and num_workers==0: | |||||
num_workers = 1 | |||||
if prefetch: | |||||
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") | |||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | |||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | ||||
@@ -430,17 +432,27 @@ class Trainer(object): | |||||
if metric_key is not None: | if metric_key is not None: | ||||
self.increase_better = False if metric_key[0] == "-" else True | self.increase_better = False if metric_key[0] == "-" else True | ||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | ||||
elif len(metrics) > 0: | |||||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||||
else: | |||||
self.metric_key = None | |||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
# sampler check | # sampler check | ||||
if sampler is not None and not isinstance(sampler, Sampler): | if sampler is not None and not isinstance(sampler, Sampler): | ||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | ||||
if sampler is None: | |||||
sampler = RandomSampler() | |||||
if isinstance(train_data, DataSet): | |||||
self.data_iterator = DataSetIter( | |||||
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||||
elif isinstance(train_data, BatchIter): | |||||
self.data_iterator = train_data | |||||
else: | |||||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||||
if check_code_level > -1: | |||||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | ||||
metric_key=metric_key, check_level=check_code_level, | metric_key=metric_key, check_level=check_code_level, | ||||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | ||||
@@ -460,8 +472,6 @@ class Trainer(object): | |||||
self.best_dev_epoch = None | self.best_dev_epoch = None | ||||
self.best_dev_step = None | self.best_dev_step = None | ||||
self.best_dev_perf = None | self.best_dev_perf = None | ||||
self.sampler = sampler if sampler is not None else RandomSampler() | |||||
self.prefetch = prefetch | |||||
self.n_steps = (len(self.train_data) // self.batch_size + int( | self.n_steps = (len(self.train_data) // self.batch_size + int( | ||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
@@ -493,7 +503,7 @@ class Trainer(object): | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True, on_exception='auto'): | def train(self, load_best_model=True, on_exception='auto'): | ||||
""" | """ | ||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
@@ -572,8 +582,7 @@ class Trainer(object): | |||||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | ||||
self.pbar = pbar | self.pbar = pbar | ||||
avg_loss = 0 | avg_loss = 0 | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
data_iterator = self.data_iterator | |||||
self.batch_per_epoch = data_iterator.num_batches | self.batch_per_epoch = data_iterator.num_batches | ||||
for epoch in range(1, self.n_epochs + 1): | for epoch in range(1, self.n_epochs + 1): | ||||
self.epoch = epoch | self.epoch = epoch | ||||
@@ -746,7 +755,9 @@ class Trainer(object): | |||||
:return bool value: True means current results on dev set is the best. | :return bool value: True means current results on dev set is the best. | ||||
""" | """ | ||||
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||||
indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||||
if self.metric_key is None: | |||||
self.metric_key = indicator | |||||
is_better = True | is_better = True | ||||
if self.best_metric_indicator is None: | if self.best_metric_indicator is None: | ||||
# first-time validation | # first-time validation | ||||
@@ -785,15 +796,34 @@ def _get_value_info(_dict): | |||||
strs.append(_str) | strs.append(_str) | ||||
return strs | return strs | ||||
from numbers import Number | |||||
from .batch import _to_tensor | |||||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | ||||
dev_data=None, metric_key=None, | dev_data=None, metric_key=None, | ||||
check_level=0): | check_level=0): | ||||
# check get_loss 方法 | # check get_loss 方法 | ||||
model_devcie = model.parameters().__next__().device | |||||
model_devcie = _get_model_device(model=model) | |||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||||
def _iter(): | |||||
start_idx = 0 | |||||
while start_idx<len(dataset): | |||||
batch_x = {} | |||||
batch_y = {} | |||||
for field_name, field in dataset.get_all_fields().items(): | |||||
indices = list(range(start_idx, min(start_idx+batch_size, len(dataset)))) | |||||
if field.is_target or field.is_input: | |||||
batch = field.get(indices) | |||||
if field.dtype is not None and \ | |||||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||||
batch, _ = _to_tensor(batch, field.dtype) | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | |||||
if field.is_input: | |||||
batch_x[field_name] = batch | |||||
yield (batch_x, batch_y) | |||||
start_idx += batch_size | |||||
for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | ||||
# forward check | # forward check | ||||
if batch_count == 0: | if batch_count == 0: | ||||
@@ -861,26 +891,16 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||||
loss, metrics = metrics | loss, metrics = metrics | ||||
if isinstance(metrics, dict): | if isinstance(metrics, dict): | ||||
if len(metrics) == 1: | |||||
# only single metric, just use it | |||||
metric_dict = list(metrics.values())[0] | |||||
metrics_name = list(metrics.keys())[0] | |||||
else: | |||||
metrics_name = metric_list[0].__class__.__name__ | |||||
if metrics_name not in metrics: | |||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||||
metric_dict = metrics[metrics_name] | |||||
metric_dict = list(metrics.values())[0] # 取第一个metric | |||||
if len(metric_dict) == 1: | |||||
if metric_key is None: | |||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | ||||
elif len(metric_dict) > 1 and metric_key is None: | |||||
raise RuntimeError( | |||||
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") | |||||
else: | else: | ||||
# metric_key is set | # metric_key is set | ||||
if metric_key not in metric_dict: | if metric_key not in metric_dict: | ||||
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | ||||
indicator_val = metric_dict[metric_key] | indicator_val = metric_dict[metric_key] | ||||
indicator = metric_key | |||||
else: | else: | ||||
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | ||||
return indicator_val | |||||
return indicator, indicator_val |
@@ -124,6 +124,14 @@ class DataInfo: | |||||
self.embeddings = embeddings or {} | self.embeddings = embeddings or {} | ||||
self.datasets = datasets or {} | self.datasets = datasets or {} | ||||
def __repr__(self): | |||||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||||
for name, dataset in self.datasets.items(): | |||||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||||
for name, vocab in self.vocabs.items(): | |||||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||||
return _str | |||||
class DataSetLoader: | class DataSetLoader: | ||||
""" | """ | ||||
@@ -115,7 +115,8 @@ class ConllLoader(DataSetLoader): | |||||
""" | """ | ||||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | ||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html | |||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||||
该符号在conll 2003中被用为文档分割符。 | |||||
列号从0开始, 每列对应内容为:: | 列号从0开始, 每列对应内容为:: | ||||
@@ -90,11 +90,12 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
return sample | return sample | ||||
with open(path, 'r', encoding=encoding) as f: | with open(path, 'r', encoding=encoding) as f: | ||||
sample = [] | sample = [] | ||||
start = next(f) | |||||
if '-DOCSTART-' not in start: | |||||
start = next(f).strip() | |||||
if '-DOCSTART-' not in start and start!='': | |||||
sample.append(start.split()) | sample.append(start.split()) | ||||
for line_idx, line in enumerate(f, 1): | for line_idx, line in enumerate(f, 1): | ||||
if line.startswith('\n'): | |||||
line = line.strip() | |||||
if line=='': | |||||
if len(sample): | if len(sample): | ||||
try: | try: | ||||
res = parse_conll(sample) | res = parse_conll(sample) | ||||
@@ -107,7 +108,8 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
elif line.startswith('#'): | elif line.startswith('#'): | ||||
continue | continue | ||||
else: | else: | ||||
sample.append(line.split()) | |||||
if not line.startswith('-DOCSTART-'): | |||||
sample.append(line.split()) | |||||
if len(sample) > 0: | if len(sample) > 0: | ||||
try: | try: | ||||
res = parse_conll(sample) | res = parse_conll(sample) | ||||
@@ -115,4 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
return | return | ||||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||||
print('invalid instance at line: {}'.format(line_idx)) | |||||
raise e |
@@ -9,7 +9,7 @@ from torch import nn | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | |||||
""" | """ | ||||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | ||||
@@ -17,7 +17,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | :param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | ||||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | ||||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | |||||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | |||||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | :param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | ||||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | ||||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | ||||
@@ -58,7 +58,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | ||||
""" | """ | ||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
:param str from_label: 比如"PER", "LOC"等label | :param str from_label: 比如"PER", "LOC"等label | ||||
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | :param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | ||||
@@ -134,9 +134,19 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||||
return to_tag in ['b', 's', 'end', 'o'] | return to_tag in ['b', 's', 'end', 'o'] | ||||
else: | else: | ||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | ||||
elif encoding_type == 'bioes': | |||||
if from_tag == 'start': | |||||
return to_tag in ['b', 's', 'o'] | |||||
elif from_tag == 'b': | |||||
return to_tag in ['i', 'e'] and from_label == to_label | |||||
elif from_tag == 'i': | |||||
return to_tag in ['i', 'e'] and from_label == to_label | |||||
elif from_tag in ['e', 's', 'o']: | |||||
return to_tag in ['b', 's', 'end', 'o'] | |||||
else: | |||||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) | |||||
else: | else: | ||||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||||
raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) | |||||
class ConditionalRandomField(nn.Module): | class ConditionalRandomField(nn.Module): | ||||
@@ -18,7 +18,8 @@ __all__ = [ | |||||
"VarLSTM", | "VarLSTM", | ||||
"VarGRU" | "VarGRU" | ||||
] | ] | ||||
from .bert import BertModel | |||||
from ._bert import BertModel | |||||
from .bert import BertWordPieceEncoder | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | ||||
from .conv_maxpool import ConvMaxpool | from .conv_maxpool import ConvMaxpool | ||||
from .embedding import Embedding | from .embedding import Embedding | ||||
@@ -6,18 +6,399 @@ | |||||
""" | """ | ||||
import torch | |||||
from torch import nn | |||||
from ... import Vocabulary | from ... import Vocabulary | ||||
import collections | import collections | ||||
import os | |||||
import unicodedata | import unicodedata | ||||
from ...io.file_utils import _get_base_url, cached_path | from ...io.file_utils import _get_base_url, cached_path | ||||
from .bert import BertModel | |||||
import numpy as np | import numpy as np | ||||
from itertools import chain | from itertools import chain | ||||
import copy | |||||
import json | |||||
import math | |||||
import os | |||||
import torch | |||||
from torch import nn | |||||
CONFIG_FILE = 'bert_config.json' | |||||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||||
def gelu(x): | |||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||||
def swish(x): | |||||
return x * torch.sigmoid(x) | |||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||||
class BertLayerNorm(nn.Module): | |||||
def __init__(self, hidden_size, eps=1e-12): | |||||
super(BertLayerNorm, self).__init__() | |||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
self.variance_epsilon = eps | |||||
def forward(self, x): | |||||
u = x.mean(-1, keepdim=True) | |||||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
return self.weight * x + self.bias | |||||
class BertEmbeddings(nn.Module): | |||||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||||
super(BertEmbeddings, self).__init__() | |||||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||||
# any TensorFlow checkpoint file | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, input_ids, token_type_ids=None): | |||||
seq_length = input_ids.size(1) | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
words_embeddings = self.word_embeddings(input_ids) | |||||
position_embeddings = self.position_embeddings(position_ids) | |||||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||||
embeddings = self.LayerNorm(embeddings) | |||||
embeddings = self.dropout(embeddings) | |||||
return embeddings | |||||
class BertSelfAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||||
super(BertSelfAttention, self).__init__() | |||||
if hidden_size % num_attention_heads != 0: | |||||
raise ValueError( | |||||
"The hidden size (%d) is not a multiple of the number of attention " | |||||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||||
self.num_attention_heads = num_attention_heads | |||||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||||
def transpose_for_scores(self, x): | |||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||||
x = x.view(*new_x_shape) | |||||
return x.permute(0, 2, 1, 3) | |||||
def forward(self, hidden_states, attention_mask): | |||||
mixed_query_layer = self.query(hidden_states) | |||||
mixed_key_layer = self.key(hidden_states) | |||||
mixed_value_layer = self.value(hidden_states) | |||||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||||
attention_scores = attention_scores + attention_mask | |||||
# Normalize the attention scores to probabilities. | |||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||||
# This is actually dropping out entire tokens to attend to, which might | |||||
# seem a bit unusual, but is taken from the original Transformer paper. | |||||
attention_probs = self.dropout(attention_probs) | |||||
context_layer = torch.matmul(attention_probs, value_layer) | |||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||||
context_layer = context_layer.view(*new_context_layer_shape) | |||||
return context_layer | |||||
class BertSelfOutput(nn.Module): | |||||
def __init__(self, hidden_size, hidden_dropout_prob): | |||||
super(BertSelfOutput, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||||
super(BertAttention, self).__init__() | |||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||||
def forward(self, input_tensor, attention_mask): | |||||
self_output = self.self(input_tensor, attention_mask) | |||||
attention_output = self.output(self_output, input_tensor) | |||||
return attention_output | |||||
class BertIntermediate(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||||
super(BertIntermediate, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||||
if isinstance(hidden_act, str) else hidden_act | |||||
def forward(self, hidden_states): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.intermediate_act_fn(hidden_states) | |||||
return hidden_states | |||||
class BertOutput(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||||
super(BertOutput, self).__init__() | |||||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertLayer(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertLayer, self).__init__() | |||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob) | |||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||||
def forward(self, hidden_states, attention_mask): | |||||
attention_output = self.attention(hidden_states, attention_mask) | |||||
intermediate_output = self.intermediate(attention_output) | |||||
layer_output = self.output(intermediate_output, attention_output) | |||||
return layer_output | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertEncoder, self).__init__() | |||||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act) | |||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||||
all_encoder_layers = [] | |||||
for layer_module in self.layer: | |||||
hidden_states = layer_module(hidden_states, attention_mask) | |||||
if output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
if not output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
return all_encoder_layers | |||||
class BertPooler(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(BertPooler, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.activation = nn.Tanh() | |||||
def forward(self, hidden_states): | |||||
# We "pool" the model by simply taking the hidden state corresponding | |||||
# to the first token. | |||||
first_token_tensor = hidden_states[:, 0] | |||||
pooled_output = self.dense(first_token_tensor) | |||||
pooled_output = self.activation(pooled_output) | |||||
return pooled_output | |||||
class BertModel(nn.Module): | |||||
"""BERT(Bidirectional Embedding Representations from Transformers). | |||||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||||
sources:: | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
用预训练权重矩阵来建立BERT模型:: | |||||
model = BertModel.from_pretrained("path/to/weights/directory") | |||||
用随机初始化权重矩阵来建立BERT模型:: | |||||
model = BertModel() | |||||
:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||||
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||||
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||||
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||||
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||||
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||||
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||||
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||||
:param int max_position_embeddings: 最大的序列长度,默认值为512, | |||||
:param int type_vocab_size: 最大segment数量,默认值为2 | |||||
:param int initializer_range: 初始化权重范围,默认值为0.02 | |||||
""" | |||||
def __init__(self, vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02): | |||||
super(BertModel, self).__init__() | |||||
self.hidden_size = hidden_size | |||||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||||
type_vocab_size, hidden_dropout_prob) | |||||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||||
hidden_act) | |||||
self.pooler = BertPooler(hidden_size) | |||||
self.initializer_range = initializer_range | |||||
self.apply(self.init_bert_weights) | |||||
def init_bert_weights(self, module): | |||||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||||
# Slightly different from the TF version which uses truncated_normal for initialization | |||||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||||
elif isinstance(module, BertLayerNorm): | |||||
module.bias.data.zero_() | |||||
module.weight.data.fill_(1.0) | |||||
if isinstance(module, nn.Linear) and module.bias is not None: | |||||
module.bias.data.zero_() | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||||
if attention_mask is None: | |||||
attention_mask = torch.ones_like(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
# We create a 3D attention mask from a 2D tensor mask. | |||||
# Sizes are [batch_size, 1, 1, to_seq_length] | |||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||||
# this attention mask is more simple than the triangular masking of causal attention | |||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||||
# masked positions, this operation will create a tensor which is 0.0 for | |||||
# positions we want to attend and -10000.0 for masked positions. | |||||
# Since we are adding it to the raw scores before the softmax, this is | |||||
# effectively the same as removing these entirely. | |||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||||
encoded_layers = self.encoder(embedding_output, | |||||
extended_attention_mask, | |||||
output_all_encoded_layers=output_all_encoded_layers) | |||||
sequence_output = encoded_layers[-1] | |||||
pooled_output = self.pooler(sequence_output) | |||||
if not output_all_encoded_layers: | |||||
encoded_layers = encoded_layers[-1] | |||||
return encoded_layers, pooled_output | |||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||||
# Load config | |||||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||||
config = json.load(open(config_file, "r")) | |||||
# config = BertConfig.from_json_file(config_file) | |||||
# logger.info("Model config {}".format(config)) | |||||
# Instantiate model. | |||||
model = cls(*inputs, **config, **kwargs) | |||||
if state_dict is None: | |||||
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||||
state_dict = torch.load(weights_path) | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if 'gamma' in key: | |||||
new_key = key.replace('gamma', 'weight') | |||||
if 'beta' in key: | |||||
new_key = key.replace('beta', 'bias') | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
missing_keys = [] | |||||
unexpected_keys = [] | |||||
error_msgs = [] | |||||
# copy state_dict so _load_from_state_dict can modify it | |||||
metadata = getattr(state_dict, '_metadata', None) | |||||
state_dict = state_dict.copy() | |||||
if metadata is not None: | |||||
state_dict._metadata = metadata | |||||
def load(module, prefix=''): | |||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||||
module._load_from_state_dict( | |||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||||
for name, child in module._modules.items(): | |||||
if child is not None: | |||||
load(child, prefix + name + '.') | |||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||||
if len(missing_keys) > 0: | |||||
print("Weights of {} not initialized from pretrained model: {}".format( | |||||
model.__class__.__name__, missing_keys)) | |||||
if len(unexpected_keys) > 0: | |||||
print("Weights from pretrained model not used in {}: {}".format( | |||||
model.__class__.__name__, unexpected_keys)) | |||||
return model | |||||
def whitespace_tokenize(text): | def whitespace_tokenize(text): | ||||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" | """Runs basic whitespace cleaning and splitting on a piece of text.""" | ||||
@@ -547,79 +928,3 @@ class _WordPieceBertModel(nn.Module): | |||||
outputs[l_index] = bert_outputs[l] | outputs[l_index] = bert_outputs[l] | ||||
return outputs | return outputs | ||||
class BertWordPieceEncoder(nn.Module): | |||||
""" | |||||
可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 | |||||
:param vocab: Vocabulary. | |||||
:param model_dir_or_name: | |||||
:param layers: | |||||
:param requires_grad: | |||||
""" | |||||
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', | |||||
requires_grad:bool=False): | |||||
super().__init__() | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
# TODO 修改 | |||||
PRETRAINED_BERT_MODEL_DIR = {'en-base': 'bert_en-80f95ea7.tar.gz', | |||||
'cn': 'elmo_cn.zip'} | |||||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(model_dir_or_name): | |||||
model_dir = model_dir_or_name | |||||
else: | |||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||||
self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||||
self.requires_grad = requires_grad | |||||
@property | |||||
def requires_grad(self): | |||||
""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||||
:return: | |||||
""" | |||||
requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||||
if len(requires_grads)==1: | |||||
return requires_grads.pop() | |||||
else: | |||||
return None | |||||
@requires_grad.setter | |||||
def requires_grad(self, value): | |||||
for name, param in self.named_parameters(): | |||||
param.requires_grad = value | |||||
@property | |||||
def embed_size(self): | |||||
return self._embed_size | |||||
def index_datasets(self, *datasets): | |||||
""" | |||||
对datasets进行word piece的index。 | |||||
Example:: | |||||
:param datasets: | |||||
:return: | |||||
""" | |||||
self.model.index_dataset(*datasets) | |||||
def forward(self, words, token_type_ids=None): | |||||
""" | |||||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||||
删除这两个表示。 | |||||
:param words: batch_size x max_len | |||||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||||
""" | |||||
outputs = self.model(words, token_type_ids) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return outputs |
@@ -1,378 +1,95 @@ | |||||
""" | |||||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||||
""" | |||||
import copy | |||||
import json | |||||
import math | |||||
import os | import os | ||||
import torch | |||||
from torch import nn | from torch import nn | ||||
import torch | |||||
from ...core import Vocabulary | |||||
from ...io.file_utils import _get_base_url, cached_path | |||||
from ._bert import _WordPieceBertModel | |||||
CONFIG_FILE = 'bert_config.json' | |||||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||||
def gelu(x): | |||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||||
def swish(x): | |||||
return x * torch.sigmoid(x) | |||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||||
class BertLayerNorm(nn.Module): | |||||
def __init__(self, hidden_size, eps=1e-12): | |||||
super(BertLayerNorm, self).__init__() | |||||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||||
self.variance_epsilon = eps | |||||
def forward(self, x): | |||||
u = x.mean(-1, keepdim=True) | |||||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||||
return self.weight * x + self.bias | |||||
class BertEmbeddings(nn.Module): | |||||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||||
super(BertEmbeddings, self).__init__() | |||||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||||
# any TensorFlow checkpoint file | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, input_ids, token_type_ids=None): | |||||
seq_length = input_ids.size(1) | |||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
words_embeddings = self.word_embeddings(input_ids) | |||||
position_embeddings = self.position_embeddings(position_ids) | |||||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||||
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||||
embeddings = self.LayerNorm(embeddings) | |||||
embeddings = self.dropout(embeddings) | |||||
return embeddings | |||||
class BertSelfAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||||
super(BertSelfAttention, self).__init__() | |||||
if hidden_size % num_attention_heads != 0: | |||||
raise ValueError( | |||||
"The hidden size (%d) is not a multiple of the number of attention " | |||||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||||
self.num_attention_heads = num_attention_heads | |||||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||||
def transpose_for_scores(self, x): | |||||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||||
x = x.view(*new_x_shape) | |||||
return x.permute(0, 2, 1, 3) | |||||
def forward(self, hidden_states, attention_mask): | |||||
mixed_query_layer = self.query(hidden_states) | |||||
mixed_key_layer = self.key(hidden_states) | |||||
mixed_value_layer = self.value(hidden_states) | |||||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||||
attention_scores = attention_scores + attention_mask | |||||
# Normalize the attention scores to probabilities. | |||||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||||
# This is actually dropping out entire tokens to attend to, which might | |||||
# seem a bit unusual, but is taken from the original Transformer paper. | |||||
attention_probs = self.dropout(attention_probs) | |||||
context_layer = torch.matmul(attention_probs, value_layer) | |||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||||
context_layer = context_layer.view(*new_context_layer_shape) | |||||
return context_layer | |||||
class BertSelfOutput(nn.Module): | |||||
def __init__(self, hidden_size, hidden_dropout_prob): | |||||
super(BertSelfOutput, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertAttention(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||||
super(BertAttention, self).__init__() | |||||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||||
def forward(self, input_tensor, attention_mask): | |||||
self_output = self.self(input_tensor, attention_mask) | |||||
attention_output = self.output(self_output, input_tensor) | |||||
return attention_output | |||||
class BertIntermediate(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||||
super(BertIntermediate, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||||
if isinstance(hidden_act, str) else hidden_act | |||||
def forward(self, hidden_states): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.intermediate_act_fn(hidden_states) | |||||
return hidden_states | |||||
class BertOutput(nn.Module): | |||||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||||
super(BertOutput, self).__init__() | |||||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||||
def forward(self, hidden_states, input_tensor): | |||||
hidden_states = self.dense(hidden_states) | |||||
hidden_states = self.dropout(hidden_states) | |||||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||||
return hidden_states | |||||
class BertLayer(nn.Module): | |||||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertLayer, self).__init__() | |||||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob) | |||||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||||
def forward(self, hidden_states, attention_mask): | |||||
attention_output = self.attention(hidden_states, attention_mask) | |||||
intermediate_output = self.intermediate(attention_output) | |||||
layer_output = self.output(intermediate_output, attention_output) | |||||
return layer_output | |||||
class BertEncoder(nn.Module): | |||||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||||
hidden_dropout_prob, | |||||
intermediate_size, hidden_act): | |||||
super(BertEncoder, self).__init__() | |||||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||||
intermediate_size, hidden_act) | |||||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||||
all_encoder_layers = [] | |||||
for layer_module in self.layer: | |||||
hidden_states = layer_module(hidden_states, attention_mask) | |||||
if output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
if not output_all_encoded_layers: | |||||
all_encoder_layers.append(hidden_states) | |||||
return all_encoder_layers | |||||
class BertPooler(nn.Module): | |||||
def __init__(self, hidden_size): | |||||
super(BertPooler, self).__init__() | |||||
self.dense = nn.Linear(hidden_size, hidden_size) | |||||
self.activation = nn.Tanh() | |||||
def forward(self, hidden_states): | |||||
# We "pool" the model by simply taking the hidden state corresponding | |||||
# to the first token. | |||||
first_token_tensor = hidden_states[:, 0] | |||||
pooled_output = self.dense(first_token_tensor) | |||||
pooled_output = self.activation(pooled_output) | |||||
return pooled_output | |||||
class BertModel(nn.Module): | |||||
"""BERT(Bidirectional Embedding Representations from Transformers). | |||||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||||
sources:: | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||||
用预训练权重矩阵来建立BERT模型:: | |||||
model = BertModel.from_pretrained("path/to/weights/directory") | |||||
用随机初始化权重矩阵来建立BERT模型:: | |||||
model = BertModel() | |||||
:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||||
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||||
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||||
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||||
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||||
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||||
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||||
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||||
:param int max_position_embeddings: 最大的序列长度,默认值为512, | |||||
:param int type_vocab_size: 最大segment数量,默认值为2 | |||||
:param int initializer_range: 初始化权重范围,默认值为0.02 | |||||
class BertWordPieceEncoder(nn.Module): | |||||
""" | """ | ||||
可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。 | |||||
def __init__(self, vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02): | |||||
super(BertModel, self).__init__() | |||||
self.hidden_size = hidden_size | |||||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||||
type_vocab_size, hidden_dropout_prob) | |||||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||||
hidden_act) | |||||
self.pooler = BertPooler(hidden_size) | |||||
self.initializer_range = initializer_range | |||||
self.apply(self.init_bert_weights) | |||||
def init_bert_weights(self, module): | |||||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||||
# Slightly different from the TF version which uses truncated_normal for initialization | |||||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||||
elif isinstance(module, BertLayerNorm): | |||||
module.bias.data.zero_() | |||||
module.weight.data.fill_(1.0) | |||||
if isinstance(module, nn.Linear) and module.bias is not None: | |||||
module.bias.data.zero_() | |||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||||
if attention_mask is None: | |||||
attention_mask = torch.ones_like(input_ids) | |||||
if token_type_ids is None: | |||||
token_type_ids = torch.zeros_like(input_ids) | |||||
# We create a 3D attention mask from a 2D tensor mask. | |||||
# Sizes are [batch_size, 1, 1, to_seq_length] | |||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||||
# this attention mask is more simple than the triangular masking of causal attention | |||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||||
# masked positions, this operation will create a tensor which is 0.0 for | |||||
# positions we want to attend and -10000.0 for masked positions. | |||||
# Since we are adding it to the raw scores before the softmax, this is | |||||
# effectively the same as removing these entirely. | |||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||||
encoded_layers = self.encoder(embedding_output, | |||||
extended_attention_mask, | |||||
output_all_encoded_layers=output_all_encoded_layers) | |||||
sequence_output = encoded_layers[-1] | |||||
pooled_output = self.pooler(sequence_output) | |||||
if not output_all_encoded_layers: | |||||
encoded_layers = encoded_layers[-1] | |||||
return encoded_layers, pooled_output | |||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||||
# Load config | |||||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||||
config = json.load(open(config_file, "r")) | |||||
# config = BertConfig.from_json_file(config_file) | |||||
# logger.info("Model config {}".format(config)) | |||||
# Instantiate model. | |||||
model = cls(*inputs, **config, **kwargs) | |||||
if state_dict is None: | |||||
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||||
state_dict = torch.load(weights_path) | |||||
old_keys = [] | |||||
new_keys = [] | |||||
for key in state_dict.keys(): | |||||
new_key = None | |||||
if 'gamma' in key: | |||||
new_key = key.replace('gamma', 'weight') | |||||
if 'beta' in key: | |||||
new_key = key.replace('beta', 'bias') | |||||
if new_key: | |||||
old_keys.append(key) | |||||
new_keys.append(new_key) | |||||
for old_key, new_key in zip(old_keys, new_keys): | |||||
state_dict[new_key] = state_dict.pop(old_key) | |||||
missing_keys = [] | |||||
unexpected_keys = [] | |||||
error_msgs = [] | |||||
# copy state_dict so _load_from_state_dict can modify it | |||||
metadata = getattr(state_dict, '_metadata', None) | |||||
state_dict = state_dict.copy() | |||||
if metadata is not None: | |||||
state_dict._metadata = metadata | |||||
def load(module, prefix=''): | |||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||||
module._load_from_state_dict( | |||||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||||
for name, child in module._modules.items(): | |||||
if child is not None: | |||||
load(child, prefix + name + '.') | |||||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||||
if len(missing_keys) > 0: | |||||
print("Weights of {} not initialized from pretrained model: {}".format( | |||||
model.__class__.__name__, missing_keys)) | |||||
if len(unexpected_keys) > 0: | |||||
print("Weights from pretrained model not used in {}: {}".format( | |||||
model.__class__.__name__, unexpected_keys)) | |||||
return model | |||||
:param fastNLP.Vocabulary vocab: 词表 | |||||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | |||||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||||
:param bool requires_grad: 是否需要gradient。 | |||||
""" | |||||
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1', | |||||
requires_grad:bool=False): | |||||
super().__init__() | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||||
} | |||||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url) | |||||
# 检查是否存在 | |||||
elif os.path.isdir(model_dir_or_name): | |||||
model_dir = model_dir_or_name | |||||
else: | |||||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||||
self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers) | |||||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||||
self.requires_grad = requires_grad | |||||
@property | |||||
def requires_grad(self): | |||||
""" | |||||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||||
:return: | |||||
""" | |||||
requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||||
if len(requires_grads)==1: | |||||
return requires_grads.pop() | |||||
else: | |||||
return None | |||||
@requires_grad.setter | |||||
def requires_grad(self, value): | |||||
for name, param in self.named_parameters(): | |||||
param.requires_grad = value | |||||
@property | |||||
def embed_size(self): | |||||
return self._embed_size | |||||
def index_datasets(self, *datasets): | |||||
""" | |||||
根据datasets中的'words'列对datasets进行word piece的index。 | |||||
Example:: | |||||
:param datasets: | |||||
:return: | |||||
""" | |||||
self.model.index_dataset(*datasets) | |||||
def forward(self, words, token_type_ids=None): | |||||
""" | |||||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||||
删除这两个表示。 | |||||
:param words: batch_size x max_len | |||||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||||
""" | |||||
outputs = self.model(words, token_type_ids) | |||||
outputs = torch.cat([*outputs], dim=-1) | |||||
return outputs |
@@ -15,7 +15,7 @@ from ...io.file_utils import cached_path, _get_base_url | |||||
from ._bert import _WordBertModel | from ._bert import _WordBertModel | ||||
from typing import List | from typing import List | ||||
from ... import DataSet, Batch, SequentialSampler | |||||
from ... import DataSet, DataSetIter, SequentialSampler | |||||
from ...core.utils import _move_model_to_device, _get_model_device | from ...core.utils import _move_model_to_device, _get_model_device | ||||
@@ -157,7 +157,6 @@ class StaticEmbedding(TokenEmbedding): | |||||
super(StaticEmbedding, self).__init__(vocab) | super(StaticEmbedding, self).__init__(vocab) | ||||
# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | # 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | ||||
PRETRAIN_URL = _get_base_url('static') | |||||
PRETRAIN_STATIC_FILES = { | PRETRAIN_STATIC_FILES = { | ||||
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | 'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | ||||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | ||||
@@ -170,6 +169,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
# 得到cache_path | # 得到cache_path | ||||
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | ||||
PRETRAIN_URL = _get_base_url('static') | |||||
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_path = cached_path(model_url) | model_path = cached_path(model_url) | ||||
@@ -234,7 +234,7 @@ class ContextualEmbedding(TokenEmbedding): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for index, dataset in enumerate(datasets): | for index, dataset in enumerate(datasets): | ||||
try: | try: | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), prefetch=False) | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
words = batch_x['words'].to(device) | words = batch_x['words'].to(device) | ||||
words_list = words.tolist() | words_list = words.tolist() | ||||
@@ -325,11 +325,11 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
self.layers = layers | self.layers = layers | ||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
PRETRAIN_URL = _get_base_url('elmo') | |||||
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | ||||
'cn': 'elmo_cn-5e9b34e2.tar.gz'} | 'cn': 'elmo_cn-5e9b34e2.tar.gz'} | ||||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | ||||
PRETRAIN_URL = _get_base_url('elmo') | |||||
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_dir = cached_path(model_url) | model_dir = cached_path(model_url) | ||||
@@ -383,7 +383,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'words_to_chars_embedding' in name: # 这个不能加入到requires_grad中 | if 'words_to_chars_embedding' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
@@ -411,7 +411,6 @@ class BertEmbedding(ContextualEmbedding): | |||||
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | ||||
super(BertEmbedding, self).__init__(vocab) | super(BertEmbedding, self).__init__(vocab) | ||||
# 根据model_dir_or_name检查是否存在并下载 | # 根据model_dir_or_name检查是否存在并下载 | ||||
PRETRAIN_URL = _get_base_url('bert') | |||||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | ||||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | 'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | ||||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | 'en-base-cased': 'bert-base-cased-f89bfe08.zip', | ||||
@@ -427,6 +426,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
} | } | ||||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | ||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | ||||
model_url = PRETRAIN_URL + model_name | model_url = PRETRAIN_URL + model_name | ||||
model_dir = cached_path(model_url) | model_dir = cached_path(model_url) | ||||
@@ -478,7 +478,7 @@ class BertEmbedding(ContextualEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 | if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
@@ -566,6 +566,7 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
for i in range(len(kernel_sizes))]) | for i in range(len(kernel_sizes))]) | ||||
self._embed_size = embed_size | self._embed_size = embed_size | ||||
self.fc = nn.Linear(sum(filter_nums), embed_size) | self.fc = nn.Linear(sum(filter_nums), embed_size) | ||||
self.init_param() | |||||
def forward(self, words): | def forward(self, words): | ||||
""" | """ | ||||
@@ -618,9 +619,17 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
def init_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | |||||
continue | |||||
if param.data.dim()>1: | |||||
nn.init.xavier_normal_(param, 1) | |||||
else: | |||||
nn.init.uniform_(param, -1, 1) | |||||
class LSTMCharEmbedding(TokenEmbedding): | class LSTMCharEmbedding(TokenEmbedding): | ||||
""" | """ | ||||
@@ -744,7 +753,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
def requires_grad(self, value): | def requires_grad(self, value): | ||||
for name, param in self.named_parameters(): | for name, param in self.named_parameters(): | ||||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | ||||
pass | |||||
continue | |||||
param.requires_grad = value | param.requires_grad = value | ||||
@@ -35,8 +35,18 @@ class LSTM(nn.Module): | |||||
self.batch_first = batch_first | self.batch_first = batch_first | ||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | ||||
dropout=dropout, bidirectional=bidirectional) | dropout=dropout, bidirectional=bidirectional) | ||||
self.init_param() | |||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def init_param(self): | |||||
for name, param in self.named_parameters(): | |||||
if 'bias_i' in name: | |||||
param.data.fill_(1) | |||||
elif 'bias_h' in name: | |||||
param.data.fill_(0) | |||||
else: | |||||
nn.init.xavier_normal_(param) | |||||
def forward(self, x, seq_len=None, h0=None, c0=None): | def forward(self, x, seq_len=None, h0=None, c0=None): | ||||
""" | """ | ||||
@@ -184,11 +184,8 @@ def train(path): | |||||
m.weight.requires_grad = True | m.weight.requires_grad = True | ||||
# Trainer | # Trainer | ||||
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||||
**train_args.data, | |||||
optimizer=fastNLP.Adam(**optim_args.data), | |||||
save_path=path, | |||||
trainer = Trainer(train_data=train_data, model=model, optimizer=fastNLP.Adam(**optim_args.data), loss=ParserLoss(), | |||||
dev_data=dev_data, metrics=ParserMetric(), metric_key='UAS', save_path=path, | |||||
callbacks=[MyCallback()]) | callbacks=[MyCallback()]) | ||||
# Start training | # Start training | ||||
@@ -89,11 +89,11 @@ def train(train_data_path, dev_data_path, checkpoint=None, save=None): | |||||
model = torch.load(checkpoint) | model = torch.load(checkpoint) | ||||
# call trainer to train | # call trainer to train | ||||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), | |||||
dev_data=dev_data, metric_key="f", | |||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) | |||||
trainer = Trainer(dataset, model, loss=None, n_epochs=20, print_every=10, dev_data=dev_data, | |||||
metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||||
target="truth", | |||||
seq_lens="word_seq_origin_len"), metric_key="f", save_path=save, | |||||
use_tqdm=True) | |||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
# save model & pipeline | # save model & pipeline | ||||
@@ -149,14 +149,10 @@ def train(): | |||||
) if x.requires_grad and x.size(0) != len(word_v)] | ) if x.requires_grad and x.size(0) != len(word_v)] | ||||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | ||||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | {'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | ||||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||||
loss=loss, metrics=metric, metric_key=metric_key, | |||||
optimizer=torch.optim.Adam(optim_cfg), | |||||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000, | |||||
device=device, | |||||
use_tqdm=False, prefetch=False, | |||||
save_path=g_args.log, | |||||
callbacks=[MyCallback()]) | |||||
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||||
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||||
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||||
device=device, callbacks=[MyCallback()]) | |||||
trainer.train() | trainer.train() | ||||
tester = FN.Tester(data=test_data, model=model, metrics=metric, | tester = FN.Tester(data=test_data, model=model, metrics=metric, | ||||
@@ -70,19 +70,10 @@ test_data = preprocess_data(test_data, bert_dirs) | |||||
model = BertForNLI(bert_dir=bert_dirs) | model = BertForNLI(bert_dir=bert_dirs) | ||||
trainer = Trainer( | |||||
train_data=train_data, | |||||
model=model, | |||||
optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, | |||||
n_epochs=4, | |||||
print_every=-1, | |||||
dev_data=dev_data, | |||||
metrics=AccuracyMetric(), | |||||
metric_key='acc', | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1 | |||||
) | |||||
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data, | |||||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1) | |||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
tester = Tester( | tester = Tester( | ||||
@@ -13,7 +13,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
} | } | ||||
如果paths为不合法的,将直接进行raise相应的错误 | 如果paths为不合法的,将直接进行raise相应的错误 | ||||
:param paths: 路径 | |||||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, | |||||
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if isinstance(paths, str): | if isinstance(paths, str): | ||||
@@ -3,7 +3,7 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP import Batch | |||||
from fastNLP import DataSetIter | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP import SequentialSampler | from fastNLP import SequentialSampler | ||||
@@ -57,7 +57,7 @@ class TestCase1(unittest.TestCase): | |||||
dataset = construct_dataset( | dataset = construct_dataset( | ||||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | ||||
dataset.set_target() | dataset.set_target() | ||||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
batch = DataSetIter(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
cnt = 0 | cnt = 0 | ||||
for _, _ in batch: | for _, _ in batch: | ||||
@@ -68,7 +68,7 @@ class TestCase1(unittest.TestCase): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | ||||
self.assertEqual(len(x["x"]), 4) | self.assertEqual(len(x["x"]), 4) | ||||
@@ -81,7 +81,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertEqual(x["x"].shape, (4, 4)) | self.assertEqual(x["x"].shape, (4, 4)) | ||||
self.assertEqual(y["y"].shape, (4, 4)) | self.assertEqual(y["y"].shape, (4, 4)) | ||||
@@ -91,7 +91,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertEqual(x["x"].shape, (4, 4)) | self.assertEqual(x["x"].shape, (4, 4)) | ||||
self.assertEqual(y["y"].shape, (4, 4)) | self.assertEqual(y["y"].shape, (4, 4)) | ||||
@@ -101,7 +101,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | self.assertTrue(isinstance(x["x"], torch.Tensor)) | ||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
@@ -113,7 +113,7 @@ class TestCase1(unittest.TestCase): | |||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | self.assertTrue(isinstance(x["x"], torch.Tensor)) | ||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
@@ -125,7 +125,7 @@ class TestCase1(unittest.TestCase): | |||||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | self.assertTrue(isinstance(x["x"], torch.Tensor)) | ||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
@@ -137,7 +137,7 @@ class TestCase1(unittest.TestCase): | |||||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
ds.set_target("y") | ds.set_target("y") | ||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||||
for x, y in iter: | for x, y in iter: | ||||
print(x, y) | print(x, y) | ||||
@@ -146,7 +146,7 @@ class TestCase1(unittest.TestCase): | |||||
num_samples = 1000 | num_samples = 1000 | ||||
dataset = generate_fake_dataset(num_samples) | dataset = generate_fake_dataset(num_samples) | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
pass | pass | ||||
@@ -40,89 +40,50 @@ class TestCallback(unittest.TestCase): | |||||
def test_gradient_clip(self): | def test_gradient_clip(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=20, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_early_stop(self): | def test_early_stop(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=20, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.01), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[EarlyStopCallback(5)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.01), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[EarlyStopCallback(5)], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_lr_scheduler(self): | def test_lr_scheduler(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=optimizer, | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||||
trainer = Trainer(data_set, model, optimizer=optimizer, loss=BCELoss(pred="predict", target="y"), batch_size=32, | |||||
n_epochs=5, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))], | |||||
check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_KeyBoardInterrupt(self): | def test_KeyBoardInterrupt(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[ControlC(False)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, callbacks=[ControlC(False)], | |||||
check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_LRFinder(self): | def test_LRFinder(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[LRFinder(len(data_set) // 32)]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, | |||||
callbacks=[LRFinder(len(data_set) // 32)], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_TensorboardCallback(self): | def test_TensorboardCallback(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=5, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[TensorboardCallback("loss", "metric")]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||||
callbacks=[TensorboardCallback("loss", "metric")], check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
def test_readonly_property(self): | def test_readonly_property(self): | ||||
@@ -141,16 +102,9 @@ class TestCallback(unittest.TestCase): | |||||
print(self.optimizer) | print(self.optimizer) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=total_epochs, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
callbacks=[MyCallback()]) | |||||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=total_epochs, print_every=50, dev_data=data_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, callbacks=[MyCallback()], | |||||
check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
assert passed_epochs == list(range(1, total_epochs + 1)) | assert passed_epochs == list(range(1, total_epochs + 1)) |
@@ -161,7 +161,15 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_duplicate(self): | |||||
# 0.4.1的潜在bug,不能出现形参重复的情况 | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} | |||||
target_dict = {'targets':torch.zeros(4, 3), 'target': 0} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
def test_seq_len(self): | def test_seq_len(self): | ||||
N = 256 | N = 256 | ||||
seq_len = torch.zeros(N).long() | seq_len = torch.zeros(N).long() | ||||
@@ -46,18 +46,10 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
trainer = Trainer(train_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | |||||
n_epochs=10, | |||||
batch_size=32, | |||||
print_every=50, | |||||
validate_every=-1, | |||||
dev_data=dev_set, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=True, | |||||
save_path=None) | |||||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||||
use_tqdm=True, check_code_level=2) | |||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
@@ -83,10 +75,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model) | |||||
""" | """ | ||||
# 应该获取到的报错提示 | # 应该获取到的报错提示 | ||||
NameError: | NameError: | ||||
@@ -116,12 +105,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
@@ -147,12 +131,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
trainer.train() | trainer.train() | ||||
def test_trainer_suggestion4(self): | def test_trainer_suggestion4(self): | ||||
@@ -175,12 +154,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
def test_trainer_suggestion5(self): | def test_trainer_suggestion5(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
@@ -203,12 +177,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||||
def test_trainer_suggestion6(self): | def test_trainer_suggestion6(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
@@ -233,14 +202,8 @@ class TrainerTestGround(unittest.TestCase): | |||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
loss=CrossEntropyLoss(), | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2) | |||||
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | |||||
metrics=AccuracyMetric(), use_tqdm=False) | |||||
""" | """ | ||||
def test_trainer_multiprocess(self): | def test_trainer_multiprocess(self): | ||||
@@ -130,11 +130,8 @@ class ModelRunner(): | |||||
tester = Tester(data=data, model=model, metrics=metrics, | tester = Tester(data=data, model=model, metrics=metrics, | ||||
batch_size=BATCH_SIZE, verbose=0) | batch_size=BATCH_SIZE, verbose=0) | ||||
before_train = tester.test() | before_train = tester.test() | ||||
trainer = Trainer(model=model, train_data=data, dev_data=None, | |||||
n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, | |||||
loss=loss, | |||||
save_path=None, | |||||
use_tqdm=False) | |||||
trainer = Trainer(train_data=data, model=model, loss=loss, batch_size=BATCH_SIZE, n_epochs=N_EPOCHS, | |||||
dev_data=None, save_path=None, use_tqdm=False) | |||||
trainer.train(load_best_model=False) | trainer.train(load_best_model=False) | ||||
after_train = tester.test() | after_train = tester.test() | ||||
for metric_name, v1 in before_train.items(): | for metric_name, v1 in before_train.items(): | ||||
@@ -1,6 +1,5 @@ | |||||
import unittest | import unittest | ||||
import fastNLP | |||||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | ||||
from .model_runner import * | from .model_runner import * | ||||
@@ -10,14 +10,14 @@ class TestCRF(unittest.TestCase): | |||||
id2label = {0: 'B', 1: 'I', 2:'O'} | id2label = {0: 'B', 1: 'I', 2:'O'} | ||||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | ||||
(2, 4), (3, 0), (3, 2)} | (2, 4), (3, 0), (3, 2)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | ||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | ||||
allowed_transitions(id2label) | |||||
allowed_transitions(id2label, include_start_end=True) | |||||
labels = ['O'] | labels = ['O'] | ||||
for label in ['X', 'Y']: | for label in ['X', 'Y']: | ||||
@@ -27,7 +27,7 @@ class TestCRF(unittest.TestCase): | |||||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | ||||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | (2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | ||||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | (4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||||
labels = [] | labels = [] | ||||
for label in ['X', 'Y']: | for label in ['X', 'Y']: | ||||
@@ -37,7 +37,7 @@ class TestCRF(unittest.TestCase): | |||||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | ||||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | (3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | ||||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | (7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | ||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||||
def test_case2(self): | def test_case2(self): | ||||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | ||||
@@ -60,10 +60,10 @@ class TestTutorial(unittest.TestCase): | |||||
print(test_data[0]) | print(test_data[0]) | ||||
# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.batch import DataSetIter | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
batch_iterator = DataSetIter(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||||
for batch_x, batch_y in batch_iterator: | for batch_x, batch_y in batch_iterator: | ||||
print("batch_x has: ", batch_x) | print("batch_x has: ", batch_x) | ||||
print("batch_y has: ", batch_y) | print("batch_y has: ", batch_y) | ||||
@@ -85,12 +85,8 @@ class TestTutorial(unittest.TestCase): | |||||
# 实例化Trainer,传入模型和数据,进行训练 | # 实例化Trainer,传入模型和数据,进行训练 | ||||
# 先在test_data拟合(确保模型的实现是正确的) | # 先在test_data拟合(确保模型的实现是正确的) | ||||
copy_model = deepcopy(model) | copy_model = deepcopy(model) | ||||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||||
loss=loss, | |||||
metrics=metric, | |||||
save_path=None, | |||||
batch_size=32, | |||||
n_epochs=5) | |||||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, loss=loss, batch_size=32, n_epochs=5, | |||||
dev_data=test_data, metrics=metric, save_path=None) | |||||
overfit_trainer.train() | overfit_trainer.train() | ||||
# 用train_data训练,在test_data验证 | # 用train_data训练,在test_data验证 | ||||
@@ -147,13 +143,8 @@ class TestTutorial(unittest.TestCase): | |||||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | ||||
trainer = Trainer(model=model, | |||||
train_data=train_data, | |||||
dev_data=dev_data, | |||||
loss=CrossEntropyLoss(), | |||||
optimizer= Adam(), | |||||
metrics=AccuracyMetric(target='target') | |||||
) | |||||
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(), loss=CrossEntropyLoss(), | |||||
dev_data=dev_data, metrics=AccuracyMetric(target='target')) | |||||
trainer.train() | trainer.train() | ||||
print('Train finished!') | print('Train finished!') | ||||