Browse Source

1. 适配将Batch修改为pytorch的DataLoader的修改

2. 修改embedding.py中的bug
3. ConllReader默认跳过所有的DOCSTART标签
4. 交换bert的heavy lifting到_bert, 将BertEncoder在bert.py中暴露
5. crf中allow_transition的include_end_start修改为false,以与CRF的默认值适配
6. allow_transition与SpanMetric支持BIOES类型的tag
7. datainfo中增加打印格式化输出
tags/v0.4.10
yh_cc 6 years ago
parent
commit
2f5d8967a3
31 changed files with 913 additions and 899 deletions
  1. +5
    -1
      fastNLP/__init__.py
  2. +1
    -1
      fastNLP/core/__init__.py
  3. +147
    -148
      fastNLP/core/batch.py
  4. +4
    -1
      fastNLP/core/field.py
  5. +28
    -17
      fastNLP/core/losses.py
  6. +32
    -23
      fastNLP/core/metrics.py
  7. +2
    -3
      fastNLP/core/predictor.py
  8. +11
    -3
      fastNLP/core/tester.py
  9. +59
    -39
      fastNLP/core/trainer.py
  10. +8
    -0
      fastNLP/io/base_loader.py
  11. +2
    -1
      fastNLP/io/dataset_loader.py
  12. +8
    -5
      fastNLP/io/file_reader.py
  13. +15
    -5
      fastNLP/modules/decoder/crf.py
  14. +2
    -1
      fastNLP/modules/encoder/__init__.py
  15. +385
    -80
      fastNLP/modules/encoder/_bert.py
  16. +88
    -371
      fastNLP/modules/encoder/bert.py
  17. +18
    -9
      fastNLP/modules/encoder/embedding.py
  18. +11
    -1
      fastNLP/modules/encoder/lstm.py
  19. +2
    -5
      reproduction/Biaffine_parser/run.py
  20. +5
    -5
      reproduction/POS_tagging/train_pos_tag.py
  21. +4
    -8
      reproduction/Star_transformer/train.py
  22. +4
    -13
      reproduction/matching/snli.py
  23. +2
    -1
      reproduction/utils.py
  24. +10
    -10
      test/core/test_batch.py
  25. +27
    -73
      test/core/test_callbacks.py
  26. +9
    -1
      test/core/test_metrics.py
  27. +11
    -48
      test/core/test_trainer.py
  28. +2
    -5
      test/models/model_runner.py
  29. +0
    -1
      test/models/test_biaffine_parser.py
  30. +5
    -5
      test/modules/decoder/test_CRF.py
  31. +6
    -15
      test/test_tutorials.py

+ 5
- 1
fastNLP/__init__.py View File

@@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的
__all__ = [
"Instance",
"FieldArray",
"Batch",

"DataSetIter",
"BatchIter",
"TorchLoaderIter",

"Vocabulary",
"DataSet",
"Const",


+ 1
- 1
fastNLP/core/__init__.py View File

@@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
介绍core 的子模块的分工,好像必要性不大
"""
from .batch import Batch
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC
from .const import Const
from .dataset import DataSet


+ 147
- 148
fastNLP/core/batch.py View File

@@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。

"""
__all__ = [
"Batch"
"BatchIter",
"DataSetIter",
"TorchLoaderIter",
]

import atexit
@@ -12,9 +14,11 @@ from queue import Empty, Full
import numpy as np
import torch
import torch.multiprocessing as mp
import torch.utils.data
from numbers import Number

from .sampler import RandomSampler
from .sampler import SequentialSampler
from .dataset import DataSet

_python_is_exit = False

@@ -27,162 +31,157 @@ def _set_python_is_exit():
atexit.register(_set_python_is_exit)


class Batch(object):
"""
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch`

Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出,
组成 `x` 和 `y`::

batch = Batch(data_set, batch_size=16, sampler=SequentialSampler())
num_batch = len(batch)
for batch_x, batch_y in batch:
# do stuff ...

:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
:param int batch_size: 取出的batch大小
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`.
Default: ``None``
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`.
Default: ``False``
:param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch.
Default: ``False``
"""
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False):
class DataSetGetter:
def __init__(self, dataset: DataSet, as_numpy=False):
self.dataset = dataset
self.batch_size = batch_size
if sampler is None:
sampler = RandomSampler()
self.sampler = sampler
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input}
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target}
self.as_numpy = as_numpy
self.idx_list = None
self.curidx = 0
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
self.cur_batch_indices = None
self.prefetch = prefetch
self.lengths = 0
def fetch_one(self):
if self.curidx >= len(self.idx_list):
return None
self.idx_list = list(range(len(dataset)))

def __getitem__(self, idx: int):
# mapping idx to sampled idx
idx = self.idx_list[idx]
inputs = {n:f.get(idx) for n, f in self.inputs.items()}
targets = {n:f.get(idx) for n, f in self.targets.items()}
return idx, inputs, targets

def __len__(self):
return len(self.dataset)

def collate_fn(self, batch: list):
batch_x = {n:[] for n in self.inputs.keys()}
batch_y = {n:[] for n in self.targets.keys()}
indices = []
for idx, x, y in batch:
indices.append(idx)
for n, v in x.items():
batch_x[n].append(v)
for n, v in y.items():
batch_y[n].append(v)

def pad_batch(batch_dict, field_array):
for n, vlist in batch_dict.items():
f = field_array[n]
if f.padder is None:
batch_dict[n] = np.array(vlist)
else:
data = f.pad(vlist)
if not self.as_numpy:
data, flag = _to_tensor(data, f.dtype)
batch_dict[n] = data
return batch_dict

return (indices,
pad_batch(batch_x, self.inputs),
pad_batch(batch_y, self.targets))

def set_idx_list(self, idx_list):
if len(idx_list) != len(self.idx_list):
raise ValueError
self.idx_list = idx_list

def __getattr__(self, item):
if hasattr(self.dataset, item):
return getattr(self.dataset, item)
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
batch_x, batch_y = {}, {}
indices = self.idx_list[self.curidx:endidx]
self.cur_batch_indices = indices
for field_name, field in self.dataset.get_all_fields().items():
if field.is_target or field.is_input:
batch = field.get(indices)
if not self.as_numpy and \
field.dtype is not None and \
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor):
batch = _to_tensor(batch)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
batch_x[field_name] = batch
self.curidx = endidx
return batch_x, batch_y
raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item))


class SamplerAdapter(torch.utils.data.Sampler):
def __init__(self, sampler, dataset):
self.sampler = sampler
self.dataset = dataset

def __iter__(self):
"""
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process
:return:
"""
if self.prefetch:
return self._run_batch_iter(self)
def batch_iter():
self.init_iter()
while 1:
res = self.fetch_one()
if res is None:
break
yield res
return batch_iter()
return iter(self.sampler(self.dataset))


class BatchIter:
def __init__(self):
self.dataiter = None
self.num_batches = None
self.cur_batch_indices = None
self.batch_size = None

def init_iter(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()
pass

@staticmethod
def get_num_batches(num_samples, batch_size, drop_last):
num_batches = num_samples // batch_size
if not drop_last and (num_samples % batch_size > 0):
num_batches += 1
return num_batches

def __iter__(self):
self.init_iter()
for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices
yield batch_x, batch_y

def get_batch_indices(self):
return self.cur_batch_indices

def __len__(self):
return self.num_batches
def get_batch_indices(self):
"""
取得当前batch在DataSet中所在的index下标序列

:return list(int) indexes: 下标序列
"""
return self.cur_batch_indices
@staticmethod
def _run_fetch(batch, q):
try:
global _python_is_exit
batch.init_iter()
# print('start fetch')
while 1:
res = batch.fetch_one()
# print('fetch one')
while 1:
try:
q.put(res, timeout=3)
break
except Full:
if _python_is_exit:
return
if res is None:
# print('fetch done, waiting processing')
break
# print('fetch exit')
except Exception as e:
q.put(e)
finally:
q.join()
@staticmethod
def _run_batch_iter(batch):
q = mp.JoinableQueue(maxsize=10)
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q))
fetch_p.daemon = True
fetch_p.start()
# print('fork fetch process')
while 1:
try:
res = q.get(timeout=1)
q.task_done()
# print('get fetched')
if res is None:
break
elif isinstance(res, Exception):
raise res
yield res
except Empty as e:
if fetch_p.is_alive():
continue
else:
break
fetch_p.terminate()
fetch_p.join()
# print('iter done')
@property
def dataset(self):
return self.dataiter.dataset


class DataSetIter(BatchIter):
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super().__init__()
assert isinstance(dataset, DataSet)
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
dataset = DataSetGetter(dataset, as_numpy)
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last)
self.batch_size = batch_size


class TorchLoaderIter(BatchIter):
def __init__(self, dataset):
super().__init__()
assert isinstance(dataset, torch.utils.data.DataLoader)
self.dataiter = dataset
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last)
self.batch_size = dataset.batch_size


def _to_tensor(batch):
class OnlineDataGettter:
# TODO
pass


class OnlineDataIter(BatchIter):
# TODO
def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, **kwargs):
super().__init__()


def _to_tensor(batch, field_dtype):
try:
if issubclass(batch.dtype.type, np.floating):
batch = torch.as_tensor(batch).float() # 默认使用float32
if field_dtype is not None \
and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor):
if issubclass(batch.dtype.type, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32
else:
new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制
return new_batch, True
else:
batch = torch.as_tensor(batch) # 复用内存地址,避免复制
return batch, False
except:
pass
return batch
return batch, False

+ 4
- 1
fastNLP/core/field.py View File

@@ -176,7 +176,10 @@ class FieldArray:
if self.padder is None or pad is False:
return np.array(contents)
else:
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
return self.pad(contents)

def pad(self, contents):
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)

def set_padder(self, padder):
"""


+ 28
- 17
fastNLP/core/losses.py View File

@@ -34,14 +34,23 @@ class LossBase(object):
"""
def __init__(self):
self.param_map = {}
self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value
self._checked = False

@property
def param_map(self):
if len(self._param_map) == 0: # 如果为空说明还没有初始化
func_spect = inspect.getfullargspec(self.get_loss)
func_args = [arg for arg in func_spect.args if arg != 'self']
for arg in func_args:
self._param_map[arg] = arg
return self._param_map

def get_loss(self, *args, **kwargs):
raise NotImplementedError
def _init_param_map(self, key_map=None, **kwargs):
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map

:param dict key_map: 表示key的映射关系
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系
@@ -53,30 +62,30 @@ class LossBase(object):
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
for key, value in key_map.items():
if value is None:
self.param_map[key] = key
self._param_map[key] = key
continue
if not isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if not isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
self._param_map[key] = value
value_counter[value].add(key)
for key, value in kwargs.items():
if value is None:
self.param_map[key] = key
self._param_map[key] = key
continue
if not isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
self._param_map[key] = value
value_counter[value].add(key)
for value, key_set in value_counter.items():
if len(key_set) > 1:
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")
# check consistence between signature and param_map
# check consistence between signature and _param_map
func_spect = inspect.getfullargspec(self.get_loss)
func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items():
for func_param, input_param in self._param_map.items():
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the "
@@ -96,7 +105,7 @@ class LossBase(object):
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
@@ -115,19 +124,19 @@ class LossBase(object):
return loss
if not self._checked:
# 1. check consistence between signature and param_map
# 1. check consistence between signature and _param_map
func_spect = inspect.getfullargspec(self.get_loss)
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
for func_arg, input_arg in self._param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.")
# 2. only part of the param_map are passed, left are not
# 2. only part of the _param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping.
if arg not in self._param_map:
self._param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()}

mapped_pred_dict = {}
mapped_target_dict = {}
@@ -149,7 +158,7 @@ class LossBase(object):
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
check_res = _CheckRes(missing=replaced_missing,
@@ -162,6 +171,8 @@ class LossBase(object):
if check_res.missing or check_res.duplicated:
raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.get_loss))
self._checked = True

refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)
loss = self.get_loss(**refined_args)


+ 32
- 23
fastNLP/core/metrics.py View File

@@ -115,9 +115,18 @@ class MetricBase(object):
"""
def __init__(self):
self.param_map = {} # key is param in function, value is input param.
self._param_map = {} # key is param in function, value is input param.
self._checked = False

@property
def param_map(self):
if len(self._param_map) == 0: # 如果为空说明还没有初始化
func_spect = inspect.getfullargspec(self.evaluate)
func_args = [arg for arg in func_spect.args if arg != 'self']
for arg in func_args:
self._param_map[arg] = arg
return self._param_map

@abstractmethod
def evaluate(self, *args, **kwargs):
raise NotImplementedError
@@ -127,7 +136,7 @@ class MetricBase(object):
raise NotImplemented
def _init_param_map(self, key_map=None, **kwargs):
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map

:param dict key_map: 表示key的映射关系
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系
@@ -139,30 +148,30 @@ class MetricBase(object):
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
for key, value in key_map.items():
if value is None:
self.param_map[key] = key
self._param_map[key] = key
continue
if not isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if not isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
self._param_map[key] = value
value_counter[value].add(key)
for key, value in kwargs.items():
if value is None:
self.param_map[key] = key
self._param_map[key] = key
continue
if not isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
self._param_map[key] = value
value_counter[value].add(key)
for value, key_set in value_counter.items():
if len(key_set) > 1:
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")
# check consistence between signature and param_map
# check consistence between signature and _param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items():
for func_param, input_param in self._param_map.items():
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the "
@@ -177,7 +186,7 @@ class MetricBase(object):
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
@@ -206,19 +215,19 @@ class MetricBase(object):
if not self._checked:
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")
# 1. check consistence between signature and param_map
# 1. check consistence between signature and _param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
for func_arg, input_arg in self._param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not
# 2. only part of the _param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping.
if arg not in self._param_map:
self._param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()}
# need to wrap inputs in dict.
mapped_pred_dict = {}
@@ -242,7 +251,7 @@ class MetricBase(object):
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
check_res = _CheckRes(missing=replaced_missing,
@@ -255,10 +264,10 @@ class MetricBase(object):
if check_res.missing or check_res.duplicated:
raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.evaluate))
self._checked = True
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)
self.evaluate(**refined_args)
self._checked = True
return

@@ -416,19 +425,19 @@ def _bioes_tag_to_spans(tags, ignore_labels=None):
ignore_labels = set(ignore_labels) if ignore_labels else set()

spans = []
prev_bmes_tag = None
prev_bioes_tag = None
for idx, tag in enumerate(tags):
tag = tag.lower()
bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'):
bieso_tag, label = tag[:1], tag[2:]
if bieso_tag in ('b', 's'):
spans.append((label, [idx, idx]))
elif bmes_tag in ('i', 'e') and prev_bmes_tag in ('b', 'i') and label == spans[-1][0]:
elif bieso_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx
elif bmes_tag == 'o':
elif bieso_tag == 'o':
pass
else:
spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag
prev_bioes_tag = bieso_tag
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels


+ 2
- 3
fastNLP/core/predictor.py View File

@@ -6,7 +6,7 @@ from collections import defaultdict

import torch

from . import Batch
from . import DataSetIter
from . import DataSet
from . import SequentialSampler
from .utils import _build_args
@@ -44,8 +44,7 @@ class Predictor(object):

self.network.eval()
batch_output = defaultdict(list)
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False,
prefetch=False)
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)

if hasattr(self.network, "predict"):
predict_func = self.network.predict


+ 11
- 3
fastNLP/core/tester.py View File

@@ -37,7 +37,7 @@ import warnings
import torch
import torch.nn as nn

from .batch import Batch
from .batch import BatchIter, DataSetIter
from .dataset import DataSet
from .metrics import _prepare_metrics
from .sampler import SequentialSampler
@@ -82,7 +82,7 @@ class Tester(object):
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
"""
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1):
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1):
super(Tester, self).__init__()
if not isinstance(data, DataSet):
@@ -96,6 +96,14 @@ class Tester(object):
self._model = _move_model_to_device(model, device=device)
self.batch_size = batch_size
self.verbose = verbose

if isinstance(data, DataSet):
self.data_iterator = DataSetIter(
dataset=data, batch_size=batch_size, num_workers=num_workers)
elif isinstance(data, BatchIter):
self.data_iterator = data
else:
raise TypeError("data type {} not support".format(type(data)))
# 如果是DataParallel将没有办法使用predict方法
if isinstance(self._model, nn.DataParallel):
@@ -124,7 +132,7 @@ class Tester(object):
self._model_device = _get_model_device(self._model)
network = self._model
self._mode(network, is_test=True)
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
data_iterator = self.data_iterator
eval_results = {}
try:
with torch.no_grad():


+ 59
- 39
fastNLP/core/trainer.py View File

@@ -311,8 +311,9 @@ try:
from tqdm.auto import tqdm
except:
from .utils import _pseudo_tqdm as tqdm
import warnings

from .batch import Batch
from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException
from .dataset import DataSet
from .losses import _prepare_losser
@@ -320,7 +321,6 @@ from .metrics import _prepare_metrics
from .optimizer import Optimizer
from .sampler import Sampler
from .sampler import RandomSampler
from .sampler import SequentialSampler
from .tester import Tester
from .utils import _CheckError
from .utils import _build_args
@@ -351,6 +351,8 @@ class Trainer(object):
:param int batch_size: 训练和验证的时候的batch大小。
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler`
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch
:param num_workers: int, 有多少个线程来进行数据pad处理。
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。
:param int n_epochs: 需要优化迭代多少次。
@@ -367,7 +369,6 @@ class Trainer(object):
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。
保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
的计算位置进行管理。支持以下的输入:
@@ -394,16 +395,17 @@ class Trainer(object):
"""
def __init__(self, train_data, model, optimizer=None, loss=None,
batch_size=32, sampler=None, update_every=1,
n_epochs=10, print_every=5,
batch_size=32, sampler=None, drop_last=False, update_every=1,
num_workers=0, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None,
validate_every=-1, save_path=None,
prefetch=False, use_tqdm=True, device=None,
callbacks=None,
check_code_level=0):
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False,
callbacks=None, check_code_level=0):
if prefetch and num_workers==0:
num_workers = 1
if prefetch:
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.")

super(Trainer, self).__init__()
if not isinstance(train_data, DataSet):
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
@@ -430,17 +432,27 @@ class Trainer(object):
if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
elif len(metrics) > 0:
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')
else:
self.metric_key = None
# prepare loss
losser = _prepare_losser(loss)
# sampler check
if sampler is not None and not isinstance(sampler, Sampler):
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))

if sampler is None:
sampler = RandomSampler()

if isinstance(train_data, DataSet):
self.data_iterator = DataSetIter(
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last)
elif isinstance(train_data, BatchIter):
self.data_iterator = train_data
else:
raise TypeError("train_data type {} not support".format(type(train_data)))
if check_code_level > -1:
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter):
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level,
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))
@@ -460,8 +472,6 @@ class Trainer(object):
self.best_dev_epoch = None
self.best_dev_step = None
self.best_dev_perf = None
self.sampler = sampler if sampler is not None else RandomSampler()
self.prefetch = prefetch
self.n_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
@@ -493,7 +503,7 @@ class Trainer(object):
self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks)
def train(self, load_best_model=True, on_exception='auto'):
"""
使用该函数使Trainer开始训练。
@@ -572,8 +582,7 @@ class Trainer(object):
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar
avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = self.data_iterator
self.batch_per_epoch = data_iterator.num_batches
for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch
@@ -746,7 +755,9 @@ class Trainer(object):

:return bool value: True means current results on dev set is the best.
"""
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
if self.metric_key is None:
self.metric_key = indicator
is_better = True
if self.best_metric_indicator is None:
# first-time validation
@@ -785,15 +796,34 @@ def _get_value_info(_dict):
strs.append(_str)
return strs


from numbers import Number
from .batch import _to_tensor
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None,
check_level=0):
# check get_loss 方法
model_devcie = model.parameters().__next__().device
model_devcie = _get_model_device(model=model)
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch):
def _iter():
start_idx = 0
while start_idx<len(dataset):
batch_x = {}
batch_y = {}
for field_name, field in dataset.get_all_fields().items():
indices = list(range(start_idx, min(start_idx+batch_size, len(dataset))))
if field.is_target or field.is_input:
batch = field.get(indices)
if field.dtype is not None and \
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor):
batch, _ = _to_tensor(batch, field.dtype)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
batch_x[field_name] = batch
yield (batch_x, batch_y)
start_idx += batch_size

for batch_count, (batch_x, batch_y) in enumerate(_iter()):
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check
if batch_count == 0:
@@ -861,26 +891,16 @@ def _check_eval_results(metrics, metric_key, metric_list):
loss, metrics = metrics
if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = metric_list[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]
metric_dict = list(metrics.values())[0] # 取第一个metric
if len(metric_dict) == 1:
if metric_key is None:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if metric_key not in metric_dict:
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}")
indicator_val = metric_dict[metric_key]
indicator = metric_key
else:
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics)))
return indicator_val
return indicator, indicator_val

+ 8
- 0
fastNLP/io/base_loader.py View File

@@ -124,6 +124,14 @@ class DataInfo:
self.embeddings = embeddings or {}
self.datasets = datasets or {}

def __repr__(self):
_str = 'In total {} datasets:\n'.format(len(self.datasets))
for name, dataset in self.datasets.items():
_str += '\t{} has {} instances.\n'.format(name, len(dataset))
_str += 'In total {} vocabs:\n'.format(len(self.vocabs))
for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str

class DataSetLoader:
"""


+ 2
- 1
fastNLP/io/dataset_loader.py View File

@@ -115,7 +115,8 @@ class ConllLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader`

读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为
该符号在conll 2003中被用为文档分割符。

列号从0开始, 每列对应内容为::



+ 8
- 5
fastNLP/io/file_reader.py View File

@@ -90,11 +90,12 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
return sample
with open(path, 'r', encoding=encoding) as f:
sample = []
start = next(f)
if '-DOCSTART-' not in start:
start = next(f).strip()
if '-DOCSTART-' not in start and start!='':
sample.append(start.split())
for line_idx, line in enumerate(f, 1):
if line.startswith('\n'):
line = line.strip()
if line=='':
if len(sample):
try:
res = parse_conll(sample)
@@ -107,7 +108,8 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
elif line.startswith('#'):
continue
else:
sample.append(line.split())
if not line.startswith('-DOCSTART-'):
sample.append(line.split())
if len(sample) > 0:
try:
res = parse_conll(sample)
@@ -115,4 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e:
if dropna:
return
raise ValueError('invalid instance at line: {}'.format(line_idx))
print('invalid instance at line: {}'.format(line_idx))
raise e

+ 15
- 5
fastNLP/modules/decoder/crf.py View File

@@ -9,7 +9,7 @@ from torch import nn
from ..utils import initial_parameter


def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False):
"""
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions`

@@ -17,7 +17,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):

:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。
:param str encoding_type: 支持"bio", "bmes", "bmeso"。
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx);
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容
@@ -58,7 +58,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label):
"""

:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
:param str from_label: 比如"PER", "LOC"等label
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag
@@ -134,9 +134,19 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
return to_tag in ['b', 's', 'end', 'o']
else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag))
elif encoding_type == 'bioes':
if from_tag == 'start':
return to_tag in ['b', 's', 'o']
elif from_tag == 'b':
return to_tag in ['i', 'e'] and from_label == to_label
elif from_tag == 'i':
return to_tag in ['i', 'e'] and from_label == to_label
elif from_tag in ['e', 's', 'o']:
return to_tag in ['b', 's', 'end', 'o']
else:
raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag))
else:
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type))
raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type))


class ConditionalRandomField(nn.Module):


+ 2
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -18,7 +18,8 @@ __all__ = [
"VarLSTM",
"VarGRU"
]
from .bert import BertModel
from ._bert import BertModel
from .bert import BertWordPieceEncoder
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding


+ 385
- 80
fastNLP/modules/encoder/_bert.py View File

@@ -6,18 +6,399 @@
"""


import torch
from torch import nn

from ... import Vocabulary
import collections

import os
import unicodedata
from ...io.file_utils import _get_base_url, cached_path
from .bert import BertModel
import numpy as np
from itertools import chain
import copy
import json
import math
import os

import torch
from torch import nn

CONFIG_FILE = 'bert_config.json'
MODEL_WEIGHTS = 'pytorch_model.bin'


def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
return x * torch.sigmoid(x)


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias


class BertEmbeddings(nn.Module):
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)

def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)

words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings


class BertSelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)

self.dropout = nn.Dropout(attention_probs_dropout_prob)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)

query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer


class BertSelfOutput(nn.Module):
def __init__(self, hidden_size, hidden_dropout_prob):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


class BertAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)

def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output


class BertIntermediate(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_act):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = ACT2FN[hidden_act] \
if isinstance(hidden_act, str) else hidden_act

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states


class BertOutput(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
super(BertOutput, self).__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


class BertLayer(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
intermediate_size, hidden_act):
super(BertLayer, self).__init__()
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob)
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act)
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)

def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output


class BertEncoder(nn.Module):
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob,
intermediate_size, hidden_act):
super(BertEncoder, self).__init__()
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
intermediate_size, hidden_act)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)])

def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers


class BertPooler(nn.Module):
def __init__(self, hidden_size):
super(BertPooler, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output


class BertModel(nn.Module):
"""BERT(Bidirectional Embedding Representations from Transformers).

如果你想使用预训练好的权重矩阵,请在以下网址下载.
sources::

'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",


用预训练权重矩阵来建立BERT模型::

model = BertModel.from_pretrained("path/to/weights/directory")

用随机初始化权重矩阵来建立BERT模型::

model = BertModel()

:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu``
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1
:param int max_position_embeddings: 最大的序列长度,默认值为512,
:param int type_vocab_size: 最大segment数量,默认值为2
:param int initializer_range: 初始化权重范围,默认值为0.02
"""

def __init__(self, vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
super(BertModel, self).__init__()
self.hidden_size = hidden_size
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings,
type_vocab_size, hidden_dropout_prob)
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads,
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size,
hidden_act)
self.pooler = BertPooler(hidden_size)
self.initializer_range = initializer_range

self.apply(self.init_bert_weights)

def init_bert_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output

@classmethod
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs):
# Load config
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE)
config = json.load(open(config_file, "r"))
# config = BertConfig.from_json_file(config_file)
# logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(*inputs, **config, **kwargs)
if state_dict is None:
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS)
state_dict = torch.load(weights_path)

old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')

load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
print("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
print("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
return model












def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
@@ -547,79 +928,3 @@ class _WordPieceBertModel(nn.Module):
outputs[l_index] = bert_outputs[l]
return outputs

class BertWordPieceEncoder(nn.Module):
"""
可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。

:param vocab: Vocabulary.
:param model_dir_or_name:
:param layers:
:param requires_grad:
"""
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1',
requires_grad:bool=False):
super().__init__()
PRETRAIN_URL = _get_base_url('bert')
# TODO 修改
PRETRAINED_BERT_MODEL_DIR = {'en-base': 'bert_en-80f95ea7.tar.gz',
'cn': 'elmo_cn.zip'}

if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR:
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
# 检查是否存在
elif os.path.isdir(model_dir_or_name):
model_dir = model_dir_or_name
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers)
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
self.requires_grad = requires_grad

@property
def requires_grad(self):
"""
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
:return:
"""
requires_grads = set([param.requires_grad for name, param in self.named_parameters()])
if len(requires_grads)==1:
return requires_grads.pop()
else:
return None

@requires_grad.setter
def requires_grad(self, value):
for name, param in self.named_parameters():
param.requires_grad = value

@property
def embed_size(self):
return self._embed_size

def index_datasets(self, *datasets):
"""
对datasets进行word piece的index。

Example::

:param datasets:
:return:
"""
self.model.index_dataset(*datasets)

def forward(self, words, token_type_ids=None):
"""
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
删除这两个表示。

:param words: batch_size x max_len
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
outputs = self.model(words, token_type_ids)
outputs = torch.cat([*outputs], dim=-1)

return outputs

+ 88
- 371
fastNLP/modules/encoder/bert.py View File

@@ -1,378 +1,95 @@
"""
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.

"""
import copy
import json
import math
import os

import torch
from torch import nn
import torch
from ...core import Vocabulary
from ...io.file_utils import _get_base_url, cached_path
from ._bert import _WordPieceBertModel

CONFIG_FILE = 'bert_config.json'
MODEL_WEIGHTS = 'pytorch_model.bin'


def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
return x * torch.sigmoid(x)


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


class BertLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-12):
super(BertLayerNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.bias = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = eps

def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.weight * x + self.bias


class BertEmbeddings(nn.Module):
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob):
super(BertEmbeddings, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size)

# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)

def forward(self, input_ids, token_type_ids=None):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)

words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)

embeddings = words_embeddings + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings


class BertSelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
super(BertSelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)

self.dropout = nn.Dropout(attention_probs_dropout_prob)

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)

query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + attention_mask

# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)

context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer


class BertSelfOutput(nn.Module):
def __init__(self, hidden_size, hidden_dropout_prob):
super(BertSelfOutput, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


class BertAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
super(BertAttention, self).__init__()
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob)

def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output


class BertIntermediate(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_act):
super(BertIntermediate, self).__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = ACT2FN[hidden_act] \
if isinstance(hidden_act, str) else hidden_act

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states


class BertOutput(nn.Module):
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob):
super(BertOutput, self).__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12)
self.dropout = nn.Dropout(hidden_dropout_prob)

def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states


class BertLayer(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
intermediate_size, hidden_act):
super(BertLayer, self).__init__()
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob)
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act)
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob)

def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output


class BertEncoder(nn.Module):
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob,
intermediate_size, hidden_act):
super(BertEncoder, self).__init__()
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob,
intermediate_size, hidden_act)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)])

def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
all_encoder_layers = []
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
if output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
if not output_all_encoded_layers:
all_encoder_layers.append(hidden_states)
return all_encoder_layers


class BertPooler(nn.Module):
def __init__(self, hidden_size):
super(BertPooler, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output


class BertModel(nn.Module):
"""BERT(Bidirectional Embedding Representations from Transformers).

如果你想使用预训练好的权重矩阵,请在以下网址下载.
sources::

'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz",
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz",
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz",
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz",
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz",
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz",


用预训练权重矩阵来建立BERT模型::

model = BertModel.from_pretrained("path/to/weights/directory")

用随机初始化权重矩阵来建立BERT模型::

model = BertModel()

:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu``
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1
:param int max_position_embeddings: 最大的序列长度,默认值为512,
:param int type_vocab_size: 最大segment数量,默认值为2
:param int initializer_range: 初始化权重范围,默认值为0.02
class BertWordPieceEncoder(nn.Module):
"""
可以通过读取vocabulary使用的Bert的Encoder。传入vocab,然后调用index_datasets方法在vocabulary中生成word piece的表示。

def __init__(self, vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02):
super(BertModel, self).__init__()
self.hidden_size = hidden_size
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings,
type_vocab_size, hidden_dropout_prob)
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads,
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size,
hidden_act)
self.pooler = BertPooler(hidden_size)
self.initializer_range = initializer_range

self.apply(self.init_bert_weights)

def init_bert_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()

def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

embedding_output = self.embeddings(input_ids, token_type_ids)
encoded_layers = self.encoder(embedding_output,
extended_attention_mask,
output_all_encoded_layers=output_all_encoded_layers)
sequence_output = encoded_layers[-1]
pooled_output = self.pooler(sequence_output)
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
return encoded_layers, pooled_output

@classmethod
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs):
# Load config
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE)
config = json.load(open(config_file, "r"))
# config = BertConfig.from_json_file(config_file)
# logger.info("Model config {}".format(config))
# Instantiate model.
model = cls(*inputs, **config, **kwargs)
if state_dict is None:
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS)
state_dict = torch.load(weights_path)

old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)

missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata

def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')

load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
if len(missing_keys) > 0:
print("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
print("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
return model
:param fastNLP.Vocabulary vocab: 词表
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased``
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param bool requires_grad: 是否需要gradient。
"""
def __init__(self, vocab:Vocabulary, model_dir_or_name:str='en-base', layers:str='-1',
requires_grad:bool=False):
super().__init__()
PRETRAIN_URL = _get_base_url('bert')
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
}

if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR:
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
# 检查是否存在
elif os.path.isdir(model_dir_or_name):
model_dir = model_dir_or_name
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

self.model = _WordPieceBertModel(model_dir=model_dir, vocab=vocab, layers=layers)
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
self.requires_grad = requires_grad

@property
def requires_grad(self):
"""
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
:return:
"""
requires_grads = set([param.requires_grad for name, param in self.named_parameters()])
if len(requires_grads)==1:
return requires_grads.pop()
else:
return None

@requires_grad.setter
def requires_grad(self, value):
for name, param in self.named_parameters():
param.requires_grad = value

@property
def embed_size(self):
return self._embed_size

def index_datasets(self, *datasets):
"""
根据datasets中的'words'列对datasets进行word piece的index。

Example::

:param datasets:
:return:
"""
self.model.index_dataset(*datasets)

def forward(self, words, token_type_ids=None):
"""
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
删除这两个表示。

:param words: batch_size x max_len
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
"""
outputs = self.model(words, token_type_ids)
outputs = torch.cat([*outputs], dim=-1)

return outputs

+ 18
- 9
fastNLP/modules/encoder/embedding.py View File

@@ -15,7 +15,7 @@ from ...io.file_utils import cached_path, _get_base_url
from ._bert import _WordBertModel
from typing import List

from ... import DataSet, Batch, SequentialSampler
from ... import DataSet, DataSetIter, SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device


@@ -157,7 +157,6 @@ class StaticEmbedding(TokenEmbedding):
super(StaticEmbedding, self).__init__(vocab)

# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server,
PRETRAIN_URL = _get_base_url('static')
PRETRAIN_STATIC_FILES = {
'en': 'glove.840B.300d-cc1ad5e1.tar.gz',
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz',
@@ -170,6 +169,7 @@ class StaticEmbedding(TokenEmbedding):

# 得到cache_path
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
PRETRAIN_URL = _get_base_url('static')
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_path = cached_path(model_url)
@@ -234,7 +234,7 @@ class ContextualEmbedding(TokenEmbedding):
with torch.no_grad():
for index, dataset in enumerate(datasets):
try:
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), prefetch=False)
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch:
words = batch_x['words'].to(device)
words_list = words.tolist()
@@ -325,11 +325,11 @@ class ElmoEmbedding(ContextualEmbedding):
self.layers = layers

# 根据model_dir_or_name检查是否存在并下载
PRETRAIN_URL = _get_base_url('elmo')
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz',
'cn': 'elmo_cn-5e9b34e2.tar.gz'}

if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
PRETRAIN_URL = _get_base_url('elmo')
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
@@ -383,7 +383,7 @@ class ElmoEmbedding(ContextualEmbedding):
def requires_grad(self, value):
for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name: # 这个不能加入到requires_grad中
pass
continue
param.requires_grad = value


@@ -411,7 +411,6 @@ class BertEmbedding(ContextualEmbedding):
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False):
super(BertEmbedding, self).__init__(vocab)
# 根据model_dir_or_name检查是否存在并下载
PRETRAIN_URL = _get_base_url('bert')
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
@@ -427,6 +426,7 @@ class BertEmbedding(ContextualEmbedding):
}

if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
model_url = PRETRAIN_URL + model_name
model_dir = cached_path(model_url)
@@ -478,7 +478,7 @@ class BertEmbedding(ContextualEmbedding):
def requires_grad(self, value):
for name, param in self.named_parameters():
if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中
pass
continue
param.requires_grad = value


@@ -566,6 +566,7 @@ class CNNCharEmbedding(TokenEmbedding):
for i in range(len(kernel_sizes))])
self._embed_size = embed_size
self.fc = nn.Linear(sum(filter_nums), embed_size)
self.init_param()

def forward(self, words):
"""
@@ -618,9 +619,17 @@ class CNNCharEmbedding(TokenEmbedding):
def requires_grad(self, value):
for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
pass
continue
param.requires_grad = value

def init_param(self):
for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset
continue
if param.data.dim()>1:
nn.init.xavier_normal_(param, 1)
else:
nn.init.uniform_(param, -1, 1)

class LSTMCharEmbedding(TokenEmbedding):
"""
@@ -744,7 +753,7 @@ class LSTMCharEmbedding(TokenEmbedding):
def requires_grad(self, value):
for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
pass
continue
param.requires_grad = value




+ 11
- 1
fastNLP/modules/encoder/lstm.py View File

@@ -35,8 +35,18 @@ class LSTM(nn.Module):
self.batch_first = batch_first
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first,
dropout=dropout, bidirectional=bidirectional)
self.init_param()
initial_parameter(self, initial_method)

def init_param(self):
for name, param in self.named_parameters():
if 'bias_i' in name:
param.data.fill_(1)
elif 'bias_h' in name:
param.data.fill_(0)
else:
nn.init.xavier_normal_(param)

def forward(self, x, seq_len=None, h0=None, c0=None):
"""



+ 2
- 5
reproduction/Biaffine_parser/run.py View File

@@ -184,11 +184,8 @@ def train(path):
m.weight.requires_grad = True

# Trainer
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data,
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS',
**train_args.data,
optimizer=fastNLP.Adam(**optim_args.data),
save_path=path,
trainer = Trainer(train_data=train_data, model=model, optimizer=fastNLP.Adam(**optim_args.data), loss=ParserLoss(),
dev_data=dev_data, metrics=ParserMetric(), metric_key='UAS', save_path=path,
callbacks=[MyCallback()])

# Start training


+ 5
- 5
reproduction/POS_tagging/train_pos_tag.py View File

@@ -89,11 +89,11 @@ def train(train_data_path, dev_data_path, checkpoint=None, save=None):
model = torch.load(checkpoint)

# call trainer to train
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict",
target="truth",
seq_lens="word_seq_origin_len"),
dev_data=dev_data, metric_key="f",
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save)
trainer = Trainer(dataset, model, loss=None, n_epochs=20, print_every=10, dev_data=dev_data,
metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict",
target="truth",
seq_lens="word_seq_origin_len"), metric_key="f", save_path=save,
use_tqdm=True)
trainer.train(load_best_model=True)

# save model & pipeline


+ 4
- 8
reproduction/Star_transformer/train.py View File

@@ -149,14 +149,10 @@ def train():
) if x.requires_grad and x.size(0) != len(word_v)]
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1},
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ]
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data,
loss=loss, metrics=metric, metric_key=metric_key,
optimizer=torch.optim.Adam(optim_cfg),
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000,
device=device,
use_tqdm=False, prefetch=False,
save_path=g_args.log,
callbacks=[MyCallback()])
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss,
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric,
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False,
device=device, callbacks=[MyCallback()])

trainer.train()
tester = FN.Tester(data=test_data, model=model, metrics=metric,


+ 4
- 13
reproduction/matching/snli.py View File

@@ -70,19 +70,10 @@ test_data = preprocess_data(test_data, bert_dirs)

model = BertForNLI(bert_dir=bert_dirs)

trainer = Trainer(
train_data=train_data,
model=model,
optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12,
n_epochs=4,
print_every=-1,
dev_data=dev_data,
metrics=AccuracyMetric(),
metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1
)
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(lr=2e-5, model_params=model.parameters()),
batch_size=torch.cuda.device_count() * 12, n_epochs=4, print_every=-1, dev_data=dev_data,
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
trainer.train(load_best_model=True)

tester = Tester(


+ 2
- 1
reproduction/utils.py View File

@@ -13,7 +13,8 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
}
如果paths为不合法的,将直接进行raise相应的错误

:param paths: 路径
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt,
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
if isinstance(paths, str):


+ 10
- 10
test/core/test_batch.py View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch

from fastNLP import Batch
from fastNLP import DataSetIter
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import SequentialSampler
@@ -57,7 +57,7 @@ class TestCase1(unittest.TestCase):
dataset = construct_dataset(
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
dataset.set_target()
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
batch = DataSetIter(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
cnt = 0
for _, _ in batch:
@@ -68,7 +68,7 @@ class TestCase1(unittest.TestCase):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
for x, y in iter:
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray))
self.assertEqual(len(x["x"]), 4)
@@ -81,7 +81,7 @@ class TestCase1(unittest.TestCase):
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
for x, y in iter:
self.assertEqual(x["x"].shape, (4, 4))
self.assertEqual(y["y"].shape, (4, 4))
@@ -91,7 +91,7 @@ class TestCase1(unittest.TestCase):
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
for x, y in iter:
self.assertEqual(x["x"].shape, (4, 4))
self.assertEqual(y["y"].shape, (4, 4))
@@ -101,7 +101,7 @@ class TestCase1(unittest.TestCase):
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter:
self.assertTrue(isinstance(x["x"], torch.Tensor))
self.assertEqual(tuple(x["x"].shape), (4, 4))
@@ -113,7 +113,7 @@ class TestCase1(unittest.TestCase):
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter:
self.assertTrue(isinstance(x["x"], torch.Tensor))
self.assertEqual(tuple(x["x"].shape), (4, 4))
@@ -125,7 +125,7 @@ class TestCase1(unittest.TestCase):
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)])
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter:
self.assertTrue(isinstance(x["x"], torch.Tensor))
self.assertEqual(tuple(x["x"].shape), (4, 4))
@@ -137,7 +137,7 @@ class TestCase1(unittest.TestCase):
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)])
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter:
print(x, y)
@@ -146,7 +146,7 @@ class TestCase1(unittest.TestCase):
num_samples = 1000
dataset = generate_fake_dataset(num_samples)
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch:
pass


+ 27
- 73
test/core/test_callbacks.py View File

@@ -40,89 +40,50 @@ class TestCallback(unittest.TestCase):
def test_gradient_clip(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=20,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)])
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False,
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)], check_code_level=2)
trainer.train()
def test_early_stop(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=20,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.01),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[EarlyStopCallback(5)])
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.01), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False,
callbacks=[EarlyStopCallback(5)], check_code_level=2)
trainer.train()
def test_lr_scheduler(self):
data_set, model = prepare_env()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=optimizer,
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))])
trainer = Trainer(data_set, model, optimizer=optimizer, loss=BCELoss(pred="predict", target="y"), batch_size=32,
n_epochs=5, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False,
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))],
check_code_level=2)
trainer.train()
def test_KeyBoardInterrupt(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
callbacks=[ControlC(False)])
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, callbacks=[ControlC(False)],
check_code_level=2)
trainer.train()
def test_LRFinder(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
callbacks=[LRFinder(len(data_set) // 32)])
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False,
callbacks=[LRFinder(len(data_set) // 32)], check_code_level=2)
trainer.train()
def test_TensorboardCallback(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=5,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[TensorboardCallback("loss", "metric")])
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False,
callbacks=[TensorboardCallback("loss", "metric")], check_code_level=2)
trainer.train()
def test_readonly_property(self):
@@ -141,16 +102,9 @@ class TestCallback(unittest.TestCase):
print(self.optimizer)
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=total_epochs,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[MyCallback()])
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=total_epochs, print_every=50, dev_data=data_set,
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, callbacks=[MyCallback()],
check_code_level=2)
trainer.train()
assert passed_epochs == list(range(1, total_epochs + 1))

+ 9
- 1
test/core/test_metrics.py View File

@@ -161,7 +161,15 @@ class TestAccuracyMetric(unittest.TestCase):
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_duplicate(self):
# 0.4.1的潜在bug,不能出现形参重复的情况
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0}
target_dict = {'targets':torch.zeros(4, 3), 'target': 0}
metric(pred_dict=pred_dict, target_dict=target_dict)


def test_seq_len(self):
N = 256
seq_len = torch.zeros(N).long()


+ 11
- 48
test/core/test_trainer.py View File

@@ -46,18 +46,10 @@ class TrainerTestGround(unittest.TestCase):
model = NaiveClassifier(2, 1)
trainer = Trainer(train_set, model,
loss=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10,
batch_size=32,
print_every=50,
validate_every=-1,
dev_data=dev_set,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=True,
save_path=None)
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set,
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
use_tqdm=True, check_code_level=2)
trainer.train()
"""
# 应该正确运行
@@ -83,10 +75,7 @@ class TrainerTestGround(unittest.TestCase):
model = Model()
with self.assertRaises(RuntimeError):
trainer = Trainer(
train_data=dataset,
model=model
)
trainer = Trainer(train_data=dataset, model=model)
"""
# 应该获取到的报错提示
NameError:
@@ -116,12 +105,7 @@ class TrainerTestGround(unittest.TestCase):
return {'loss': loss}
model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False)
trainer.train()
"""
# 应该正确运行
@@ -147,12 +131,7 @@ class TrainerTestGround(unittest.TestCase):
model = Model()
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False)
trainer.train()
def test_trainer_suggestion4(self):
@@ -175,12 +154,7 @@ class TrainerTestGround(unittest.TestCase):
model = Model()
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False)
def test_trainer_suggestion5(self):
# 检查报错提示能否正确提醒用户
@@ -203,12 +177,7 @@ class TrainerTestGround(unittest.TestCase):
return {'loss': loss}
model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False)
def test_trainer_suggestion6(self):
# 检查报错提示能否正确提醒用户
@@ -233,14 +202,8 @@ class TrainerTestGround(unittest.TestCase):
model = Model()
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
loss=CrossEntropyLoss(),
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2)
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset,
metrics=AccuracyMetric(), use_tqdm=False)
"""
def test_trainer_multiprocess(self):


+ 2
- 5
test/models/model_runner.py View File

@@ -130,11 +130,8 @@ class ModelRunner():
tester = Tester(data=data, model=model, metrics=metrics,
batch_size=BATCH_SIZE, verbose=0)
before_train = tester.test()
trainer = Trainer(model=model, train_data=data, dev_data=None,
n_epochs=N_EPOCHS, batch_size=BATCH_SIZE,
loss=loss,
save_path=None,
use_tqdm=False)
trainer = Trainer(train_data=data, model=model, loss=loss, batch_size=BATCH_SIZE, n_epochs=N_EPOCHS,
dev_data=None, save_path=None, use_tqdm=False)
trainer.train(load_best_model=False)
after_train = tester.test()
for metric_name, v1 in before_train.items():


+ 0
- 1
test/models/test_biaffine_parser.py View File

@@ -1,6 +1,5 @@
import unittest

import fastNLP
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric
from .model_runner import *



+ 5
- 5
test/modules/decoder/test_CRF.py View File

@@ -10,14 +10,14 @@ class TestCRF(unittest.TestCase):
id2label = {0: 'B', 1: 'I', 2:'O'}
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2),
(2, 4), (3, 0), (3, 2)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True)))

id2label = {0: 'B', 1:'M', 2:'E', 3:'S'}
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True)))

id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"}
allowed_transitions(id2label)
allowed_transitions(id2label, include_start_end=True)

labels = ['O']
for label in ['X', 'Y']:
@@ -27,7 +27,7 @@ class TestCRF(unittest.TestCase):
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1),
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3),
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label)))
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True)))

labels = []
for label in ['X', 'Y']:
@@ -37,7 +37,7 @@ class TestCRF(unittest.TestCase):
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4),
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0),
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)}
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES')))
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True)))

def test_case2(self):
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。


+ 6
- 15
test/test_tutorials.py View File

@@ -60,10 +60,10 @@ class TestTutorial(unittest.TestCase):
print(test_data[0])

# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具
from fastNLP.core.batch import Batch
from fastNLP.core.batch import DataSetIter
from fastNLP.core.sampler import RandomSampler

batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())
batch_iterator = DataSetIter(dataset=train_data, batch_size=2, sampler=RandomSampler())
for batch_x, batch_y in batch_iterator:
print("batch_x has: ", batch_x)
print("batch_y has: ", batch_y)
@@ -85,12 +85,8 @@ class TestTutorial(unittest.TestCase):
# 实例化Trainer,传入模型和数据,进行训练
# 先在test_data拟合(确保模型的实现是正确的)
copy_model = deepcopy(model)
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,
loss=loss,
metrics=metric,
save_path=None,
batch_size=32,
n_epochs=5)
overfit_trainer = Trainer(train_data=test_data, model=copy_model, loss=loss, batch_size=32, n_epochs=5,
dev_data=test_data, metrics=metric, save_path=None)
overfit_trainer.train()

# 用train_data训练,在test_data验证
@@ -147,13 +143,8 @@ class TestTutorial(unittest.TestCase):

from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam

trainer = Trainer(model=model,
train_data=train_data,
dev_data=dev_data,
loss=CrossEntropyLoss(),
optimizer= Adam(),
metrics=AccuracyMetric(target='target')
)
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(), loss=CrossEntropyLoss(),
dev_data=dev_data, metrics=AccuracyMetric(target='target'))
trainer.train()
print('Train finished!')



Loading…
Cancel
Save