@@ -0,0 +1,16 @@ | |||
.gitignore | |||
.DS_Store | |||
.ipynb_checkpoints | |||
*.pyc | |||
__pycache__ | |||
*.swp | |||
.vscode/ | |||
.idea/** | |||
caches | |||
# fitlog | |||
.fitlog | |||
logs/ | |||
.fitconfig |
@@ -8,7 +8,7 @@ install: | |||
- pip install pytest-cov | |||
# command to run tests | |||
script: | |||
- pytest --cov=./ | |||
- pytest --cov=./ test/ | |||
after_success: | |||
- bash <(curl -s https://codecov.io/bash) |
@@ -56,6 +56,7 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||
快速入门 <user/quickstart> | |||
详细指南 <user/tutorial_one> | |||
科研指南 <user/with_fitlog> | |||
注释语法 <user/example> | |||
API 文档 | |||
------------- | |||
@@ -0,0 +1,104 @@ | |||
====== | |||
大标题 | |||
====== | |||
.. note:: | |||
中文标题需要符号的数量至少是中文字数的两倍 | |||
.. warning:: | |||
符号的数量只可以多,不可以少。 | |||
小标题1 | |||
########### | |||
小标题2 | |||
********* | |||
小标题3(正常使用) | |||
======================== | |||
小标题4 | |||
------------------- | |||
参考 http://docutils.sourceforge.net/docs/user/rst/quickref.html | |||
常见语法 | |||
============ | |||
*emphasis* | |||
**strong** | |||
`text` | |||
``inline literal`` | |||
http://docutils.sf.net/ 孤立的网址会自动生成链接 | |||
显示为特定的文字的链接 `sohu <http://www.sohu.com>`_ | |||
突出显示的 | |||
上面文字 | |||
正常缩进 | |||
形成锻炼 | |||
特殊模块 | |||
============ | |||
选项会自动识别 | |||
-v An option | |||
-o file Same with value | |||
--delta A long option | |||
--delta=len Same with value | |||
图片 | |||
.. image:: ../figures/procedures.PNG | |||
:height: 200 | |||
:width: 560 | |||
:scale: 50 | |||
:alt: alternate text | |||
:align: center | |||
显示一个冒号的代码块:: | |||
中间要空一行 | |||
:: | |||
不显示冒号的代码块 | |||
.. code-block:: python | |||
:linenos: | |||
:emphasize-lines: 1,3 | |||
print("专业的代码块") | |||
print("") | |||
print("有行号和高亮") | |||
数学块 | |||
.. math:: | |||
H_2O + Na = NaOH + H_2 \uparrow | |||
各种连接 | |||
=========== | |||
:doc:`/user/with_fitlog` | |||
:mod:`~fastNLP.core.batch` | |||
:class:`~fastNLP.Batch` | |||
~表示指显示最后一项 | |||
:meth:`fastNLP.DataSet.apply` | |||
@@ -49,7 +49,7 @@ | |||
.. code-block:: python | |||
from fastNLP.models import CNNText | |||
model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1) | |||
model = CNNText((len(vocab),50), num_classes=5, dropout=0.1) | |||
:class:`~fastNLP.models.CNNText` 的网络结构如下:: | |||
@@ -121,4 +121,4 @@ | |||
In Epoch:6/Step:12, got best dev performance:AccuracyMetric: acc=0.8 | |||
Reloaded the best model. | |||
这份教程只是简单地介绍了使用 fastNLP 工作的流程,具体的细节分析见 :doc:`/user/tutorial_one` | |||
这份教程只是简单地介绍了使用 fastNLP 工作的流程,具体的细节分析见 :doc:`/user/tutorial_one` |
@@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||
__all__ = [ | |||
"Instance", | |||
"FieldArray", | |||
"Batch", | |||
"DataSetIter", | |||
"BatchIter", | |||
"TorchLoaderIter", | |||
"Vocabulary", | |||
"DataSet", | |||
"Const", | |||
@@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
介绍core 的子模块的分工,好像必要性不大 | |||
""" | |||
from .batch import Batch | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||
from .const import Const | |||
from .dataset import DataSet | |||
@@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 | |||
""" | |||
__all__ = [ | |||
"Batch" | |||
"BatchIter", | |||
"DataSetIter", | |||
"TorchLoaderIter", | |||
] | |||
import atexit | |||
@@ -12,8 +14,11 @@ from queue import Empty, Full | |||
import numpy as np | |||
import torch | |||
import torch.multiprocessing as mp | |||
import torch.utils.data | |||
from numbers import Number | |||
from .sampler import RandomSampler | |||
from .sampler import SequentialSampler | |||
from .dataset import DataSet | |||
_python_is_exit = False | |||
@@ -26,160 +31,163 @@ def _set_python_is_exit(): | |||
atexit.register(_set_python_is_exit) | |||
class Batch(object): | |||
""" | |||
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` | |||
Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||
组成 `x` 和 `y`:: | |||
batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) | |||
num_batch = len(batch) | |||
for batch_x, batch_y in batch: | |||
# do stuff ... | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`. | |||
Default: ``None`` | |||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
Default: ``False`` | |||
:param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. | |||
Default: ``False`` | |||
""" | |||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||
class DataSetGetter: | |||
def __init__(self, dataset: DataSet, as_numpy=False): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
if sampler is None: | |||
sampler = RandomSampler() | |||
self.sampler = sampler | |||
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} | |||
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} | |||
self.as_numpy = as_numpy | |||
self.idx_list = None | |||
self.curidx = 0 | |||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||
self.cur_batch_indices = None | |||
self.prefetch = prefetch | |||
self.lengths = 0 | |||
def fetch_one(self): | |||
if self.curidx >= len(self.idx_list): | |||
return None | |||
self.idx_list = list(range(len(dataset))) | |||
def __getitem__(self, idx: int): | |||
# mapping idx to sampled idx | |||
idx = self.idx_list[idx] | |||
inputs = {n:f.get(idx) for n, f in self.inputs.items()} | |||
targets = {n:f.get(idx) for n, f in self.targets.items()} | |||
return idx, inputs, targets | |||
def __len__(self): | |||
return len(self.dataset) | |||
def collate_fn(self, batch: list): | |||
batch_x = {n:[] for n in self.inputs.keys()} | |||
batch_y = {n:[] for n in self.targets.keys()} | |||
indices = [] | |||
for idx, x, y in batch: | |||
indices.append(idx) | |||
for n, v in x.items(): | |||
batch_x[n].append(v) | |||
for n, v in y.items(): | |||
batch_y[n].append(v) | |||
def pad_batch(batch_dict, field_array): | |||
for n, vlist in batch_dict.items(): | |||
f = field_array[n] | |||
if f.padder is None: | |||
batch_dict[n] = np.array(vlist) | |||
else: | |||
data = f.pad(vlist) | |||
if not self.as_numpy: | |||
try: | |||
data, flag = _to_tensor(data, f.dtype) | |||
except TypeError as e: | |||
print(f"Field {n} cannot be converted to torch.tensor.") | |||
raise e | |||
batch_dict[n] = data | |||
return batch_dict | |||
return (indices, | |||
pad_batch(batch_x, self.inputs), | |||
pad_batch(batch_y, self.targets)) | |||
def set_idx_list(self, idx_list): | |||
if len(idx_list) != len(self.idx_list): | |||
raise ValueError | |||
self.idx_list = idx_list | |||
def __getattr__(self, item): | |||
if hasattr(self.dataset, item): | |||
return getattr(self.dataset, item) | |||
else: | |||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
batch_x, batch_y = {}, {} | |||
indices = self.idx_list[self.curidx:endidx] | |||
self.cur_batch_indices = indices | |||
for field_name, field in self.dataset.get_all_fields().items(): | |||
if field.is_target or field.is_input: | |||
batch = field.get(indices) | |||
if not self.as_numpy and field.padder is not None: | |||
batch = _to_tensor(batch, field.dtype) | |||
if field.is_target: | |||
batch_y[field_name] = batch | |||
if field.is_input: | |||
batch_x[field_name] = batch | |||
self.curidx = endidx | |||
return batch_x, batch_y | |||
raise AttributeError("'DataSetGetter' object has no attribute '{}'".format(item)) | |||
class SamplerAdapter(torch.utils.data.Sampler): | |||
def __init__(self, sampler, dataset): | |||
self.sampler = sampler | |||
self.dataset = dataset | |||
def __iter__(self): | |||
""" | |||
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||
:return: | |||
""" | |||
if self.prefetch: | |||
return self._run_batch_iter(self) | |||
def batch_iter(): | |||
self.init_iter() | |||
while 1: | |||
res = self.fetch_one() | |||
if res is None: | |||
break | |||
yield res | |||
return batch_iter() | |||
return iter(self.sampler(self.dataset)) | |||
class BatchIter: | |||
def __init__(self): | |||
self.dataiter = None | |||
self.num_batches = None | |||
self.cur_batch_indices = None | |||
self.batch_size = None | |||
def init_iter(self): | |||
self.idx_list = self.sampler(self.dataset) | |||
self.curidx = 0 | |||
self.lengths = self.dataset.get_length() | |||
pass | |||
@staticmethod | |||
def get_num_batches(num_samples, batch_size, drop_last): | |||
num_batches = num_samples // batch_size | |||
if not drop_last and (num_samples % batch_size > 0): | |||
num_batches += 1 | |||
return num_batches | |||
def __iter__(self): | |||
self.init_iter() | |||
for indices, batch_x, batch_y in self.dataiter: | |||
self.cur_batch_indices = indices | |||
yield batch_x, batch_y | |||
def get_batch_indices(self): | |||
return self.cur_batch_indices | |||
def __len__(self): | |||
return self.num_batches | |||
def get_batch_indices(self): | |||
""" | |||
取得当前batch在DataSet中所在的index下标序列 | |||
:return list(int) indexes: 下标序列 | |||
""" | |||
return self.cur_batch_indices | |||
@staticmethod | |||
def _run_fetch(batch, q): | |||
try: | |||
global _python_is_exit | |||
batch.init_iter() | |||
# print('start fetch') | |||
while 1: | |||
res = batch.fetch_one() | |||
# print('fetch one') | |||
while 1: | |||
try: | |||
q.put(res, timeout=3) | |||
break | |||
except Full: | |||
if _python_is_exit: | |||
return | |||
if res is None: | |||
# print('fetch done, waiting processing') | |||
break | |||
# print('fetch exit') | |||
except Exception as e: | |||
q.put(e) | |||
finally: | |||
q.join() | |||
@staticmethod | |||
def _run_batch_iter(batch): | |||
q = mp.JoinableQueue(maxsize=10) | |||
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) | |||
fetch_p.daemon = True | |||
fetch_p.start() | |||
# print('fork fetch process') | |||
while 1: | |||
try: | |||
res = q.get(timeout=1) | |||
q.task_done() | |||
# print('get fetched') | |||
if res is None: | |||
break | |||
elif isinstance(res, Exception): | |||
raise res | |||
yield res | |||
except Empty as e: | |||
if fetch_p.is_alive(): | |||
continue | |||
else: | |||
break | |||
fetch_p.terminate() | |||
fetch_p.join() | |||
# print('iter done') | |||
@property | |||
def dataset(self): | |||
return self.dataiter.dataset | |||
class DataSetIter(BatchIter): | |||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None): | |||
super().__init__() | |||
assert isinstance(dataset, DataSet) | |||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
dataset = DataSetGetter(dataset, as_numpy) | |||
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | |||
self.dataiter = torch.utils.data.DataLoader( | |||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
pin_memory=pin_memory, drop_last=drop_last, | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) | |||
self.batch_size = batch_size | |||
class TorchLoaderIter(BatchIter): | |||
def __init__(self, dataset): | |||
super().__init__() | |||
assert isinstance(dataset, torch.utils.data.DataLoader) | |||
self.dataiter = dataset | |||
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) | |||
self.batch_size = dataset.batch_size | |||
class OnlineDataGettter: | |||
# TODO | |||
pass | |||
def _to_tensor(batch, dtype): | |||
class OnlineDataIter(BatchIter): | |||
# TODO | |||
def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, **kwargs): | |||
super().__init__() | |||
def _to_tensor(batch, field_dtype): | |||
try: | |||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | |||
batch = torch.LongTensor(batch) | |||
if dtype in (float, np.float32, np.float64): | |||
batch = torch.FloatTensor(batch) | |||
except: | |||
pass | |||
return batch | |||
if field_dtype is not None and isinstance(field_dtype, type)\ | |||
and issubclass(field_dtype, Number) \ | |||
and not isinstance(batch, torch.Tensor): | |||
if issubclass(batch.dtype.type, np.floating): | |||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | |||
elif issubclass(batch.dtype.type, np.integer): | |||
new_batch = torch.as_tensor(batch).long() # 复用内存地址,避免复制 | |||
else: | |||
new_batch = torch.as_tensor(batch) | |||
return new_batch, True | |||
else: | |||
return batch, False | |||
except Exception as e: | |||
raise e |
@@ -438,26 +438,29 @@ class EarlyStopCallback(Callback): | |||
class FitlogCallback(Callback): | |||
""" | |||
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | |||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` | |||
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | |||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 | |||
dict的方式传入。如果仅传入DataSet, 则被命名为test | |||
:param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` | |||
:param int verbose: 是否在终端打印内容,0不打印 | |||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||
:param bool log_exception: fitlog是否记录发生的exception信息 | |||
""" | |||
# 还没有被导出到 fastNLP 层 | |||
# 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` | |||
def __init__(self, data=None, tester=None, verbose=0, log_exception=False): | |||
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): | |||
super().__init__() | |||
self.datasets = {} | |||
self.testers = {} | |||
self._log_exception = log_exception | |||
assert isinstance(log_loss_every, int) and log_loss_every>=0 | |||
if tester is not None: | |||
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." | |||
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." | |||
@@ -477,7 +480,9 @@ class FitlogCallback(Callback): | |||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||
self.verbose = verbose | |||
self._log_loss_every = log_loss_every | |||
self._avg_loss = 0 | |||
def on_train_begin(self): | |||
if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None: | |||
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") | |||
@@ -490,8 +495,12 @@ class FitlogCallback(Callback): | |||
fitlog.add_progress(total_steps=self.n_steps) | |||
def on_backward_begin(self, loss): | |||
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) | |||
if self._log_loss_every>0: | |||
self._avg_loss += loss.item() | |||
if self.step%self._log_loss_every==0: | |||
fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch) | |||
self._avg_loss = 0 | |||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | |||
if better_result: | |||
eval_result = deepcopy(eval_result) | |||
@@ -518,7 +527,7 @@ class FitlogCallback(Callback): | |||
def on_exception(self, exception): | |||
fitlog.finish(status=1) | |||
if self._log_exception: | |||
fitlog.add_other(str(exception), name='except_info') | |||
fitlog.add_other(repr(exception), name='except_info') | |||
class LRScheduler(Callback): | |||
@@ -285,7 +285,8 @@ from .field import AutoPadder | |||
from .field import FieldArray | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .field import AppendToTargetOrInputException | |||
from .field import SetInputOrTargetException | |||
class DataSet(object): | |||
""" | |||
@@ -422,7 +423,7 @@ class DataSet(object): | |||
if len(self.field_arrays) == 0: | |||
# DataSet has no field yet | |||
for name, field in instance.fields.items(): | |||
field = field.tolist() if isinstance(field, np.ndarray) else field | |||
# field = field.tolist() if isinstance(field, np.ndarray) else field | |||
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | |||
else: | |||
if len(self.field_arrays) != len(instance.fields): | |||
@@ -431,7 +432,11 @@ class DataSet(object): | |||
.format(len(self.field_arrays), len(instance.fields))) | |||
for name, field in instance.fields.items(): | |||
assert name in self.field_arrays | |||
self.field_arrays[name].append(field) | |||
try: | |||
self.field_arrays[name].append(field) | |||
except AppendToTargetOrInputException as e: | |||
print(f"Cannot append to field:{name}.") | |||
raise e | |||
def add_fieldarray(self, field_name, fieldarray): | |||
""" | |||
@@ -549,6 +554,7 @@ class DataSet(object): | |||
self.field_arrays[new_name].name = new_name | |||
else: | |||
raise KeyError("DataSet has no field named {}.".format(old_name)) | |||
return self | |||
def set_target(self, *field_names, flag=True): | |||
""" | |||
@@ -565,7 +571,11 @@ class DataSet(object): | |||
assert isinstance(flag, bool), "Only bool type supported." | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
self.field_arrays[name].is_target = flag | |||
try: | |||
self.field_arrays[name].is_target = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as target.") | |||
raise e | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
@@ -581,7 +591,11 @@ class DataSet(object): | |||
""" | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
self.field_arrays[name].is_input = flag | |||
try: | |||
self.field_arrays[name].is_input = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||
raise e | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
@@ -748,7 +762,20 @@ class DataSet(object): | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def add_seq_len(self, field_name:str, new_field_name='seq_len'): | |||
""" | |||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | |||
:param field_name: str. | |||
:return: | |||
""" | |||
if self.has_field(field_name=field_name): | |||
self.apply_field(len, field_name, new_field_name=new_field_name) | |||
else: | |||
raise KeyError(f"Field:{field_name} not found.") | |||
return self | |||
def drop(self, func, inplace=True): | |||
""" | |||
func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者加入到返回的DataSet中。 | |||
@@ -778,7 +805,7 @@ class DataSet(object): | |||
""" | |||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `ratio` 这么多数据,第二个DataSet拥有 `(1-ratio)` 这么多数据 | |||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `(1-ratio)` 这么多数据,第二个DataSet拥有`ratio`这么多数据 | |||
:return: [DataSet, DataSet] | |||
""" | |||
assert isinstance(ratio, float) | |||
@@ -1,251 +1,164 @@ | |||
""" | |||
field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fastNLP.DataSet` 中一列的存储方式, | |||
原理部分请参考 :doc:`fastNLP.core.dataset` | |||
""" | |||
__all__ = [ | |||
"FieldArray", | |||
"Padder", | |||
"AutoPadder", | |||
"EngChar2DPadder" | |||
] | |||
from copy import deepcopy | |||
from numbers import Number | |||
import torch | |||
import numpy as np | |||
class FieldArray(object): | |||
""" | |||
别名::class:`fastNLP.FieldArray` :class:`fastNLP.core.field.FieldArray` | |||
FieldArray 是用于保存 :class:`~fastNLP.DataSet` 中一个field的类型。 | |||
:param str name: FieldArray的名称 | |||
:param list,numpy.ndarray content: 列表的元素可以为list,int,float, | |||
:param bool is_target: 这个field是否是一个target field。 | |||
:param bool is_input: 这个field是否是一个input field。 | |||
:param padder: :class:`~fastNLP.Padder` 类型。赋值给fieldarray的padder的对象会被deepcopy一份,需要修改padder参数必须通过 | |||
fieldarray.set_pad_val()。默认为None,即使用 :class:`~fastNLP.AutoPadder` 。 | |||
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, | |||
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 | |||
""" | |||
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): | |||
from typing import Any | |||
from abc import abstractmethod | |||
from copy import deepcopy | |||
from collections import Counter | |||
class SetInputOrTargetException(Exception): | |||
def __init__(self, msg, index=None, field_name=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
self.field_name = field_name # 标示当前field的名称 | |||
class AppendToTargetOrInputException(Exception): | |||
def __init__(self, msg, index=None, field_name=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
self.field_name = field_name # 标示当前field的名称 | |||
class FieldArray: | |||
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): | |||
if len(content)==0: | |||
raise RuntimeError("Empty fieldarray is not allowed.") | |||
_content = content | |||
try: | |||
_content = list(_content) | |||
except BaseException as e: | |||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||
raise e | |||
self.name = name | |||
if isinstance(content, list): | |||
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | |||
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] | |||
for idx, item in enumerate(content): | |||
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) | |||
# 将[np.array] 转化为 list of list | |||
# 也可以支持[array, array, array]的情况 | |||
if isinstance(item, np.ndarray): | |||
content[idx] = content[idx].tolist() | |||
elif isinstance(content, np.ndarray): | |||
content = content.tolist() # convert np.ndarray into 2-D list | |||
else: | |||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||
if len(content) == 0: | |||
raise RuntimeError("Cannot initialize FieldArray with empty list.") | |||
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | |||
self.content_dim = None # 表示content是多少维的list | |||
self.content = _content | |||
self._ignore_type = ignore_type | |||
# 根据input的情况设置input,target等 | |||
self._cell_ndim = None # 多少维度 | |||
self.dtype = None # 最内层的element都是什么类型的 | |||
self._is_input = False | |||
self._is_target = False | |||
if is_input: | |||
self.is_input = is_input | |||
if is_target: | |||
self.is_target = is_target | |||
if padder is None: | |||
padder = AutoPadder(pad_val=0) | |||
else: | |||
assert isinstance(padder, Padder), "padder must be of type Padder." | |||
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." | |||
padder = deepcopy(padder) | |||
self.set_padder(padder) | |||
self.ignore_type = ignore_type | |||
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | |||
self.pytype = None | |||
self.dtype = None | |||
self._is_input = None | |||
self._is_target = None | |||
if is_input is not None or is_target is not None: | |||
self.is_input = is_input | |||
self.is_target = is_target | |||
def _set_dtype(self): | |||
if self.ignore_type is False: | |||
self.pytype = self._type_detection(self.content) | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
@property | |||
def ignore_type(self): | |||
return self._ignore_type | |||
@ignore_type.setter | |||
def ignore_type(self, value): | |||
if value: | |||
self._cell_ndim = None | |||
self.dtype = None | |||
self._ignore_type = value | |||
@property | |||
def is_input(self): | |||
return self._is_input | |||
@is_input.setter | |||
def is_input(self, value): | |||
""" | |||
当 field_array.is_input = True / False 时被调用 | |||
""" | |||
if value is True: | |||
self._set_dtype() | |||
# 如果(value为True)且(_is_input和_is_target都是False)且(ignore_type为False) | |||
if value is True and \ | |||
self._is_target is False and \ | |||
self._ignore_type is False: | |||
self._check_dtype_and_ndim() | |||
if value is False and self._is_target is False: | |||
self.dtype = None | |||
self._cell_ndim = None | |||
self._is_input = value | |||
@property | |||
def is_target(self): | |||
return self._is_target | |||
@is_target.setter | |||
def is_target(self, value): | |||
""" | |||
当 field_array.is_target = True / False 时被调用 | |||
""" | |||
if value is True: | |||
self._set_dtype() | |||
if value is True and \ | |||
self._is_input is False and \ | |||
self._ignore_type is False: | |||
self._check_dtype_and_ndim() | |||
if value is False and self._is_input is False: | |||
self.dtype = None | |||
self._cell_ndim = None | |||
self._is_target = value | |||
def _type_detection(self, content): | |||
""" | |||
当该field被设置为is_input或者is_target时被调用 | |||
""" | |||
if len(content) == 0: | |||
raise RuntimeError("Empty list in Field {}.".format(self.name)) | |||
type_set = set([type(item) for item in content]) | |||
if list in type_set: | |||
if len(type_set) > 1: | |||
# list 跟 非list 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
# >1维list | |||
inner_type_set = set() | |||
for l in content: | |||
[inner_type_set.add(type(obj)) for obj in l] | |||
if list not in inner_type_set: | |||
# 二维list | |||
self.content_dim = 2 | |||
return self._basic_type_detection(inner_type_set) | |||
else: | |||
if len(inner_type_set) == 1: | |||
# >2维list | |||
inner_inner_type_set = set() | |||
for _2d_list in content: | |||
for _1d_list in _2d_list: | |||
[inner_inner_type_set.add(type(obj)) for obj in _1d_list] | |||
if list in inner_inner_type_set: | |||
raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") | |||
# 3维list | |||
self.content_dim = 3 | |||
return self._basic_type_detection(inner_inner_type_set) | |||
else: | |||
# list 跟 非list 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(inner_type_set))) | |||
else: | |||
# 一维list | |||
for content_type in type_set: | |||
if content_type not in self.BASIC_TYPES: | |||
raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( | |||
self.name, self.BASIC_TYPES, content_type)) | |||
self.content_dim = 1 | |||
return self._basic_type_detection(type_set) | |||
def _basic_type_detection(self, type_set): | |||
""" | |||
:param type_set: a set of Python types | |||
:return: one of self.BASIC_TYPES | |||
""" | |||
if len(type_set) == 1: | |||
return type_set.pop() | |||
elif len(type_set) == 2: | |||
# 有多个basic type; 可能需要up-cast | |||
if float in type_set and int in type_set: | |||
# up-cast int to float | |||
return float | |||
else: | |||
# str 跟 int 或者 float 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
def _check_dtype_and_ndim(self): | |||
""" | |||
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | |||
通过将直接报错. | |||
:return: | |||
""" | |||
cell_0 = self.content[0] | |||
index = 0 | |||
try: | |||
type_0, dim_0 = _get_ele_type_and_dim(cell_0) | |||
for cell in self.content[1:]: | |||
index += 1 | |||
type_i, dim_i = _get_ele_type_and_dim(cell) | |||
if type_i!=type_0: | |||
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." | |||
".".format(type_i, index, type_0)) | |||
if dim_0!=dim_i: | |||
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " | |||
"dimension:{}.".format(dim_i, index, dim_0)) | |||
self._cell_ndim = dim_0 | |||
self.dtype = type_0 | |||
except SetInputOrTargetException as e: | |||
e.index = index | |||
raise e | |||
def append(self, val:Any): | |||
""" | |||
:param val: 把该val append到fieldarray。 | |||
:return: | |||
""" | |||
if (self._is_target or self._is_input) and self._ignore_type is False: | |||
type_, dim_ = _get_ele_type_and_dim(val) | |||
if self.dtype!=type_: | |||
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " | |||
f"previous values(type:{self.dtype}).") | |||
if self._cell_ndim!=dim_: | |||
raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " | |||
f"previous values(dim:{self._cell_ndim}).") | |||
self.content.append(val) | |||
else: | |||
# str, int, float混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
def _1d_list_check(self, val): | |||
"""如果不是1D list就报错 | |||
""" | |||
type_set = set((type(obj) for obj in val)) | |||
if any(obj not in self.BASIC_TYPES for obj in type_set): | |||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) | |||
self._basic_type_detection(type_set) | |||
# otherwise: _basic_type_detection will raise error | |||
return True | |||
def _2d_list_check(self, val): | |||
"""如果不是2D list 就报错 | |||
""" | |||
type_set = set(type(obj) for obj in val) | |||
if list(type_set) != [list]: | |||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
inner_type_set = set() | |||
for l in val: | |||
for obj in l: | |||
inner_type_set.add(type(obj)) | |||
self._basic_type_detection(inner_type_set) | |||
return True | |||
@staticmethod | |||
def _map_to_np_type(basic_type): | |||
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} | |||
return type_mapping[basic_type] | |||
def __repr__(self): | |||
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) | |||
def append(self, val): | |||
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 | |||
的内容是匹配的。 | |||
:param Any val: 需要append的值。 | |||
""" | |||
if self.ignore_type is False: | |||
if isinstance(val, list): | |||
pass | |||
elif isinstance(val, tuple): # 确保最外层是list | |||
val = list(val) | |||
elif isinstance(val, np.ndarray): | |||
val = val.tolist() | |||
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): | |||
pass | |||
else: | |||
raise RuntimeError( | |||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||
if self.is_input is True or self.is_target is True: | |||
if type(val) == list: | |||
if len(val) == 0: | |||
raise ValueError("Cannot append an empty list.") | |||
if self.content_dim == 2 and self._1d_list_check(val): | |||
# 1维list检查 | |||
pass | |||
elif self.content_dim == 3 and self._2d_list_check(val): | |||
# 2维list检查 | |||
pass | |||
else: | |||
raise RuntimeError( | |||
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) | |||
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: | |||
# scalar检查 | |||
if type(val) == float and self.pytype == int: | |||
self.pytype = float | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
else: | |||
raise RuntimeError( | |||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||
self.content.append(val) | |||
self.content.append(val) | |||
def __getitem__(self, indices): | |||
return self.get(indices, pad=False) | |||
def __setitem__(self, idx, val): | |||
assert isinstance(idx, int) | |||
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 | |||
type_, dim_ = _get_ele_type_and_dim(val) | |||
if self.dtype!=type_: | |||
raise RuntimeError(f"Value(type:{type_}) are of different types with " | |||
f"other values(type:{self.dtype}).") | |||
if self._cell_ndim!=dim_: | |||
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " | |||
f"previous values(dim:{self._cell_ndim}).") | |||
self.content[idx] = val | |||
def get(self, indices, pad=True): | |||
""" | |||
根据给定的indices返回内容 | |||
@@ -257,14 +170,17 @@ class FieldArray(object): | |||
if isinstance(indices, int): | |||
return self.content[indices] | |||
if self.is_input is False and self.is_target is False: | |||
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) | |||
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) | |||
contents = [self.content[i] for i in indices] | |||
if self.padder is None or pad is False: | |||
return np.array(contents) | |||
else: | |||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) | |||
return self.pad(contents) | |||
def pad(self, contents): | |||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
def set_padder(self, padder): | |||
""" | |||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | |||
@@ -276,7 +192,7 @@ class FieldArray(object): | |||
self.padder = deepcopy(padder) | |||
else: | |||
self.padder = None | |||
def set_pad_val(self, pad_val): | |||
""" | |||
修改padder的pad_val. | |||
@@ -286,7 +202,7 @@ class FieldArray(object): | |||
if self.padder is not None: | |||
self.padder.set_pad_val(pad_val) | |||
return self | |||
def __len__(self): | |||
""" | |||
Returns the size of FieldArray. | |||
@@ -294,7 +210,7 @@ class FieldArray(object): | |||
:return int length: | |||
""" | |||
return len(self.content) | |||
def to(self, other): | |||
""" | |||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | |||
@@ -303,22 +219,225 @@ class FieldArray(object): | |||
:param other: :class:`~fastNLP.FieldArray` 从哪个field拷贝属性 | |||
:return: :class:`~fastNLP.FieldArray` | |||
""" | |||
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) | |||
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) | |||
self.ignore_type = other.ignore_type | |||
self.is_input = other.is_input | |||
self.is_target = other.is_target | |||
self.padder = other.padder | |||
self.ignore_type = other.ignore_type | |||
return self | |||
def split(self, sep:str=None, inplace:bool=True): | |||
""" | |||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||
:param sep: 分割符,如果为None则直接调用str.split()。 | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[List[str]] or self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
new_contents.append(cell.split(sep)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def int(self, inplace:bool=True): | |||
""" | |||
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([int(value) for value in cell]) | |||
else: | |||
new_contents.append(int(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
print(e) | |||
return self._after_process(new_contents, inplace=inplace) | |||
def float(self, inplace=True): | |||
""" | |||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([float(value) for value in cell]) | |||
else: | |||
new_contents.append(float(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def bool(self, inplace=True): | |||
""" | |||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([bool(value) for value in cell]) | |||
else: | |||
new_contents.append(bool(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def lower(self, inplace=True): | |||
""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([value.lower() for value in cell]) | |||
else: | |||
new_contents.append(cell.lower()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def upper(self, inplace=True): | |||
""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([value.upper() for value in cell]) | |||
else: | |||
new_contents.append(cell.upper()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def _is_iterable(content): | |||
def value_count(self): | |||
""" | |||
返回该field下不同value的数量。多用于统计label数量 | |||
:return: Counter, key是label,value是出现次数 | |||
""" | |||
count = Counter() | |||
def cum(cell): | |||
if _is_iterable(cell) and not isinstance(cell, str): | |||
for cell_ in cell: | |||
cum(cell_) | |||
else: | |||
count[cell] += 1 | |||
for cell in self.content: | |||
cum(cell) | |||
return count | |||
def _after_process(self, new_contents, inplace): | |||
""" | |||
当调用处理函数之后,决定是否要替换field。 | |||
:param new_contents: | |||
:param inplace: | |||
:return: self或者生成的content | |||
""" | |||
if inplace: | |||
self.content = new_contents | |||
try: | |||
self.is_input = self.is_input | |||
self.is_target = self.is_input | |||
except SetInputOrTargetException as e: | |||
print("The newly generated field cannot be set as input or target.") | |||
raise e | |||
return self | |||
else: | |||
return new_contents | |||
def _get_ele_type_and_dim(cell:Any, dim=0): | |||
""" | |||
识别cell的类别与dimension的数量 | |||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||
:param cell: | |||
:param dim: | |||
:return: | |||
""" | |||
if isinstance(cell, (str, Number, np.bool_)): | |||
if hasattr(cell, 'dtype'): | |||
return cell.dtype.type, dim | |||
return type(cell), dim | |||
elif isinstance(cell, list): | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i,j in res]) | |||
dims = set([j for i,j in res]) | |||
if len(types)>1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types)==0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims)>1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
elif isinstance(cell, torch.Tensor): | |||
return cell.dtype, cell.dim() + dim # 如果是torch.mean的结果是0 | |||
elif isinstance(cell, np.ndarray): | |||
if cell.dtype != np.dtype('O'): # 如果不是object的话说明是well-formatted的了 | |||
return cell.dtype.type, cell.ndim + dim # dtype.type返回的会是np.int32, np.float等 | |||
# 否则需要继续往下iterate | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i,j in res]) | |||
dims = set([j for i,j in res]) | |||
if len(types)>1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types)==0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims)>1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
else: # 包含tuple, set, dict以及其它的类型 | |||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||
def _is_iterable(value): | |||
# 检查是否是iterable的, duck typing | |||
try: | |||
_ = (e for e in content) | |||
except TypeError: | |||
iter(value) | |||
return True | |||
except BaseException as e: | |||
return False | |||
return True | |||
class Padder: | |||
@@ -327,32 +446,35 @@ class Padder: | |||
所有padder都需要继承这个类,并覆盖__call__方法。 | |||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): | |||
传入的是List内容。假设有以下的DataSet。 | |||
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
:param str, field_name: field的名称。 | |||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||
:return: np.array([padded_element]) | |||
""" | |||
def __init__(self, pad_val=0, **kwargs): | |||
self.pad_val = pad_val | |||
def set_pad_val(self, pad_val): | |||
self.pad_val = pad_val | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
@abstractmethod | |||
def __call__(self, contents, field_name, field_ele_dtype, dim:int): | |||
""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
:param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
:param str, field_name: field的名称。 | |||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True, | |||
该这个值为None。 | |||
:param dim: 这个field的维度。当ignore_type为True时,该值为None | |||
:return: np.array([padded_element]) | |||
Example:: | |||
@@ -394,50 +516,86 @@ class AutoPadder(Padder): | |||
根据contents的数据自动判定是否需要做padding。 | |||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
型为np.str, [[1,2], ...]的元素类型为np.int64)的数据不为(np.int64, np.float64)则不会进行pad | |||
型为str, [[1,2], ...]的元素类型为int)的数据不为数值类型则不会进行pad | |||
2 如果元素类型为数值类型,比如np.int64, np.float64, int, float, torch.int64等 | |||
2 如果元素类型为(np.int64, np.float64), | |||
2.1 如果该field的内容为数值类型(包括int, float等),比如为seq_len, 则不进行padding | |||
2.1 如果该field的内容为(np.int64, np.float64),比如为seq_len, 则不进行padding | |||
2.2 如果该field的内容等价于一维list, 那么会将Batch中的List pad为一样长。 | |||
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 | |||
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad | |||
2.3 如果该field的内容等价于二维list,那么会按照英语character padding的方式进行padding。如果是character padding建议使用 | |||
:class: fastNLP.EngChar2DPadder. | |||
2.4 如果该field的内容等价于三维list,则如果每个instance在每个维度上相等,会组成一个batch的tensor返回,这种情况应该是为图片 | |||
的情况。 | |||
3 其它情况不进行处理,返回一个np.array类型。 | |||
""" | |||
def __init__(self, pad_val=0): | |||
""" | |||
:param pad_val: int, padding的位置使用该index | |||
""" | |||
super().__init__(pad_val=pad_val) | |||
def _is_two_dimension(self, contents): | |||
""" | |||
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 | |||
:param contents: | |||
:return: | |||
""" | |||
value = contents[0] | |||
if isinstance(value, (np.ndarray, list)): | |||
value = value[0] | |||
if isinstance(value, (np.ndarray, list)): | |||
return False | |||
return True | |||
return False | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
if not _is_iterable(contents[0]): | |||
array = np.array([content for content in contents], dtype=field_ele_dtype) | |||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||
max_len = max([len(content) for content in contents]) | |||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||
for i, content in enumerate(contents): | |||
array[i][:len(content)] = content | |||
elif field_ele_dtype is None: | |||
array = np.array(contents) # 当ignore_type=True时,直接返回contents | |||
else: # should only be str | |||
array = np.array([content for content in contents]) | |||
return array | |||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||
if field_ele_dtype: | |||
if dim>3: | |||
return np.array(contents) | |||
if isinstance(field_ele_dtype, type) and \ | |||
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): | |||
if dim==0: | |||
array = np.array(contents, dtype=field_ele_dtype) | |||
elif dim==1: | |||
max_len = max(map(len, contents)) | |||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
array[i, :len(content_i)] = content_i | |||
elif dim==2: | |||
max_len = max(map(len, contents)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in contents]) | |||
array = np.full((len(contents), max_len, max_word_len), self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
for j, content_ii in enumerate(content_i): | |||
array[i, j, :len(content_ii)] = content_ii | |||
else: | |||
shape = np.shape(contents) | |||
if len(shape)==4: # 说明各dimension是相同的大小 | |||
array = np.array(contents, dtype=field_ele_dtype) | |||
else: | |||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return array | |||
elif str(field_ele_dtype).startswith('torch'): | |||
if dim==0: | |||
tensor = torch.tensor(contents).to(field_ele_dtype) | |||
elif dim==1: | |||
max_len = max(map(len, contents)) | |||
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
tensor[i, :len(content_i)] = torch.tensor(content_i) | |||
elif dim==2: | |||
max_len = max(map(len, contents)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in contents]) | |||
tensor = torch.full((len(contents), max_len, max_word_len), fill_value=self.pad_val, | |||
dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
for j, content_ii in enumerate(content_i): | |||
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) | |||
else: | |||
shapes = set([np.shape(content_i) for content_i in contents]) | |||
if len(shapes)>1: | |||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
shape = shapes.pop() | |||
if len(shape)==3: | |||
tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) | |||
else: | |||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return tensor | |||
else: | |||
return np.array(contents) # 不进行任何操作 | |||
else: | |||
return np.array(contents) | |||
class EngChar2DPadder(Padder): | |||
@@ -463,7 +621,7 @@ class EngChar2DPadder(Padder): | |||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | |||
""" | |||
def __init__(self, pad_val=0, pad_length=0): | |||
""" | |||
:param pad_val: int, pad的位置使用该index | |||
@@ -471,32 +629,10 @@ class EngChar2DPadder(Padder): | |||
都pad或截取到该长度. | |||
""" | |||
super().__init__(pad_val=pad_val) | |||
self.pad_length = pad_length | |||
def _exactly_three_dims(self, contents, field_name): | |||
""" | |||
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character | |||
:param contents: | |||
:param field_name: str | |||
:return: | |||
""" | |||
if not isinstance(contents, list): | |||
raise TypeError("contents should be a list, not {}.".format(type(contents))) | |||
value = contents[0] | |||
try: | |||
value = value[0] | |||
except: | |||
raise ValueError("Field:{} only has one dimension.".format(field_name)) | |||
try: | |||
value = value[0] | |||
except: | |||
raise ValueError("Field:{} only has two dimensions.".format(field_name)) | |||
if _is_iterable(value): | |||
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||
""" | |||
期望输入类似于 | |||
[ | |||
@@ -510,24 +646,24 @@ class EngChar2DPadder(Padder): | |||
:param field_ele_dtype | |||
:return: | |||
""" | |||
if field_ele_dtype not in (np.int64, np.float64): | |||
if field_ele_dtype not in (np.int64, np.float64, int, float): | |||
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( | |||
field_name, field_ele_dtype | |||
)) | |||
self._exactly_three_dims(contents, field_name) | |||
assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." | |||
if self.pad_length < 1: | |||
max_char_length = max(max([[len(char_lst) for char_lst in word_lst] for word_lst in contents])) | |||
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) | |||
else: | |||
max_char_length = self.pad_length | |||
max_sent_length = max(len(word_lst) for word_lst in contents) | |||
batch_size = len(contents) | |||
dtype = type(contents[0][0][0]) | |||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | |||
dtype=dtype) | |||
for b_idx, word_lst in enumerate(contents): | |||
for c_idx, char_lst in enumerate(word_lst): | |||
chars = char_lst[:max_char_length] | |||
padded_array[b_idx, c_idx, :len(chars)] = chars | |||
return padded_array |
@@ -34,14 +34,23 @@ class LossBase(object): | |||
""" | |||
def __init__(self): | |||
self.param_map = {} | |||
self._param_map = {} # key是fun的参数,value是以该值从传入的dict取出value | |||
self._checked = False | |||
@property | |||
def param_map(self): | |||
if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||
func_spect = inspect.getfullargspec(self.get_loss) | |||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
for arg in func_args: | |||
self._param_map[arg] = arg | |||
return self._param_map | |||
def get_loss(self, *args, **kwargs): | |||
raise NotImplementedError | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
:param dict key_map: 表示key的映射关系 | |||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | |||
@@ -53,30 +62,30 @@ class LossBase(object): | |||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||
for key, value in key_map.items(): | |||
if value is None: | |||
self.param_map[key] = key | |||
self._param_map[key] = key | |||
continue | |||
if not isinstance(key, str): | |||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
if not isinstance(value, str): | |||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
self._param_map[key] = value | |||
value_counter[value].add(key) | |||
for key, value in kwargs.items(): | |||
if value is None: | |||
self.param_map[key] = key | |||
self._param_map[key] = key | |||
continue | |||
if not isinstance(value, str): | |||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
self._param_map[key] = value | |||
value_counter[value].add(key) | |||
for value, key_set in value_counter.items(): | |||
if len(key_set) > 1: | |||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||
# check consistence between signature and param_map | |||
# check consistence between signature and _param_map | |||
func_spect = inspect.getfullargspec(self.get_loss) | |||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
for func_param, input_param in self.param_map.items(): | |||
for func_param, input_param in self._param_map.items(): | |||
if func_param not in func_args: | |||
raise NameError( | |||
f"Parameter `{func_param}` is not in {_get_func_signature(self.get_loss)}. Please check the " | |||
@@ -96,7 +105,7 @@ class LossBase(object): | |||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||
""" | |||
fast_param = {} | |||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
fast_param['pred'] = list(pred_dict.values())[0] | |||
fast_param['target'] = list(target_dict.values())[0] | |||
return fast_param | |||
@@ -115,49 +124,41 @@ class LossBase(object): | |||
return loss | |||
if not self._checked: | |||
# 1. check consistence between signature and param_map | |||
# 1. check consistence between signature and _param_map | |||
func_spect = inspect.getfullargspec(self.get_loss) | |||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||
for func_arg, input_arg in self.param_map.items(): | |||
for func_arg, input_arg in self._param_map.items(): | |||
if func_arg not in func_args: | |||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.get_loss)}.") | |||
# 2. only part of the param_map are passed, left are not | |||
# 2. only part of the _param_map are passed, left are not | |||
for arg in func_args: | |||
if arg not in self.param_map: | |||
self.param_map[arg] = arg # This param does not need mapping. | |||
if arg not in self._param_map: | |||
self._param_map[arg] = arg # This param does not need mapping. | |||
self._evaluate_args = func_args | |||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||
# need to wrap inputs in dict. | |||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||
mapped_pred_dict = {} | |||
mapped_target_dict = {} | |||
duplicated = [] | |||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||
not_duplicate_flag = 0 | |||
if input_arg in self._reverse_param_map: | |||
mapped_arg = self._reverse_param_map[input_arg] | |||
not_duplicate_flag += 1 | |||
else: | |||
mapped_arg = input_arg | |||
for input_arg, mapped_arg in self._reverse_param_map.items(): | |||
if input_arg in pred_dict: | |||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||
not_duplicate_flag += 1 | |||
if input_arg in target_dict: | |||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||
not_duplicate_flag += 1 | |||
if not_duplicate_flag == 3: | |||
duplicated.append(input_arg) | |||
# missing | |||
if not self._checked: | |||
duplicated = [] | |||
for input_arg, mapped_arg in self._reverse_param_map.items(): | |||
if input_arg in pred_dict and input_arg in target_dict: | |||
duplicated.append(input_arg) | |||
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict]) | |||
# replace missing. | |||
missing = check_res.missing | |||
replaced_missing = list(missing) | |||
for idx, func_arg in enumerate(missing): | |||
# Don't delete `` in this information, nor add `` | |||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
f"in `{self.__class__.__name__}`)" | |||
check_res = _CheckRes(missing=replaced_missing, | |||
@@ -170,6 +171,8 @@ class LossBase(object): | |||
if check_res.missing or check_res.duplicated: | |||
raise _CheckError(check_res=check_res, | |||
func_signature=_get_func_signature(self.get_loss)) | |||
self._checked = True | |||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | |||
loss = self.get_loss(**refined_args) | |||
@@ -204,15 +207,12 @@ class LossFunc(LossBase): | |||
super(LossFunc, self).__init__() | |||
_check_function_or_method(func) | |||
self.get_loss = func | |||
if key_map is not None: | |||
if not isinstance(key_map, dict): | |||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | |||
self.param_map = key_map | |||
if len(kwargs) > 0: | |||
for key, val in kwargs.items(): | |||
self.param_map.update({key: val}) | |||
self._init_param_map(key_map, **kwargs) | |||
self.get_loss = func | |||
class CrossEntropyLoss(LossBase): | |||
@@ -232,12 +232,16 @@ class CrossEntropyLoss(LossBase): | |||
""" | |||
def __init__(self, pred=None, target=None, padding_idx=-100): | |||
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际需要(16,4) | |||
super(CrossEntropyLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
self.padding_idx = padding_idx | |||
def get_loss(self, pred, target): | |||
if pred.dim()>2: | |||
if pred.size()[:2]==target.size(): | |||
# F.cross_entropy在计算时,如果pred是(16, 10 ,4), 会在第二维上去log_softmax, 所以需要交换一下位置 | |||
pred = pred.transpose(1, 2) | |||
return F.cross_entropy(input=pred, target=target, | |||
ignore_index=self.padding_idx) | |||
@@ -22,7 +22,7 @@ from .utils import _check_arg_dict_list | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
class MetricBase(object): | |||
""" | |||
@@ -115,17 +115,28 @@ class MetricBase(object): | |||
""" | |||
def __init__(self): | |||
self.param_map = {} # key is param in function, value is input param. | |||
self._param_map = {} # key is param in function, value is input param. | |||
self._checked = False | |||
@property | |||
def param_map(self): | |||
if len(self._param_map) == 0: # 如果为空说明还没有初始化 | |||
func_spect = inspect.getfullargspec(self.evaluate) | |||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
for arg in func_args: | |||
self._param_map[arg] = arg | |||
return self._param_map | |||
@abstractmethod | |||
def evaluate(self, *args, **kwargs): | |||
raise NotImplementedError | |||
@abstractmethod | |||
def get_metric(self, reset=True): | |||
raise NotImplemented | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map | |||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
:param dict key_map: 表示key的映射关系 | |||
:param kwargs: key word args里面的每一个的键-值对都会被构造成映射关系 | |||
@@ -137,30 +148,30 @@ class MetricBase(object): | |||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||
for key, value in key_map.items(): | |||
if value is None: | |||
self.param_map[key] = key | |||
self._param_map[key] = key | |||
continue | |||
if not isinstance(key, str): | |||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||
if not isinstance(value, str): | |||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
self._param_map[key] = value | |||
value_counter[value].add(key) | |||
for key, value in kwargs.items(): | |||
if value is None: | |||
self.param_map[key] = key | |||
self._param_map[key] = key | |||
continue | |||
if not isinstance(value, str): | |||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | |||
self.param_map[key] = value | |||
self._param_map[key] = value | |||
value_counter[value].add(key) | |||
for value, key_set in value_counter.items(): | |||
if len(key_set) > 1: | |||
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") | |||
# check consistence between signature and param_map | |||
# check consistence between signature and _param_map | |||
func_spect = inspect.getfullargspec(self.evaluate) | |||
func_args = [arg for arg in func_spect.args if arg != 'self'] | |||
for func_param, input_param in self.param_map.items(): | |||
for func_param, input_param in self._param_map.items(): | |||
if func_param not in func_args: | |||
raise NameError( | |||
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " | |||
@@ -175,7 +186,7 @@ class MetricBase(object): | |||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||
""" | |||
fast_param = {} | |||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | |||
fast_param['pred'] = list(pred_dict.values())[0] | |||
fast_param['target'] = list(target_dict.values())[0] | |||
return fast_param | |||
@@ -204,42 +215,35 @@ class MetricBase(object): | |||
if not self._checked: | |||
if not callable(self.evaluate): | |||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
# 1. check consistence between signature and param_map | |||
# 1. check consistence between signature and _param_map | |||
func_spect = inspect.getfullargspec(self.evaluate) | |||
func_args = set([arg for arg in func_spect.args if arg != 'self']) | |||
for func_arg, input_arg in self.param_map.items(): | |||
for func_arg, input_arg in self._param_map.items(): | |||
if func_arg not in func_args: | |||
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") | |||
# 2. only part of the param_map are passed, left are not | |||
# 2. only part of the _param_map are passed, left are not | |||
for arg in func_args: | |||
if arg not in self.param_map: | |||
self.param_map[arg] = arg # This param does not need mapping. | |||
if arg not in self._param_map: | |||
self._param_map[arg] = arg # This param does not need mapping. | |||
self._evaluate_args = func_args | |||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self._param_map.items()} | |||
# need to wrap inputs in dict. | |||
mapped_pred_dict = {} | |||
mapped_target_dict = {} | |||
duplicated = [] | |||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||
not_duplicate_flag = 0 | |||
if input_arg in self._reverse_param_map: | |||
mapped_arg = self._reverse_param_map[input_arg] | |||
not_duplicate_flag += 1 | |||
else: | |||
mapped_arg = input_arg | |||
for input_arg, mapped_arg in self._reverse_param_map.items(): | |||
if input_arg in pred_dict: | |||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||
not_duplicate_flag += 1 | |||
if input_arg in target_dict: | |||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||
not_duplicate_flag += 1 | |||
if not_duplicate_flag == 3: | |||
duplicated.append(input_arg) | |||
# missing | |||
if not self._checked: | |||
duplicated = [] | |||
for input_arg, mapped_arg in self._reverse_param_map.items(): | |||
if input_arg in pred_dict and input_arg in target_dict: | |||
duplicated.append(input_arg) | |||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | |||
# only check missing. | |||
# replace missing. | |||
@@ -247,7 +251,7 @@ class MetricBase(object): | |||
replaced_missing = list(missing) | |||
for idx, func_arg in enumerate(missing): | |||
# Don't delete `` in this information, nor add `` | |||
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
replaced_missing[idx] = f"{self._param_map[func_arg]}" + f"(assign to `{func_arg}` " \ | |||
f"in `{self.__class__.__name__}`)" | |||
check_res = _CheckRes(missing=replaced_missing, | |||
@@ -260,10 +264,10 @@ class MetricBase(object): | |||
if check_res.missing or check_res.duplicated: | |||
raise _CheckError(check_res=check_res, | |||
func_signature=_get_func_signature(self.evaluate)) | |||
self._checked = True | |||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | |||
self.evaluate(**refined_args) | |||
self._checked = True | |||
return | |||
@@ -409,6 +413,37 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||
] | |||
def _bioes_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bioes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bioes_tag, label = tag[:1], tag[2:] | |||
if bioes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bioes_tag == 'o': | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bioes_tag = bioes_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def _bio_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||
@@ -438,7 +473,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
class SpanFPreRecMetric(MetricBase): | |||
""" | |||
r""" | |||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
@@ -469,15 +504,15 @@ class SpanFPreRecMetric(MetricBase): | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。 | |||
:param str encoding_type: 目前支持bio, bmes | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||
个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||
label的f1, pre, rec | |||
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | |||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | |||
@@ -497,6 +532,8 @@ class SpanFPreRecMetric(MetricBase): | |||
self.tag_to_span_func = _bio_tag_to_spans | |||
elif self.encoding_type == 'bmeso': | |||
self.tag_to_span_func = _bmeso_tag_to_spans | |||
elif self.encoding_type == 'bioes': | |||
self.tag_to_span_func = _bioes_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||
@@ -699,17 +736,17 @@ def _pred_topk(y_prob, k=1): | |||
class SQuADMetric(MetricBase): | |||
""" | |||
r""" | |||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||
SQuAD数据集metric | |||
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | |||
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` | |||
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` | |||
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` | |||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||
@@ -6,7 +6,7 @@ from collections import defaultdict | |||
import torch | |||
from . import Batch | |||
from . import DataSetIter | |||
from . import DataSet | |||
from . import SequentialSampler | |||
from .utils import _build_args | |||
@@ -44,8 +44,7 @@ class Predictor(object): | |||
self.network.eval() | |||
batch_output = defaultdict(list) | |||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, | |||
prefetch=False) | |||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
if hasattr(self.network, "predict"): | |||
predict_func = self.network.predict | |||
@@ -37,7 +37,7 @@ import warnings | |||
import torch | |||
import torch.nn as nn | |||
from .batch import Batch | |||
from .batch import BatchIter, DataSetIter | |||
from .dataset import DataSet | |||
from .metrics import _prepare_metrics | |||
from .sampler import SequentialSampler | |||
@@ -82,7 +82,7 @@ class Tester(object): | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
""" | |||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): | |||
super(Tester, self).__init__() | |||
if not isinstance(data, DataSet): | |||
@@ -96,6 +96,14 @@ class Tester(object): | |||
self._model = _move_model_to_device(model, device=device) | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
if isinstance(data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
dataset=data, batch_size=batch_size, num_workers=num_workers, sampler=SequentialSampler()) | |||
elif isinstance(data, BatchIter): | |||
self.data_iterator = data | |||
else: | |||
raise TypeError("data type {} not support".format(type(data))) | |||
# 如果是DataParallel将没有办法使用predict方法 | |||
if isinstance(self._model, nn.DataParallel): | |||
@@ -112,7 +120,10 @@ class Tester(object): | |||
raise TypeError(f"`{_model_name}.predict` must be callable to be used " | |||
f"for evaluation, not `{type(self._predict_func)}`.") | |||
else: | |||
self._predict_func = self._model.forward | |||
if isinstance(model, nn.DataParallel): | |||
self._predict_func = self._model.module.forward | |||
else: | |||
self._predict_func = self._model.forward | |||
def test(self): | |||
"""开始进行验证,并返回验证结果。 | |||
@@ -124,7 +135,7 @@ class Tester(object): | |||
self._model_device = _get_model_device(self._model) | |||
network = self._model | |||
self._mode(network, is_test=True) | |||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
data_iterator = self.data_iterator | |||
eval_results = {} | |||
try: | |||
with torch.no_grad(): | |||
@@ -311,8 +311,9 @@ try: | |||
from tqdm.auto import tqdm | |||
except: | |||
from .utils import _pseudo_tqdm as tqdm | |||
import warnings | |||
from .batch import Batch | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import CallbackManager, CallbackException | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
@@ -320,7 +321,6 @@ from .metrics import _prepare_metrics | |||
from .optimizer import Optimizer | |||
from .sampler import Sampler | |||
from .sampler import RandomSampler | |||
from .sampler import SequentialSampler | |||
from .tester import Tester | |||
from .utils import _CheckError | |||
from .utils import _build_args | |||
@@ -351,6 +351,8 @@ class Trainer(object): | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||
:param int n_epochs: 需要优化迭代多少次。 | |||
@@ -367,7 +369,6 @@ class Trainer(object): | |||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | |||
保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。理论上会使得Batch迭代更快。 | |||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
@@ -394,16 +395,17 @@ class Trainer(object): | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
batch_size=32, sampler=None, update_every=1, | |||
n_epochs=10, print_every=5, | |||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||
num_workers=0, n_epochs=10, print_every=5, | |||
dev_data=None, metrics=None, metric_key=None, | |||
validate_every=-1, save_path=None, | |||
prefetch=False, use_tqdm=True, device=None, | |||
callbacks=None, | |||
check_code_level=0): | |||
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, | |||
callbacks=None, check_code_level=0): | |||
if prefetch and num_workers==0: | |||
num_workers = 1 | |||
if prefetch: | |||
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") | |||
super(Trainer, self).__init__() | |||
if not isinstance(train_data, DataSet): | |||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
@@ -430,18 +432,30 @@ class Trainer(object): | |||
if metric_key is not None: | |||
self.increase_better = False if metric_key[0] == "-" else True | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
elif len(metrics) > 0: | |||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||
else: | |||
self.metric_key = None | |||
# prepare loss | |||
losser = _prepare_losser(loss) | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, Sampler): | |||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||
if check_code_level > -1: | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
if sampler is None: | |||
sampler = RandomSampler() | |||
if isinstance(train_data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||
elif isinstance(train_data, BatchIter): | |||
self.data_iterator = train_data | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||
self.model = _move_model_to_device(model, device=device) | |||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
_check_code(dataset=train_data, model=self.model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
@@ -460,13 +474,9 @@ class Trainer(object): | |||
self.best_dev_epoch = None | |||
self.best_dev_step = None | |||
self.best_dev_perf = None | |||
self.sampler = sampler if sampler is not None else RandomSampler() | |||
self.prefetch = prefetch | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
self.model = _move_model_to_device(self.model, device=device) | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
@@ -493,13 +503,16 @@ class Trainer(object): | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True): | |||
def train(self, load_best_model=True, on_exception='auto'): | |||
""" | |||
使用该函数使Trainer开始训练。 | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效, | |||
如果True, trainer将在返回之前重新加载dev表现最好的模型参数。 | |||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||
最好的模型参数。 | |||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | |||
支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; | |||
'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. | |||
:return dict: 返回一个字典类型的数据, | |||
内含以下内容:: | |||
@@ -528,10 +541,16 @@ class Trainer(object): | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.on_train_end() | |||
except (CallbackException, KeyboardInterrupt) as e: | |||
except BaseException as e: | |||
self.callback_manager.on_exception(e) | |||
if on_exception == 'auto': | |||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
raise e | |||
elif on_exception == 'raise': | |||
raise e | |||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | |||
if self.dev_data is not None and self.best_dev_perf is not None: | |||
print( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf), ) | |||
@@ -559,12 +578,14 @@ class Trainer(object): | |||
self.step = 0 | |||
self.epoch = 0 | |||
start = time.time() | |||
if isinstance(self.model, nn.DataParallel): | |||
self._forward_func = self.model.module.forward | |||
else: | |||
self._forward_func = self.model.forward | |||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
self.pbar = pbar | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
prefetch=self.prefetch) | |||
data_iterator = self.data_iterator | |||
self.batch_per_epoch = data_iterator.num_batches | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.epoch = epoch | |||
@@ -664,11 +685,11 @@ class Trainer(object): | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
x = _build_args(network.forward, **x) | |||
x = _build_args(self._forward_func, **x) | |||
y = network(**x) | |||
if not isinstance(y, dict): | |||
raise TypeError( | |||
f"The return value of {_get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | |||
return y | |||
def _grad_backward(self, loss): | |||
@@ -737,7 +758,9 @@ class Trainer(object): | |||
:return bool value: True means current results on dev set is the best. | |||
""" | |||
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||
indicator, indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) | |||
if self.metric_key is None: | |||
self.metric_key = indicator | |||
is_better = True | |||
if self.best_metric_indicator is None: | |||
# first-time validation | |||
@@ -776,15 +799,34 @@ def _get_value_info(_dict): | |||
strs.append(_str) | |||
return strs | |||
from numbers import Number | |||
from .batch import _to_tensor | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, | |||
check_level=0): | |||
# check get_loss 方法 | |||
model_devcie = model.parameters().__next__().device | |||
model_devcie = _get_model_device(model=model) | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
def _iter(): | |||
start_idx = 0 | |||
while start_idx<len(dataset): | |||
batch_x = {} | |||
batch_y = {} | |||
for field_name, field in dataset.get_all_fields().items(): | |||
indices = list(range(start_idx, min(start_idx+batch_size, len(dataset)))) | |||
if field.is_target or field.is_input: | |||
batch = field.get(indices) | |||
if field.dtype is not None and \ | |||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||
batch, _ = _to_tensor(batch, field.dtype) | |||
if field.is_target: | |||
batch_y[field_name] = batch | |||
if field.is_input: | |||
batch_x[field_name] = batch | |||
yield (batch_x, batch_y) | |||
start_idx += batch_size | |||
for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
# forward check | |||
if batch_count == 0: | |||
@@ -806,8 +848,11 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
print(info_str) | |||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
if isinstance(model, nn.DataParallel): | |||
forward_func = model.module.forward | |||
else: | |||
forward_func = model.forward | |||
refined_batch_x = _build_args(forward_func, **batch_x) | |||
pred_dict = model(**refined_batch_x) | |||
func_signature = _get_func_signature(model.forward) | |||
if not isinstance(pred_dict, dict): | |||
@@ -852,26 +897,16 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||
loss, metrics = metrics | |||
if isinstance(metrics, dict): | |||
if len(metrics) == 1: | |||
# only single metric, just use it | |||
metric_dict = list(metrics.values())[0] | |||
metrics_name = list(metrics.keys())[0] | |||
else: | |||
metrics_name = metric_list[0].__class__.__name__ | |||
if metrics_name not in metrics: | |||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||
metric_dict = metrics[metrics_name] | |||
metric_dict = list(metrics.values())[0] # 取第一个metric | |||
if len(metric_dict) == 1: | |||
if metric_key is None: | |||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | |||
elif len(metric_dict) > 1 and metric_key is None: | |||
raise RuntimeError( | |||
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") | |||
else: | |||
# metric_key is set | |||
if metric_key not in metric_dict: | |||
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") | |||
indicator_val = metric_dict[metric_key] | |||
indicator = metric_key | |||
else: | |||
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) | |||
return indicator_val | |||
return indicator, indicator_val |
@@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 | |||
""" | |||
__all__ = [ | |||
"cache_results", | |||
"seq_len_to_mask" | |||
"seq_len_to_mask", | |||
"Option", | |||
] | |||
import _pickle | |||
@@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||
'varargs']) | |||
class Option(dict): | |||
"""a dict can treat keys as attributes""" | |||
def __getattr__(self, item): | |||
try: | |||
return self.__getitem__(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __setattr__(self, key, value): | |||
if key.startswith('__') and key.endswith('__'): | |||
raise AttributeError(key) | |||
self.__setitem__(key, value) | |||
def __delattr__(self, item): | |||
try: | |||
self.pop(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __getstate__(self): | |||
return self | |||
def __setstate__(self, state): | |||
self.update(state) | |||
def _prepare_cache_filepath(filepath): | |||
""" | |||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||
@@ -258,6 +285,7 @@ def _get_model_device(model): | |||
:param model: nn.Module | |||
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 | |||
""" | |||
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding | |||
assert isinstance(model, nn.Module) | |||
parameters = list(model.parameters()) | |||
@@ -268,6 +296,13 @@ def _get_model_device(model): | |||
def _build_args(func, **kwargs): | |||
""" | |||
根据func的初始化参数,从kwargs中选择func需要的参数 | |||
:param func: callable | |||
:param kwargs: 参数 | |||
:return:dict. func中用到的参数 | |||
""" | |||
spect = inspect.getfullargspec(func) | |||
if spect.varkw is not None: | |||
return kwargs | |||
@@ -608,7 +643,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
warnings.warn(message=_unused_warn) | |||
def seq_len_to_mask(seq_len): | |||
def seq_len_to_mask(seq_len, max_len=None): | |||
""" | |||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||
@@ -624,20 +659,26 @@ def seq_len_to_mask(seq_len): | |||
>>> mask = seq_len_to_mask(seq_len) | |||
>>> print(mask.shape) | |||
(14, 15) | |||
>>> seq_len = torch.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len, max_len=100) | |||
>>>print(mask.size()) | |||
torch.Size([14, 100]) | |||
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,) | |||
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 | |||
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 | |||
:return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8 | |||
""" | |||
if isinstance(seq_len, np.ndarray): | |||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
max_len = int(seq_len.max()) | |||
max_len = int(max_len) if max_len else int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
elif isinstance(seq_len, torch.Tensor): | |||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||
batch_size = seq_len.size(0) | |||
max_len = seq_len.max().long() | |||
max_len = int(max_len) if max_len else seq_len.max().long() | |||
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) | |||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||
else: | |||
@@ -1,11 +1,27 @@ | |||
__all__ = [ | |||
"Vocabulary" | |||
"Vocabulary", | |||
"VocabularyOption", | |||
] | |||
from functools import wraps | |||
from collections import Counter | |||
from collections import Counter, defaultdict | |||
from .dataset import DataSet | |||
from .utils import Option | |||
from functools import partial | |||
import numpy as np | |||
class VocabularyOption(Option): | |||
def __init__(self, | |||
max_size=None, | |||
min_freq=None, | |||
padding='<pad>', | |||
unknown='<unk>'): | |||
super().__init__( | |||
max_size=max_size, | |||
min_freq=min_freq, | |||
padding=padding, | |||
unknown=unknown | |||
) | |||
def _check_build_vocab(func): | |||
@@ -74,7 +90,9 @@ class Vocabulary(object): | |||
self.word2idx = None | |||
self.idx2word = None | |||
self.rebuild = True | |||
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | |||
self._no_create_word = defaultdict(int) | |||
@_check_build_status | |||
def update(self, word_lst): | |||
"""依次增加序列中词在词典中的出现频率 | |||
@@ -133,7 +151,7 @@ class Vocabulary(object): | |||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
self.rebuild = False | |||
def build_reverse_vocab(self): | |||
""" | |||
基于 "word to index" dict, 构建 "index to word" dict. | |||
@@ -225,8 +243,12 @@ class Vocabulary(object): | |||
raise e | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
def from_dataset(self, *datasets, field_name): | |||
@property | |||
def _no_create_word_length(self): | |||
return len(self._no_create_word) | |||
def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None): | |||
""" | |||
使用dataset的对应field中词构建词典:: | |||
@@ -238,6 +260,13 @@ class Vocabulary(object): | |||
构建词典所使用的 field(s), 支持一个或多个field | |||
若有多个 DataSet, 每个DataSet都必须有这些field. | |||
目前仅支持的field结构: ``str`` , ``list(str)`` , ``list(list(str))`` | |||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain | |||
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | |||
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | |||
如果一个词出现在了train中,但是没在预训练模型中,embedding会为它用unk初始化,但它是单独的一个vector,如果 | |||
finetune embedding的话,这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector, | |||
而应该让它指向unk这个vector的值。所以只位于no_create_entry_dataset中的token,将首先从预训练的词表中寻找它的表示, | |||
如果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。 | |||
:return self: | |||
""" | |||
if isinstance(field_name, str): | |||
@@ -245,19 +274,28 @@ class Vocabulary(object): | |||
elif not isinstance(field_name, list): | |||
raise TypeError('invalid argument field_name: {}'.format(field_name)) | |||
def construct_vocab(ins): | |||
def construct_vocab(ins, no_create_entry=False): | |||
for fn in field_name: | |||
field = ins[fn] | |||
if isinstance(field, str): | |||
if no_create_entry and field not in self.word_count: | |||
self._no_create_word[field] += 1 | |||
self.add_word(field) | |||
elif isinstance(field, list): | |||
if not isinstance(field[0], list): | |||
self.add_word_lst(field) | |||
elif isinstance(field, (list, np.ndarray)): | |||
if not isinstance(field[0], (list, np.ndarray)): | |||
for word in field: | |||
if no_create_entry and word not in self.word_count: | |||
self._no_create_word[word] += 1 | |||
self.add_word(word) | |||
else: | |||
if isinstance(field[0][0], list): | |||
if isinstance(field[0][0], (list, np.ndarray)): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
[self.add_word_lst(w) for w in field] | |||
for words in field: | |||
for word in words: | |||
if no_create_entry and word not in self.word_count: | |||
self._no_create_word[word] += 1 | |||
self.add_word(word) | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
@@ -266,9 +304,27 @@ class Vocabulary(object): | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
raise e | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
raise TypeError("Only DataSet type is allowed.") | |||
if no_create_entry_dataset is not None: | |||
partial_construct_vocab = partial(construct_vocab, no_create_entry=True) | |||
if isinstance(no_create_entry_dataset, DataSet): | |||
no_create_entry_dataset.apply(partial_construct_vocab) | |||
elif isinstance(no_create_entry_dataset, list): | |||
for dataset in no_create_entry_dataset: | |||
if not isinstance(dataset, DataSet): | |||
raise TypeError("Only DataSet type is allowed.") | |||
dataset.apply(partial_construct_vocab) | |||
return self | |||
def _is_word_no_create_entry(self, word): | |||
""" | |||
判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 | |||
:param word: str | |||
:return: bool | |||
""" | |||
return word in self._no_create_word | |||
def to_index(self, w): | |||
""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 | |||
@@ -323,6 +379,7 @@ class Vocabulary(object): | |||
self.word2idx = None | |||
self.idx2word = None | |||
self.rebuild = True | |||
self._no_create_word.clear() | |||
def __getstate__(self): | |||
"""Use to prepare data for pickle. | |||
@@ -344,5 +401,7 @@ class Vocabulary(object): | |||
def __repr__(self): | |||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | |||
@_check_build_vocab | |||
def __iter__(self): | |||
return iter(list(self.word_count.keys())) | |||
for word, index in self.word2idx.items(): | |||
yield word, index |
@@ -26,6 +26,6 @@ __all__ = [ | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||
PeopleDailyCorpusLoader, Conll2003Loader | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, \ | |||
SNLILoader, SSTLoader, PeopleDailyCorpusLoader, Conll2003Loader | |||
from .model_io import ModelLoader, ModelSaver |
@@ -1,10 +1,14 @@ | |||
__all__ = [ | |||
"BaseLoader" | |||
"BaseLoader", | |||
'DataInfo', | |||
'DataSetLoader', | |||
] | |||
import _pickle as pickle | |||
import os | |||
from typing import Union, Dict | |||
import os | |||
from ..core.dataset import DataSet | |||
class BaseLoader(object): | |||
""" | |||
@@ -51,24 +55,169 @@ class BaseLoader(object): | |||
return obj | |||
class DataLoaderRegister: | |||
_readers = {} | |||
@classmethod | |||
def set_reader(cls, reader_cls, read_fn_name): | |||
# def wrapper(reader_cls): | |||
if read_fn_name in cls._readers: | |||
raise KeyError( | |||
'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, | |||
read_fn_name)) | |||
if hasattr(reader_cls, 'load'): | |||
cls._readers[read_fn_name] = reader_cls().load | |||
return reader_cls | |||
@classmethod | |||
def get_reader(cls, read_fn_name): | |||
if read_fn_name in cls._readers: | |||
return cls._readers[read_fn_name] | |||
raise AttributeError('no read function: {}'.format(read_fn_name)) | |||
# TODO 这个类使用在何处? | |||
def _download_from_url(url, path): | |||
try: | |||
from tqdm.auto import tqdm | |||
except: | |||
from ..core.utils import _pseudo_tqdm as tqdm | |||
import requests | |||
"""Download file""" | |||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||
chunk_size = 16 * 1024 | |||
total_size = int(r.headers.get('Content-length', 0)) | |||
with open(path, "wb") as file, \ | |||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||
for chunk in r.iter_content(chunk_size): | |||
if chunk: | |||
file.write(chunk) | |||
t.update(len(chunk)) | |||
def _uncompress(src, dst): | |||
import zipfile | |||
import gzip | |||
import tarfile | |||
import os | |||
def unzip(src, dst): | |||
with zipfile.ZipFile(src, 'r') as f: | |||
f.extractall(dst) | |||
def ungz(src, dst): | |||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||
length = 16 * 1024 # 16KB | |||
buf = f.read(length) | |||
while buf: | |||
uf.write(buf) | |||
buf = f.read(length) | |||
def untar(src, dst): | |||
with tarfile.open(src, 'r:gz') as f: | |||
f.extractall(dst) | |||
fn, ext = os.path.splitext(src) | |||
_, ext_2 = os.path.splitext(fn) | |||
if ext == '.zip': | |||
unzip(src, dst) | |||
elif ext == '.gz' and ext_2 != '.tar': | |||
ungz(src, dst) | |||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||
untar(src, dst) | |||
else: | |||
raise ValueError('unsupported file {}'.format(src)) | |||
class DataInfo: | |||
""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||
""" | |||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||
self.vocabs = vocabs or {} | |||
self.embeddings = embeddings or {} | |||
self.datasets = datasets or {} | |||
def __repr__(self): | |||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
class DataSetLoader: | |||
""" | |||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 | |||
开发者至少应该编写如下内容: | |||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` | |||
**process 函数中可以 调用load 函数或 _load 函数** | |||
""" | |||
URL = '' | |||
DATA_DIR = '' | |||
ROOT_DIR = '.fastnlp/datasets/' | |||
UNCOMPRESS = True | |||
def _download(self, url: str, pdir: str, uncompress=True) -> str: | |||
""" | |||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 | |||
:param url: 下载的网站 | |||
:param pdir: 下载到的目录 | |||
:param uncompress: 是否自动解压缩 | |||
:return: 数据的存放路径 | |||
""" | |||
fn = os.path.basename(url) | |||
path = os.path.join(pdir, fn) | |||
"""check data exists""" | |||
if not os.path.exists(path): | |||
os.makedirs(pdir, exist_ok=True) | |||
_download_from_url(url, path) | |||
if uncompress: | |||
dst = os.path.join(pdir, 'data') | |||
if not os.path.exists(dst): | |||
_uncompress(path, dst) | |||
return dst | |||
return path | |||
def download(self): | |||
return self._download( | |||
self.URL, | |||
os.path.join(self.ROOT_DIR, self.DATA_DIR), | |||
uncompress=self.UNCOMPRESS) | |||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 | |||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 | |||
:param Union[str, Dict[str, str]] paths: 文件路径 | |||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 | |||
""" | |||
if isinstance(paths, str): | |||
return self._load(paths) | |||
return {name: self._load(path) for name, path in paths.items()} | |||
def _load(self, path: str) -> DataSet: | |||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 | |||
:param str path: 文件路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||
""" | |||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||
返回的 :class:`DataInfo` 对象有如下属性: | |||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||
- embeddings: (可选) 数据集对应的词嵌入 | |||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||
:param paths: 原始数据读取的路径 | |||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||
:return: 返回一个 DataInfo | |||
""" | |||
raise NotImplementedError |
@@ -0,0 +1,95 @@ | |||
from typing import Iterable | |||
from nltk import Tree | |||
from ..base_loader import DataInfo, DataSetLoader | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||
class SSTLoader(DataSetLoader): | |||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||
DATA_DIR = 'sst/' | |||
""" | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||
读取SST数据集, DataSet包含fields:: | |||
words: list(str) 需要分类的文本 | |||
target: str 文本的标签 | |||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
""" | |||
def __init__(self, subtree=False, fine_grained=False): | |||
self.subtree = subtree | |||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||
'3': 'positive', '4': 'very positive'} | |||
if not fine_grained: | |||
tag_v['0'] = tag_v['1'] | |||
tag_v['4'] = tag_v['3'] | |||
self.tag_v = tag_v | |||
def _load(self, path): | |||
""" | |||
:param str path: 存储数据的路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
datas = [] | |||
for l in f: | |||
datas.extend([(s, self.tag_v[t]) | |||
for s, t in self._get_one(l, self.subtree)]) | |||
ds = DataSet() | |||
for words, tag in datas: | |||
ds.append(Instance(words=words, target=tag)) | |||
return ds | |||
@staticmethod | |||
def _get_one(data, subtree): | |||
tree = Tree.fromstring(data) | |||
if subtree: | |||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||
return [(tree.leaves(), tree.label())] | |||
def process(self, | |||
paths, | |||
train_ds: Iterable[str] = None, | |||
src_vocab_op: VocabularyOption = None, | |||
tgt_vocab_op: VocabularyOption = None, | |||
src_embed_op: EmbeddingOption = None): | |||
input_name, target_name = 'words', 'target' | |||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||
info = DataInfo(datasets=self.load(paths)) | |||
_train_ds = [info.datasets[name] | |||
for name in train_ds] if train_ds else info.datasets.values() | |||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||
src_vocab.index_dataset( | |||
*info.datasets.values(), | |||
field_name=input_name, new_field_name=input_name) | |||
tgt_vocab.index_dataset( | |||
*info.datasets.values(), | |||
field_name=target_name, new_field_name=target_name) | |||
info.vocabs = { | |||
input_name: src_vocab, | |||
target_name: tgt_vocab | |||
} | |||
if src_embed_op is not None: | |||
src_embed_op.vocab = src_vocab | |||
init_emb = EmbedLoader.load_with_vocab(**src_embed_op) | |||
info.embeddings[input_name] = init_emb | |||
return info | |||
@@ -13,8 +13,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 | |||
""" | |||
__all__ = [ | |||
'DataInfo', | |||
'DataSetLoader', | |||
'CSVLoader', | |||
'JsonLoader', | |||
'ConllLoader', | |||
@@ -25,158 +23,18 @@ __all__ = [ | |||
'Conll2003Loader', | |||
] | |||
from nltk.tree import Tree | |||
import os | |||
from nltk import Tree | |||
from typing import Union, Dict | |||
from ..core.vocabulary import Vocabulary | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from ..core.vocabulary import Vocabulary | |||
from .file_reader import _read_csv, _read_json, _read_conll | |||
from typing import Union, Dict | |||
import os | |||
def _download_from_url(url, path): | |||
try: | |||
from tqdm.auto import tqdm | |||
except: | |||
from ..core.utils import _pseudo_tqdm as tqdm | |||
import requests | |||
"""Download file""" | |||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||
chunk_size = 16 * 1024 | |||
total_size = int(r.headers.get('Content-length', 0)) | |||
with open(path, "wb") as file, \ | |||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||
for chunk in r.iter_content(chunk_size): | |||
if chunk: | |||
file.write(chunk) | |||
t.update(len(chunk)) | |||
return | |||
def _uncompress(src, dst): | |||
import zipfile | |||
import gzip | |||
import tarfile | |||
import os | |||
def unzip(src, dst): | |||
with zipfile.ZipFile(src, 'r') as f: | |||
f.extractall(dst) | |||
def ungz(src, dst): | |||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||
length = 16 * 1024 # 16KB | |||
buf = f.read(length) | |||
while buf: | |||
uf.write(buf) | |||
buf = f.read(length) | |||
def untar(src, dst): | |||
with tarfile.open(src, 'r:gz') as f: | |||
f.extractall(dst) | |||
fn, ext = os.path.splitext(src) | |||
_, ext_2 = os.path.splitext(fn) | |||
if ext == '.zip': | |||
unzip(src, dst) | |||
elif ext == '.gz' and ext_2 != '.tar': | |||
ungz(src, dst) | |||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||
untar(src, dst) | |||
else: | |||
raise ValueError('unsupported file {}'.format(src)) | |||
class DataInfo: | |||
""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||
""" | |||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||
self.vocabs = vocabs or {} | |||
self.embeddings = embeddings or {} | |||
self.datasets = datasets or {} | |||
class DataSetLoader: | |||
""" | |||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||
定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 | |||
开发者至少应该编写如下内容: | |||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` | |||
**process 函数中可以 调用load 函数或 _load 函数** | |||
""" | |||
def _download(self, url: str, path: str, uncompress=True) -> str: | |||
""" | |||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 | |||
:param url: 下载的网站 | |||
:param path: 下载到的目录 | |||
:param uncompress: 是否自动解压缩 | |||
:return: 数据的存放路径 | |||
""" | |||
pdir = os.path.dirname(path) | |||
os.makedirs(pdir, exist_ok=True) | |||
_download_from_url(url, path) | |||
if uncompress: | |||
dst = os.path.join(pdir, 'data') | |||
_uncompress(path, dst) | |||
return dst | |||
return path | |||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 | |||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 | |||
:param Union[str, Dict[str, str]] paths: 文件路径 | |||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 | |||
""" | |||
if isinstance(paths, str): | |||
return self._load(paths) | |||
return {name: self._load(path) for name, path in paths.items()} | |||
def _load(self, path: str) -> DataSet: | |||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 | |||
:param str path: 文件路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||
""" | |||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||
返回的 :class:`DataInfo` 对象有如下属性: | |||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||
- embeddings: (可选) 数据集对应的词嵌入 | |||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||
:param paths: 原始数据读取的路径 | |||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||
:return: 返回一个 DataInfo | |||
""" | |||
raise NotImplementedError | |||
from .base_loader import DataSetLoader, DataInfo | |||
from .data_loader.sst import SSTLoader | |||
from ..core.const import Const | |||
from ..modules.encoder._bert import BertTokenizer | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
@@ -185,12 +43,12 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
读取人民日报数据集 | |||
""" | |||
def __init__(self, pos=True, ner=True): | |||
super(PeopleDailyCorpusLoader, self).__init__() | |||
self.pos = pos | |||
self.ner = ner | |||
def _load(self, data_path): | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
sents = f.readlines() | |||
@@ -235,7 +93,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
example.append(sent_ner) | |||
examples.append(example) | |||
return self.convert(examples) | |||
def convert(self, data): | |||
""" | |||
@@ -263,7 +121,8 @@ class ConllLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` | |||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html | |||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||
该符号在conll 2003中被用为文档分割符。 | |||
列号从0开始, 每列对应内容为:: | |||
@@ -286,7 +145,7 @@ class ConllLoader(DataSetLoader): | |||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=False): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
@@ -300,7 +159,7 @@ class ConllLoader(DataSetLoader): | |||
if len(indexes) != len(headers): | |||
raise ValueError | |||
self.indexes = indexes | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
@@ -318,7 +177,7 @@ class Conll2003Loader(ConllLoader): | |||
关于数据集的更多信息,参考: | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'tokens', 'pos', 'chunks', 'ner', | |||
@@ -356,56 +215,6 @@ def _cut_long_sentence(sent, max_sample_length=200): | |||
return cutted_sentence | |||
class SSTLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||
读取SST数据集, DataSet包含fields:: | |||
words: list(str) 需要分类的文本 | |||
target: str 文本的标签 | |||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
""" | |||
def __init__(self, subtree=False, fine_grained=False): | |||
self.subtree = subtree | |||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||
'3': 'positive', '4': 'very positive'} | |||
if not fine_grained: | |||
tag_v['0'] = tag_v['1'] | |||
tag_v['4'] = tag_v['3'] | |||
self.tag_v = tag_v | |||
def _load(self, path): | |||
""" | |||
:param str path: 存储数据的路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
datas = [] | |||
for l in f: | |||
datas.extend([(s, self.tag_v[t]) | |||
for s, t in self._get_one(l, self.subtree)]) | |||
ds = DataSet() | |||
for words, tag in datas: | |||
ds.append(Instance(words=words, target=tag)) | |||
return ds | |||
@staticmethod | |||
def _get_one(data, subtree): | |||
tree = Tree.fromstring(data) | |||
if subtree: | |||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||
return [(tree.leaves(), tree.label())] | |||
class JsonLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` | |||
@@ -419,7 +228,7 @@ class JsonLoader(DataSetLoader): | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super(JsonLoader, self).__init__() | |||
self.dropna = dropna | |||
@@ -430,7 +239,7 @@ class JsonLoader(DataSetLoader): | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.fields_list = list(self.fields.keys()) | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
@@ -454,27 +263,27 @@ class SNLILoader(JsonLoader): | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self): | |||
fields = { | |||
'sentence1_parse': 'words1', | |||
'sentence2_parse': 'words2', | |||
'gold_label': 'target', | |||
'sentence1_parse': Const.INPUTS(0), | |||
'sentence2_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
super(SNLILoader, self).__init__(fields=fields) | |||
def _load(self, path): | |||
ds = super(SNLILoader, self)._load(path) | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
ds.apply(lambda ins: parse_tree( | |||
ins['words1']), new_field_name='words1') | |||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: parse_tree( | |||
ins['words2']), new_field_name='words2') | |||
ds.drop(lambda x: x['target'] == '-') | |||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
@@ -562,12 +371,12 @@ class CSVLoader(DataSetLoader): | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=False): | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_csv(path, headers=self.headers, | |||
@@ -582,7 +391,7 @@ def _add_seg_tag(data): | |||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||
:return: list of ([word], [pos]) | |||
""" | |||
_processed = [] | |||
for word_list, pos_list, _, _ in data: | |||
new_sample = [] | |||
@@ -1,5 +1,6 @@ | |||
__all__ = [ | |||
"EmbedLoader" | |||
"EmbedLoader", | |||
"EmbeddingOption", | |||
] | |||
import os | |||
@@ -9,6 +10,21 @@ import numpy as np | |||
from ..core.vocabulary import Vocabulary | |||
from .base_loader import BaseLoader | |||
from ..core.utils import Option | |||
class EmbeddingOption(Option): | |||
def __init__(self, | |||
embed_filepath=None, | |||
dtype=np.float32, | |||
normalize=True, | |||
error='ignore'): | |||
super().__init__( | |||
embed_filepath=embed_filepath, | |||
dtype=dtype, | |||
normalize=normalize, | |||
error=error | |||
) | |||
class EmbedLoader(BaseLoader): | |||
@@ -20,9 +36,9 @@ class EmbedLoader(BaseLoader): | |||
def __init__(self): | |||
super(EmbedLoader, self).__init__() | |||
@staticmethod | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, error='ignore'): | |||
""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
@@ -31,6 +47,8 @@ class EmbedLoader(BaseLoader): | |||
:param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 | |||
没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||
:param dtype: 读出的embedding的类型 | |||
:param str padding: 词表中padding的token | |||
:param str unknown: 词表中unknown的token | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 | |||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 | |||
@@ -54,9 +72,16 @@ class EmbedLoader(BaseLoader): | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
parts = line.strip().split() | |||
if parts[0] in vocab: | |||
index = vocab.to_index(parts[0]) | |||
matrix[index] = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||
word = ''.join(parts[:-dim]) | |||
nums = parts[-dim:] | |||
# 对齐unk与pad | |||
if word==padding and vocab.padding is not None: | |||
word = vocab.padding | |||
elif word==unknown and vocab.unknown is not None: | |||
word = vocab.unknown | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
matrix[index] = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) | |||
hit_flags[index] = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
@@ -87,14 +112,14 @@ class EmbedLoader(BaseLoader): | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param dtype: 读出的embedding的类型 | |||
:param str padding: the padding tag for vocabulary. | |||
:param str unknown: the unknown tag for vocabulary. | |||
:param str padding: 词表中的padding的token. 并以此用做vocab的padding。 | |||
:param str unknown: 词表中的unknown的token. 并以此用做vocab的unknown。 | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 | |||
方在于词表有空行或者词表出现了维度不一致。 | |||
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
:return numpy.ndarray: Vocabulary Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||
:return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||
是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | |||
""" | |||
vocab = Vocabulary(padding=padding, unknown=unknown) | |||
vec_dict = {} | |||
@@ -111,15 +136,16 @@ class EmbedLoader(BaseLoader): | |||
for idx, line in enumerate(f, start=start): | |||
try: | |||
parts = line.strip().split() | |||
word = parts[0] | |||
if dim == -1: | |||
dim = len(parts) - 1 | |||
vec = np.fromstring(' '.join(parts[1:]), sep=' ', dtype=dtype, count=dim) | |||
word = ''.join(parts[:-dim]) | |||
nums = parts[-dim:] | |||
vec = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) | |||
vec_dict[word] = vec | |||
vocab.add_word(word) | |||
if unknown is not None and unknown == word: | |||
found_unknown = True | |||
if found_pad is not None and padding == word: | |||
if padding is not None and padding == word: | |||
found_pad = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
@@ -90,11 +90,12 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
return sample | |||
with open(path, 'r', encoding=encoding) as f: | |||
sample = [] | |||
start = next(f) | |||
if '-DOCSTART-' not in start: | |||
start = next(f).strip() | |||
if '-DOCSTART-' not in start and start!='': | |||
sample.append(start.split()) | |||
for line_idx, line in enumerate(f, 1): | |||
if line.startswith('\n'): | |||
line = line.strip() | |||
if line=='': | |||
if len(sample): | |||
try: | |||
res = parse_conll(sample) | |||
@@ -107,7 +108,8 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split()) | |||
if not line.startswith('-DOCSTART-'): | |||
sample.append(line.split()) | |||
if len(sample) > 0: | |||
try: | |||
res = parse_conll(sample) | |||
@@ -115,4 +117,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
except Exception as e: | |||
if dropna: | |||
return | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
print('invalid instance at line: {}'.format(line_idx)) | |||
raise e |
@@ -0,0 +1,255 @@ | |||
import os | |||
from pathlib import Path | |||
from urllib.parse import urlparse | |||
import re | |||
import requests | |||
import tempfile | |||
from tqdm import tqdm | |||
import shutil | |||
import hashlib | |||
def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path: | |||
""" | |||
给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | |||
将文件放入到cache_dir中 | |||
""" | |||
if cache_dir is None: | |||
dataset_cache = Path(get_defalt_path()) | |||
else: | |||
dataset_cache = cache_dir | |||
parsed = urlparse(url_or_filename) | |||
if parsed.scheme in ("http", "https"): | |||
# URL, so get it from the cache (downloading if necessary) | |||
return get_from_cache(url_or_filename, dataset_cache) | |||
elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists(): | |||
# File, and it exists. | |||
return Path(url_or_filename) | |||
elif parsed.scheme == "": | |||
# File, but it doesn't exist. | |||
raise FileNotFoundError("file {} not found".format(url_or_filename)) | |||
else: | |||
# Something unknown | |||
raise ValueError( | |||
"unable to parse {} as a URL or as a local path".format(url_or_filename) | |||
) | |||
def get_filepath(filepath): | |||
""" | |||
如果filepath中只有一个文件,则直接返回对应的全路径 | |||
:param filepath: | |||
:return: | |||
""" | |||
if os.path.isdir(filepath): | |||
files = os.listdir(filepath) | |||
if len(files)==1: | |||
return os.path.join(filepath, files[0]) | |||
else: | |||
return filepath | |||
return filepath | |||
def get_defalt_path(): | |||
""" | |||
获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | |||
:return: | |||
""" | |||
if 'FASTNLP_CACHE_DIR' in os.environ: | |||
fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') | |||
if os.path.exists(fastnlp_cache_dir): | |||
return fastnlp_cache_dir | |||
raise RuntimeError("Some errors happens on cache directory.") | |||
else: | |||
raise RuntimeError("There function is not available right now.") | |||
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) | |||
return fastnlp_cache_dir | |||
def _get_base_url(name): | |||
# 返回的URL结尾必须是/ | |||
if 'FASTNLP_BASE_URL' in os.environ: | |||
fastnlp_base_url = os.environ['FASTNLP_BASE_URL'] | |||
return fastnlp_base_url | |||
raise RuntimeError("There function is not available right now.") | |||
def split_filename_suffix(filepath): | |||
""" | |||
给定filepath返回对应的name和suffix | |||
:param filepath: | |||
:return: filename, suffix | |||
""" | |||
filename = os.path.basename(filepath) | |||
if filename.endswith('.tar.gz'): | |||
return filename[:-7], '.tar.gz' | |||
return os.path.splitext(filename) | |||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
""" | |||
尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。 | |||
如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。 | |||
""" | |||
cache_dir.mkdir(parents=True, exist_ok=True) | |||
filename = re.sub(r".+/", "", url) | |||
dir_name, suffix = split_filename_suffix(filename) | |||
sep_index = dir_name[::-1].index('-') | |||
if sep_index<0: | |||
check_sum = None | |||
else: | |||
check_sum = dir_name[-sep_index+1:] | |||
sep_index = len(dir_name) if sep_index==-1 else -sep_index-1 | |||
dir_name = dir_name[:sep_index] | |||
# 寻找与它名字匹配的内容, 而不关心后缀 | |||
match_dir_name = match_file(dir_name, cache_dir) | |||
if match_dir_name: | |||
dir_name = match_dir_name | |||
cache_path = cache_dir / dir_name | |||
# get cache path to put the file | |||
if cache_path.exists(): | |||
return get_filepath(cache_path) | |||
# make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 | |||
response = requests.head(url, headers={"User-Agent": "fastNLP"}) | |||
if response.status_code != 200: | |||
raise IOError( | |||
f"HEAD request failed for url {url} with status code {response.status_code}." | |||
) | |||
# add ETag to filename if it exists | |||
# etag = response.headers.get("ETag") | |||
if not cache_path.exists(): | |||
# Download to temporary file, then copy to cache dir once finished. | |||
# Otherwise you get corrupt cache entries if the download gets interrupted. | |||
fd, temp_filename = tempfile.mkstemp() | |||
print("%s not found in cache, downloading to %s"%(url, temp_filename)) | |||
# GET file object | |||
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | |||
content_length = req.headers.get("Content-Length") | |||
total = int(content_length) if content_length is not None else None | |||
progress = tqdm(unit="B", total=total) | |||
sha256 = hashlib.sha256() | |||
with open(temp_filename, "wb") as temp_file: | |||
for chunk in req.iter_content(chunk_size=1024): | |||
if chunk: # filter out keep-alive new chunks | |||
progress.update(len(chunk)) | |||
temp_file.write(chunk) | |||
sha256.update(chunk) | |||
# check sum | |||
digit = sha256.hexdigest()[:8] | |||
if not check_sum: | |||
assert digit == check_sum, "File corrupted when download." | |||
progress.close() | |||
print(f"Finish download from {url}.") | |||
# 开始解压 | |||
delete_temp_dir = None | |||
if suffix in ('.zip', '.tar.gz'): | |||
uncompress_temp_dir = tempfile.mkdtemp() | |||
delete_temp_dir = uncompress_temp_dir | |||
print(f"Start to uncompress file to {uncompress_temp_dir}.") | |||
if suffix == '.zip': | |||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
else: | |||
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
filenames = os.listdir(uncompress_temp_dir) | |||
if len(filenames)==1: | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): | |||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||
cache_path.mkdir(parents=True, exist_ok=True) | |||
print("Finish un-compressing file.") | |||
else: | |||
uncompress_temp_dir = temp_filename | |||
cache_path = str(cache_path) + suffix | |||
success = False | |||
try: | |||
# 复制到指定的位置 | |||
print(f"Copy file to {cache_path}.") | |||
if os.path.isdir(uncompress_temp_dir): | |||
for filename in os.listdir(uncompress_temp_dir): | |||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) | |||
else: | |||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||
success = True | |||
except Exception as e: | |||
print(e) | |||
raise e | |||
finally: | |||
if not success: | |||
if cache_path.exists(): | |||
if cache_path.is_file(): | |||
os.remove(cache_path) | |||
else: | |||
shutil.rmtree(cache_path) | |||
if delete_temp_dir: | |||
shutil.rmtree(delete_temp_dir) | |||
os.close(fd) | |||
os.remove(temp_filename) | |||
return get_filepath(cache_path) | |||
def unzip_file(file: Path, to: Path): | |||
# unpack and write out in CoNLL column-like format | |||
from zipfile import ZipFile | |||
with ZipFile(file, "r") as zipObj: | |||
# Extract all the contents of zip file in current directory | |||
zipObj.extractall(to) | |||
def untar_gz_file(file:Path, to:Path): | |||
import tarfile | |||
with tarfile.open(file, 'r:gz') as tar: | |||
tar.extractall(to) | |||
def match_file(dir_name:str, cache_dir:str)->str: | |||
""" | |||
匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。 | |||
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | |||
:param dir_name: 需要匹配的名称 | |||
:param cache_dir: 在该目录下找匹配dir_name是否存在 | |||
:return: str | |||
""" | |||
files = os.listdir(cache_dir) | |||
matched_filenames = [] | |||
for file_name in files: | |||
if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name): | |||
matched_filenames.append(file_name) | |||
if len(matched_filenames)==0: | |||
return '' | |||
elif len(matched_filenames)==1: | |||
return matched_filenames[-1] | |||
else: | |||
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | |||
if __name__ == '__main__': | |||
cache_dir = Path('caches') | |||
cache_dir = None | |||
# 需要对cache_dir进行测试 | |||
base_url = 'http://0.0.0.0:8888/file/download' | |||
# if True: | |||
# for filename in os.listdir(cache_dir): | |||
# if os.path.isdir(os.path.join(cache_dir, filename)): | |||
# shutil.rmtree(os.path.join(cache_dir, filename)) | |||
# else: | |||
# os.remove(os.path.join(cache_dir, filename)) | |||
# 1. 测试.txt文件 | |||
print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir)) | |||
# 2. 测试.zip文件(只有一个文件) | |||
print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir)) | |||
# 3. 测试.zip文件(有多个文件) | |||
print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir)) | |||
# 4. 测试.tar.gz文件 | |||
print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir)) | |||
# 5. 测试.tar.gz多个文件 | |||
print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir)) | |||
# 6. 测试.pkl文件 |
@@ -10,6 +10,35 @@ from ..core.const import Const | |||
from ..modules.encoder import BertModel | |||
class BertConfig: | |||
def __init__( | |||
self, | |||
vocab_size=30522, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act="gelu", | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02 | |||
): | |||
self.vocab_size = vocab_size | |||
self.hidden_size = hidden_size | |||
self.num_hidden_layers = num_hidden_layers | |||
self.num_attention_heads = num_attention_heads | |||
self.intermediate_size = intermediate_size | |||
self.hidden_act = hidden_act | |||
self.hidden_dropout_prob = hidden_dropout_prob | |||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||
self.max_position_embeddings = max_position_embeddings | |||
self.type_vocab_size = type_vocab_size | |||
self.initializer_range = initializer_range | |||
class BertForSequenceClassification(BaseModel): | |||
"""BERT model for classification. | |||
This module is composed of the BERT model with a linear layer on top of | |||
@@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel): | |||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
num_labels = 2 | |||
model = BertForSequenceClassification(config, num_labels) | |||
model = BertForSequenceClassification(num_labels, config) | |||
logits = model(input_ids, token_type_ids, input_mask) | |||
``` | |||
""" | |||
def __init__(self, config, num_labels, bert_dir): | |||
def __init__(self, num_labels, config=None, bert_dir=None): | |||
super(BertForSequenceClassification, self).__init__() | |||
self.num_labels = num_labels | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, num_labels) | |||
@@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel): | |||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | |||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
num_choices = 2 | |||
model = BertForMultipleChoice(config, num_choices, bert_dir) | |||
model = BertForMultipleChoice(num_choices, config, bert_dir) | |||
logits = model(input_ids, token_type_ids, input_mask) | |||
``` | |||
""" | |||
def __init__(self, config, num_choices, bert_dir): | |||
def __init__(self, num_choices, config=None, bert_dir=None): | |||
super(BertForMultipleChoice, self).__init__() | |||
self.num_choices = num_choices | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, 1) | |||
@@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel): | |||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
num_labels = 2 | |||
bert_dir = 'your-bert-file-dir' | |||
model = BertForTokenClassification(config, num_labels, bert_dir) | |||
model = BertForTokenClassification(num_labels, config, bert_dir) | |||
logits = model(input_ids, token_type_ids, input_mask) | |||
``` | |||
""" | |||
def __init__(self, config, num_labels, bert_dir): | |||
def __init__(self, num_labels, config=None, bert_dir=None): | |||
super(BertForTokenClassification, self).__init__() | |||
self.num_labels = num_labels | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.classifier = nn.Linear(config.hidden_size, num_labels) | |||
@@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel): | |||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) | |||
``` | |||
""" | |||
def __init__(self, config, bert_dir): | |||
def __init__(self, config=None, bert_dir=None): | |||
super(BertForQuestionAnswering, self).__init__() | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
if config is None: | |||
config = BertConfig() | |||
self.bert = BertModel(**config.__dict__) | |||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | |||
# self.dropout = nn.Dropout(config.hidden_dropout_prob) | |||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | |||
@@ -7,6 +7,7 @@ import torch.nn as nn | |||
from ..core.const import Const as C | |||
from ..modules import encoder | |||
from fastNLP import seq_len_to_mask | |||
class CNNText(torch.nn.Module): | |||
@@ -21,15 +22,13 @@ class CNNText(torch.nn.Module): | |||
:param int num_classes: 一共有多少类 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param int padding: 对句子前后的pad的大小, 用0填充。 | |||
:param float dropout: Dropout的大小 | |||
""" | |||
def __init__(self, init_embed, | |||
num_classes, | |||
kernel_nums=(3, 4, 5), | |||
kernel_sizes=(3, 4, 5), | |||
padding=0, | |||
kernel_nums=(30, 40, 50), | |||
kernel_sizes=(1, 3, 5), | |||
dropout=0.5): | |||
super(CNNText, self).__init__() | |||
@@ -38,8 +37,7 @@ class CNNText(torch.nn.Module): | |||
self.conv_pool = encoder.ConvMaxpool( | |||
in_channels=self.embed.embedding_dim, | |||
out_channels=kernel_nums, | |||
kernel_sizes=kernel_sizes, | |||
padding=padding) | |||
kernel_sizes=kernel_sizes) | |||
self.dropout = nn.Dropout(dropout) | |||
self.fc = nn.Linear(sum(kernel_nums), num_classes) | |||
@@ -51,7 +49,11 @@ class CNNText(torch.nn.Module): | |||
:return output: dict of torch.LongTensor, [batch_size, num_classes] | |||
""" | |||
x = self.embed(words) # [N,L] -> [N,L,C] | |||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len) | |||
x = self.conv_pool(x, mask) | |||
else: | |||
x = self.conv_pool(x) # [N,L,C] -> [N,C] | |||
x = self.dropout(x) | |||
x = self.fc(x) # [N,C] -> [N, N_class] | |||
return {C.OUTPUT: x} | |||
@@ -9,7 +9,7 @@ from torch import nn | |||
from ..utils import initial_parameter | |||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): | |||
""" | |||
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` | |||
@@ -17,7 +17,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
:param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 | |||
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。 | |||
:param str encoding_type: 支持"bio", "bmes", "bmeso"。 | |||
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。 | |||
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; | |||
为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); | |||
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容 | |||
@@ -58,7 +58,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): | |||
""" | |||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO"。 | |||
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 'bioes'。 | |||
:param str from_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
:param str from_label: 比如"PER", "LOC"等label | |||
:param str to_tag: 比如"B", "M"之类的标注tag. 还包括start, end等两种特殊tag | |||
@@ -134,9 +134,19 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label | |||
return to_tag in ['b', 's', 'end', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) | |||
elif encoding_type == 'bioes': | |||
if from_tag == 'start': | |||
return to_tag in ['b', 's', 'o'] | |||
elif from_tag == 'b': | |||
return to_tag in ['i', 'e'] and from_label == to_label | |||
elif from_tag == 'i': | |||
return to_tag in ['i', 'e'] and from_label == to_label | |||
elif from_tag in ['e', 's', 'o']: | |||
return to_tag in ['b', 's', 'end', 'o'] | |||
else: | |||
raise ValueError("Unexpect tag type {}. Expect only 'B', 'I', 'E', 'S', 'O'.".format(from_tag)) | |||
else: | |||
raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) | |||
raise ValueError("Only support BIO, BMES, BMESO, BIOES encoding type, got {}.".format(encoding_type)) | |||
class ConditionalRandomField(nn.Module): | |||
@@ -7,6 +7,12 @@ __all__ = [ | |||
"ConvMaxpool", | |||
"Embedding", | |||
"StaticEmbedding", | |||
"ElmoEmbedding", | |||
"BertEmbedding", | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
"LSTM", | |||
@@ -18,10 +24,12 @@ __all__ = [ | |||
"VarLSTM", | |||
"VarGRU" | |||
] | |||
from .bert import BertModel | |||
from ._bert import BertModel | |||
from .bert import BertWordPieceEncoder | |||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
from .conv_maxpool import ConvMaxpool | |||
from .embedding import Embedding | |||
from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \ | |||
StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding | |||
from .lstm import LSTM | |||
from .star_transformer import StarTransformer | |||
from .transformer import TransformerEncoder | |||
@@ -0,0 +1,961 @@ | |||
""" | |||
这个页面的代码很大程度上参考了https://github.com/huggingface/pytorch-pretrained-BERT的代码 | |||
""" | |||
from ...core.vocabulary import Vocabulary | |||
import collections | |||
import unicodedata | |||
from ...io.file_utils import _get_base_url, cached_path | |||
import numpy as np | |||
from itertools import chain | |||
import copy | |||
import json | |||
import math | |||
import os | |||
import torch | |||
from torch import nn | |||
import glob | |||
CONFIG_FILE = 'bert_config.json' | |||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||
def gelu(x): | |||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
def swish(x): | |||
return x * torch.sigmoid(x) | |||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
class BertLayerNorm(nn.Module): | |||
def __init__(self, hidden_size, eps=1e-12): | |||
super(BertLayerNorm, self).__init__() | |||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
self.variance_epsilon = eps | |||
def forward(self, x): | |||
u = x.mean(-1, keepdim=True) | |||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||
return self.weight * x + self.bias | |||
class BertEmbeddings(nn.Module): | |||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||
super(BertEmbeddings, self).__init__() | |||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||
# any TensorFlow checkpoint file | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
def forward(self, input_ids, token_type_ids=None): | |||
seq_length = input_ids.size(1) | |||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
if token_type_ids is None: | |||
token_type_ids = torch.zeros_like(input_ids) | |||
words_embeddings = self.word_embeddings(input_ids) | |||
position_embeddings = self.position_embeddings(position_ids) | |||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||
embeddings = self.LayerNorm(embeddings) | |||
embeddings = self.dropout(embeddings) | |||
return embeddings | |||
class BertSelfAttention(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||
super(BertSelfAttention, self).__init__() | |||
if hidden_size % num_attention_heads != 0: | |||
raise ValueError( | |||
"The hidden size (%d) is not a multiple of the number of attention " | |||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||
self.num_attention_heads = num_attention_heads | |||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||
x = x.view(*new_x_shape) | |||
return x.permute(0, 2, 1, 3) | |||
def forward(self, hidden_states, attention_mask): | |||
mixed_query_layer = self.query(hidden_states) | |||
mixed_key_layer = self.key(hidden_states) | |||
mixed_value_layer = self.value(hidden_states) | |||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||
attention_scores = attention_scores + attention_mask | |||
# Normalize the attention scores to probabilities. | |||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||
# This is actually dropping out entire tokens to attend to, which might | |||
# seem a bit unusual, but is taken from the original Transformer paper. | |||
attention_probs = self.dropout(attention_probs) | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
return context_layer | |||
class BertSelfOutput(nn.Module): | |||
def __init__(self, hidden_size, hidden_dropout_prob): | |||
super(BertSelfOutput, self).__init__() | |||
self.dense = nn.Linear(hidden_size, hidden_size) | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class BertAttention(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||
super(BertAttention, self).__init__() | |||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||
def forward(self, input_tensor, attention_mask): | |||
self_output = self.self(input_tensor, attention_mask) | |||
attention_output = self.output(self_output, input_tensor) | |||
return attention_output | |||
class BertIntermediate(nn.Module): | |||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||
super(BertIntermediate, self).__init__() | |||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||
if isinstance(hidden_act, str) else hidden_act | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.intermediate_act_fn(hidden_states) | |||
return hidden_states | |||
class BertOutput(nn.Module): | |||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||
super(BertOutput, self).__init__() | |||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class BertLayer(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
intermediate_size, hidden_act): | |||
super(BertLayer, self).__init__() | |||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
hidden_dropout_prob) | |||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||
def forward(self, hidden_states, attention_mask): | |||
attention_output = self.attention(hidden_states, attention_mask) | |||
intermediate_output = self.intermediate(attention_output) | |||
layer_output = self.output(intermediate_output, attention_output) | |||
return layer_output | |||
class BertEncoder(nn.Module): | |||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
hidden_dropout_prob, | |||
intermediate_size, hidden_act): | |||
super(BertEncoder, self).__init__() | |||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
intermediate_size, hidden_act) | |||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||
all_encoder_layers = [] | |||
for layer_module in self.layer: | |||
hidden_states = layer_module(hidden_states, attention_mask) | |||
if output_all_encoded_layers: | |||
all_encoder_layers.append(hidden_states) | |||
if not output_all_encoded_layers: | |||
all_encoder_layers.append(hidden_states) | |||
return all_encoder_layers | |||
class BertPooler(nn.Module): | |||
def __init__(self, hidden_size): | |||
super(BertPooler, self).__init__() | |||
self.dense = nn.Linear(hidden_size, hidden_size) | |||
self.activation = nn.Tanh() | |||
def forward(self, hidden_states): | |||
# We "pool" the model by simply taking the hidden state corresponding | |||
# to the first token. | |||
first_token_tensor = hidden_states[:, 0] | |||
pooled_output = self.dense(first_token_tensor) | |||
pooled_output = self.activation(pooled_output) | |||
return pooled_output | |||
class BertModel(nn.Module): | |||
"""BERT(Bidirectional Embedding Representations from Transformers). | |||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||
sources:: | |||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||
用预训练权重矩阵来建立BERT模型:: | |||
model = BertModel.from_pretrained("path/to/weights/directory") | |||
用随机初始化权重矩阵来建立BERT模型:: | |||
model = BertModel() | |||
:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||
:param int max_position_embeddings: 最大的序列长度,默认值为512, | |||
:param int type_vocab_size: 最大segment数量,默认值为2 | |||
:param int initializer_range: 初始化权重范围,默认值为0.02 | |||
""" | |||
def __init__(self, vocab_size=30522, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act="gelu", | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02): | |||
super(BertModel, self).__init__() | |||
self.hidden_size = hidden_size | |||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||
type_vocab_size, hidden_dropout_prob) | |||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||
hidden_act) | |||
self.pooler = BertPooler(hidden_size) | |||
self.initializer_range = initializer_range | |||
self.apply(self.init_bert_weights) | |||
def init_bert_weights(self, module): | |||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||
elif isinstance(module, BertLayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
if isinstance(module, nn.Linear) and module.bias is not None: | |||
module.bias.data.zero_() | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||
if attention_mask is None: | |||
attention_mask = torch.ones_like(input_ids) | |||
if token_type_ids is None: | |||
token_type_ids = torch.zeros_like(input_ids) | |||
# We create a 3D attention mask from a 2D tensor mask. | |||
# Sizes are [batch_size, 1, 1, to_seq_length] | |||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||
# this attention mask is more simple than the triangular masking of causal attention | |||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||
# masked positions, this operation will create a tensor which is 0.0 for | |||
# positions we want to attend and -10000.0 for masked positions. | |||
# Since we are adding it to the raw scores before the softmax, this is | |||
# effectively the same as removing these entirely. | |||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||
encoded_layers = self.encoder(embedding_output, | |||
extended_attention_mask, | |||
output_all_encoded_layers=output_all_encoded_layers) | |||
sequence_output = encoded_layers[-1] | |||
pooled_output = self.pooler(sequence_output) | |||
if not output_all_encoded_layers: | |||
encoded_layers = encoded_layers[-1] | |||
return encoded_layers, pooled_output | |||
@classmethod | |||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||
# Load config | |||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||
config = json.load(open(config_file, "r")) | |||
# config = BertConfig.from_json_file(config_file) | |||
# logger.info("Model config {}".format(config)) | |||
# Instantiate model. | |||
model = cls(*inputs, **config, **kwargs) | |||
if state_dict is None: | |||
files = glob.glob(os.path.join(pretrained_model_dir, '*.bin')) | |||
if len(files)==0: | |||
raise FileNotFoundError(f"There is no *.bin file in {pretrained_model_dir}") | |||
elif len(files)>1: | |||
raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}") | |||
weights_path = files[0] | |||
state_dict = torch.load(weights_path) | |||
old_keys = [] | |||
new_keys = [] | |||
for key in state_dict.keys(): | |||
new_key = None | |||
if 'gamma' in key: | |||
new_key = key.replace('gamma', 'weight') | |||
if 'beta' in key: | |||
new_key = key.replace('beta', 'bias') | |||
if new_key: | |||
old_keys.append(key) | |||
new_keys.append(new_key) | |||
for old_key, new_key in zip(old_keys, new_keys): | |||
state_dict[new_key] = state_dict.pop(old_key) | |||
missing_keys = [] | |||
unexpected_keys = [] | |||
error_msgs = [] | |||
# copy state_dict so _load_from_state_dict can modify it | |||
metadata = getattr(state_dict, '_metadata', None) | |||
state_dict = state_dict.copy() | |||
if metadata is not None: | |||
state_dict._metadata = metadata | |||
def load(module, prefix=''): | |||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||
module._load_from_state_dict( | |||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||
for name, child in module._modules.items(): | |||
if child is not None: | |||
load(child, prefix + name + '.') | |||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||
if len(missing_keys) > 0: | |||
print("Weights of {} not initialized from pretrained model: {}".format( | |||
model.__class__.__name__, missing_keys)) | |||
if len(unexpected_keys) > 0: | |||
print("Weights from pretrained model not used in {}: {}".format( | |||
model.__class__.__name__, unexpected_keys)) | |||
return model | |||
def whitespace_tokenize(text): | |||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||
text = text.strip() | |||
if not text: | |||
return [] | |||
tokens = text.split() | |||
return tokens | |||
class WordpieceTokenizer(object): | |||
"""Runs WordPiece tokenization.""" | |||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): | |||
self.vocab = vocab | |||
self.unk_token = unk_token | |||
self.max_input_chars_per_word = max_input_chars_per_word | |||
def tokenize(self, text): | |||
"""Tokenizes a piece of text into its word pieces. | |||
This uses a greedy longest-match-first algorithm to perform tokenization | |||
using the given vocabulary. | |||
For example: | |||
input = "unaffable" | |||
output = ["un", "##aff", "##able"] | |||
Args: | |||
text: A single token or whitespace separated tokens. This should have | |||
already been passed through `BasicTokenizer`. | |||
Returns: | |||
A list of wordpiece tokens. | |||
""" | |||
output_tokens = [] | |||
for token in whitespace_tokenize(text): | |||
chars = list(token) | |||
if len(chars) > self.max_input_chars_per_word: | |||
output_tokens.append(self.unk_token) | |||
continue | |||
is_bad = False | |||
start = 0 | |||
sub_tokens = [] | |||
while start < len(chars): | |||
end = len(chars) | |||
cur_substr = None | |||
while start < end: | |||
substr = "".join(chars[start:end]) | |||
if start > 0: | |||
substr = "##" + substr | |||
if substr in self.vocab: | |||
cur_substr = substr | |||
break | |||
end -= 1 | |||
if cur_substr is None: | |||
is_bad = True | |||
break | |||
sub_tokens.append(cur_substr) | |||
start = end | |||
if is_bad: | |||
output_tokens.append(self.unk_token) | |||
else: | |||
output_tokens.extend(sub_tokens) | |||
return output_tokens | |||
def load_vocab(vocab_file): | |||
"""Loads a vocabulary file into a dictionary.""" | |||
vocab = collections.OrderedDict() | |||
index = 0 | |||
with open(vocab_file, "r", encoding="utf-8") as reader: | |||
while True: | |||
token = reader.readline() | |||
if not token: | |||
break | |||
token = token.strip() | |||
vocab[token] = index | |||
index += 1 | |||
return vocab | |||
class BasicTokenizer(object): | |||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||
def __init__(self, | |||
do_lower_case=True, | |||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||
"""Constructs a BasicTokenizer. | |||
Args: | |||
do_lower_case: Whether to lower case the input. | |||
""" | |||
self.do_lower_case = do_lower_case | |||
self.never_split = never_split | |||
def tokenize(self, text): | |||
"""Tokenizes a piece of text.""" | |||
text = self._clean_text(text) | |||
# This was added on November 1st, 2018 for the multilingual and Chinese | |||
# models. This is also applied to the English models now, but it doesn't | |||
# matter since the English models were not trained on any Chinese data | |||
# and generally don't have any Chinese data in them (there are Chinese | |||
# characters in the vocabulary because Wikipedia does have some Chinese | |||
# words in the English Wikipedia.). | |||
text = self._tokenize_chinese_chars(text) | |||
orig_tokens = whitespace_tokenize(text) | |||
split_tokens = [] | |||
for token in orig_tokens: | |||
if self.do_lower_case and token not in self.never_split: | |||
token = token.lower() | |||
token = self._run_strip_accents(token) | |||
split_tokens.extend(self._run_split_on_punc(token)) | |||
output_tokens = whitespace_tokenize(" ".join(split_tokens)) | |||
return output_tokens | |||
def _run_strip_accents(self, text): | |||
"""Strips accents from a piece of text.""" | |||
text = unicodedata.normalize("NFD", text) | |||
output = [] | |||
for char in text: | |||
cat = unicodedata.category(char) | |||
if cat == "Mn": | |||
continue | |||
output.append(char) | |||
return "".join(output) | |||
def _run_split_on_punc(self, text): | |||
"""Splits punctuation on a piece of text.""" | |||
if text in self.never_split: | |||
return [text] | |||
chars = list(text) | |||
i = 0 | |||
start_new_word = True | |||
output = [] | |||
while i < len(chars): | |||
char = chars[i] | |||
if _is_punctuation(char): | |||
output.append([char]) | |||
start_new_word = True | |||
else: | |||
if start_new_word: | |||
output.append([]) | |||
start_new_word = False | |||
output[-1].append(char) | |||
i += 1 | |||
return ["".join(x) for x in output] | |||
def _tokenize_chinese_chars(self, text): | |||
"""Adds whitespace around any CJK character.""" | |||
output = [] | |||
for char in text: | |||
cp = ord(char) | |||
if self._is_chinese_char(cp): | |||
output.append(" ") | |||
output.append(char) | |||
output.append(" ") | |||
else: | |||
output.append(char) | |||
return "".join(output) | |||
def _is_chinese_char(self, cp): | |||
"""Checks whether CP is the codepoint of a CJK character.""" | |||
# This defines a "chinese character" as anything in the CJK Unicode block: | |||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||
# | |||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||
# despite its name. The modern Korean Hangul alphabet is a different block, | |||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||
# space-separated words, so they are not treated specially and handled | |||
# like the all of the other languages. | |||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # | |||
(cp >= 0x3400 and cp <= 0x4DBF) or # | |||
(cp >= 0x20000 and cp <= 0x2A6DF) or # | |||
(cp >= 0x2A700 and cp <= 0x2B73F) or # | |||
(cp >= 0x2B740 and cp <= 0x2B81F) or # | |||
(cp >= 0x2B820 and cp <= 0x2CEAF) or | |||
(cp >= 0xF900 and cp <= 0xFAFF) or # | |||
(cp >= 0x2F800 and cp <= 0x2FA1F)): # | |||
return True | |||
return False | |||
def _clean_text(self, text): | |||
"""Performs invalid character removal and whitespace cleanup on text.""" | |||
output = [] | |||
for char in text: | |||
cp = ord(char) | |||
if cp == 0 or cp == 0xfffd or _is_control(char): | |||
continue | |||
if _is_whitespace(char): | |||
output.append(" ") | |||
else: | |||
output.append(char) | |||
return "".join(output) | |||
def _is_whitespace(char): | |||
"""Checks whether `chars` is a whitespace character.""" | |||
# \t, \n, and \r are technically contorl characters but we treat them | |||
# as whitespace since they are generally considered as such. | |||
if char == " " or char == "\t" or char == "\n" or char == "\r": | |||
return True | |||
cat = unicodedata.category(char) | |||
if cat == "Zs": | |||
return True | |||
return False | |||
def _is_control(char): | |||
"""Checks whether `chars` is a control character.""" | |||
# These are technically control characters but we count them as whitespace | |||
# characters. | |||
if char == "\t" or char == "\n" or char == "\r": | |||
return False | |||
cat = unicodedata.category(char) | |||
if cat.startswith("C"): | |||
return True | |||
return False | |||
def _is_punctuation(char): | |||
"""Checks whether `chars` is a punctuation character.""" | |||
cp = ord(char) | |||
# We treat all non-letter/number ASCII as punctuation. | |||
# Characters such as "^", "$", and "`" are not in the Unicode | |||
# Punctuation class but we treat them as punctuation anyways, for | |||
# consistency. | |||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or | |||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | |||
return True | |||
cat = unicodedata.category(char) | |||
if cat.startswith("P"): | |||
return True | |||
return False | |||
class BertTokenizer(object): | |||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, | |||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||
"""Constructs a BertTokenizer. | |||
Args: | |||
vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||
do_lower_case: Whether to lower case the input | |||
Only has an effect when do_wordpiece_only=False | |||
do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||
max_len: An artificial maximum length to truncate tokenized sequences to; | |||
Effective maximum length is always the minimum of this | |||
value (if specified) and the underlying BERT model's | |||
sequence length. | |||
never_split: List of tokens which will never be split during tokenization. | |||
Only has an effect when do_wordpiece_only=False | |||
""" | |||
if not os.path.isfile(vocab_file): | |||
raise ValueError( | |||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) | |||
self.vocab = load_vocab(vocab_file) | |||
self.ids_to_tokens = collections.OrderedDict( | |||
[(ids, tok) for tok, ids in self.vocab.items()]) | |||
self.do_basic_tokenize = do_basic_tokenize | |||
if do_basic_tokenize: | |||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, | |||
never_split=never_split) | |||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||
self.max_len = max_len if max_len is not None else int(1e12) | |||
def _reinit_on_new_vocab(self, vocab): | |||
""" | |||
在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质 | |||
:param vocab: | |||
:return: | |||
""" | |||
self.vocab = vocab | |||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||
def tokenize(self, text): | |||
split_tokens = [] | |||
if self.do_basic_tokenize: | |||
for token in self.basic_tokenizer.tokenize(text): | |||
for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||
split_tokens.append(sub_token) | |||
else: | |||
split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||
return split_tokens | |||
def convert_tokens_to_ids(self, tokens): | |||
"""Converts a sequence of tokens into ids using the vocab.""" | |||
ids = [] | |||
for token in tokens: | |||
ids.append(self.vocab[token]) | |||
if len(ids) > self.max_len: | |||
print( | |||
"Token indices sequence length is longer than the specified maximum " | |||
" sequence length for this BERT model ({} > {}). Running this" | |||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) | |||
) | |||
return ids | |||
def convert_ids_to_tokens(self, ids): | |||
"""Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||
tokens = [] | |||
for i in ids: | |||
tokens.append(self.ids_to_tokens[i]) | |||
return tokens | |||
def save_vocabulary(self, vocab_path): | |||
"""Save the tokenizer vocabulary to a directory or file.""" | |||
index = 0 | |||
if os.path.isdir(vocab_path): | |||
vocab_file = os.path.join(vocab_path, VOCAB_NAME) | |||
else: | |||
vocab_file = vocab_path | |||
with open(vocab_file, "w", encoding="utf-8") as writer: | |||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): | |||
if index != token_index: | |||
print("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||
" Please check that the vocabulary is not corrupted!".format(vocab_file)) | |||
index = token_index | |||
writer.write(token + u'\n') | |||
index += 1 | |||
return vocab_file | |||
@classmethod | |||
def from_pretrained(cls, model_dir, *inputs, **kwargs): | |||
""" | |||
给定path,直接读取vocab. | |||
""" | |||
pretrained_model_name_or_path = os.path.join(model_dir, VOCAB_NAME) | |||
print("loading vocabulary file {}".format(pretrained_model_name_or_path)) | |||
max_len = 512 | |||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) | |||
# Instantiate tokenizer. | |||
tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs) | |||
return tokenizer | |||
VOCAB_NAME = 'vocab.txt' | |||
class _WordBertModel(nn.Module): | |||
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | |||
self.encoder = BertModel.from_pretrained(model_dir) | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
self.layers = list(map(int, layers.split(','))) | |||
for layer in self.layers: | |||
if layer<0: | |||
assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
else: | |||
assert layer<encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
assert pool_method in ('avg', 'max', 'first', 'last') | |||
self.pool_method = pool_method | |||
self.include_cls_sep = include_cls_sep | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
print("Start to generating word pieces for word.") | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
for word, index in vocab: | |||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||
word = '[PAD]' | |||
elif index == vocab.unknown_idx: | |||
word = '[UNK]' | |||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
if len(word_pieces)==1: | |||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
continue | |||
for word_piece in word_pieces: | |||
word_piece_dict[word_piece] = 1 | |||
found_count += 1 | |||
original_embed = self.encoder.embeddings.word_embeddings.weight.data | |||
# 特殊词汇要特殊处理 | |||
embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed | |||
new_word_piece_vocab = collections.OrderedDict() | |||
for index, token in enumerate(['[PAD]', '[UNK]']): | |||
word_piece_dict.pop(token, None) | |||
embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]] | |||
new_word_piece_vocab[token] = index | |||
for token in word_piece_dict.keys(): | |||
if token in self.tokenzier.vocab: | |||
embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab[token]] | |||
else: | |||
embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab['[UNK]']] | |||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||
self.encoder.embeddings.word_embeddings = embed | |||
word_to_wordpieces = [] | |||
word_pieces_lengths = [] | |||
for word, index in vocab: | |||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||
word = '[PAD]' | |||
elif index == vocab.unknown_idx: | |||
word = '[UNK]' | |||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
word_to_wordpieces.append(word_pieces) | |||
word_pieces_lengths.append(len(word_pieces)) | |||
print("Found(Or seg into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
self._cls_index = self.tokenzier.vocab['[CLS]'] | |||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||
self._pad_index = vocab.padding_idx | |||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) | |||
print("Successfully generate word pieces.") | |||
def forward(self, words): | |||
""" | |||
:param words: torch.LongTensor, batch_size x max_len | |||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||
""" | |||
batch_size, max_word_len = words.size() | |||
seq_len = words.ne(self._pad_index).sum(dim=-1) | |||
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len | |||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) | |||
max_word_piece_length = word_pieces_lengths.max().item() | |||
# +2是由于需要加入[CLS]与[SEP] | |||
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) | |||
word_pieces[:, 0].fill_(self._cls_index) | |||
batch_indexes = torch.arange(batch_size).to(words) | |||
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index | |||
attn_masks = torch.zeros_like(word_pieces) | |||
# 1. 获取words的word_pieces的id,以及对应的span范围 | |||
word_indexes = words.tolist() | |||
for i in range(batch_size): | |||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) | |||
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) | |||
attn_masks[i, :len(word_pieces_i)+2].fill_(1) | |||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, | |||
output_all_encoded_layers=True) | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||
if self.include_cls_sep: | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 1 | |||
else: | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 0 | |||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | |||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||
for l_index, l in enumerate(self.layers): | |||
output_layer = bert_outputs[l] | |||
# 从word_piece collapse到word的表示 | |||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | |||
outputs_seq_len = seq_len + s_shift | |||
if self.pool_method == 'first': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置 | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size | |||
elif self.pool_method == 'last': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i]+1] - 1 # 每个word的end | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] | |||
elif self.pool_method == 'max': | |||
for i in range(batch_size): | |||
for j in range(seq_len[i]): | |||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] | |||
outputs[l_index, i, j+s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2) | |||
else: | |||
for i in range(batch_size): | |||
for j in range(seq_len[i]): | |||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] | |||
outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | |||
if self.include_cls_sep: | |||
outputs[l_index, :, 0] = output_layer[:, 0] | |||
outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift] | |||
# 3. 最终的embedding结果 | |||
return outputs | |||
class _WordPieceBertModel(nn.Module): | |||
""" | |||
这个模块用于直接计算word_piece的结果. | |||
""" | |||
def __init__(self, model_dir:str, layers:str='-1'): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | |||
self.encoder = BertModel.from_pretrained(model_dir) | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
self.layers = list(map(int, layers.split(','))) | |||
for layer in self.layers: | |||
if layer<0: | |||
assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
else: | |||
assert layer<encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
self._cls_index = self.tokenzier.vocab['[CLS]'] | |||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
def index_dataset(self, *datasets, field_name): | |||
""" | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||
:param datasets: DataSet对象 | |||
:param field_name: 基于哪一列index | |||
:return: | |||
""" | |||
def convert_words_to_word_pieces(words): | |||
word_pieces = [] | |||
for word in words: | |||
tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens) | |||
word_pieces.extend(word_piece_ids) | |||
if word_pieces[0]!=self._cls_index: | |||
word_pieces.insert(0, self._cls_index) | |||
if word_pieces[-1]!=self._sep_index: | |||
word_pieces.insert(-1, self._sep_index) | |||
return word_pieces | |||
for index, dataset in enumerate(datasets): | |||
try: | |||
dataset.apply_field(convert_words_to_word_pieces, field_name=field_name, new_field_name='word_pieces', | |||
is_input=True) | |||
dataset.set_pad_val('word_pieces', self._wordpiece_pad_index) | |||
except Exception as e: | |||
print(f"Exception happens when processing the {index} dataset.") | |||
raise e | |||
def forward(self, word_pieces, token_type_ids=None): | |||
""" | |||
:param word_pieces: torch.LongTensor, batch_size x max_len | |||
:param token_type_ids: torch.LongTensor, batch_size x max_len | |||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||
""" | |||
batch_size, max_len = word_pieces.size() | |||
attn_masks = word_pieces.ne(self._wordpiece_pad_index) | |||
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||
output_all_encoded_layers=True) | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||
outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1))) | |||
for l_index, l in enumerate(self.layers): | |||
outputs[l_index] = bert_outputs[l] | |||
return outputs | |||
@@ -0,0 +1,788 @@ | |||
""" | |||
这个页面的代码大量参考了https://github.com/HIT-SCIR/ELMoForManyLangs/tree/master/elmoformanylangs | |||
""" | |||
from typing import Optional, Tuple, List, Callable | |||
import os | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn.utils.rnn import PackedSequence, pad_packed_sequence | |||
from ...core.vocabulary import Vocabulary | |||
import json | |||
from ..utils import get_dropout_mask | |||
import codecs | |||
from torch import autograd | |||
class LstmCellWithProjection(torch.nn.Module): | |||
""" | |||
An LSTM with Recurrent Dropout and a projected and clipped hidden state and | |||
memory. Note: this implementation is slower than the native Pytorch LSTM because | |||
it cannot make use of CUDNN optimizations for stacked RNNs due to and | |||
variational dropout and the custom nature of the cell state. | |||
Parameters | |||
---------- | |||
input_size : ``int``, required. | |||
The dimension of the inputs to the LSTM. | |||
hidden_size : ``int``, required. | |||
The dimension of the outputs of the LSTM. | |||
cell_size : ``int``, required. | |||
The dimension of the memory cell used for the LSTM. | |||
go_forward: ``bool``, optional (default = True) | |||
The direction in which the LSTM is applied to the sequence. | |||
Forwards by default, or backwards if False. | |||
recurrent_dropout_probability: ``float``, optional (default = 0.0) | |||
The dropout probability to be used in a dropout scheme as stated in | |||
`A Theoretically Grounded Application of Dropout in Recurrent Neural Networks | |||
<https://arxiv.org/abs/1512.05287>`_ . Implementation wise, this simply | |||
applies a fixed dropout mask per sequence to the recurrent connection of the | |||
LSTM. | |||
state_projection_clip_value: ``float``, optional, (default = None) | |||
The magnitude with which to clip the hidden_state after projecting it. | |||
memory_cell_clip_value: ``float``, optional, (default = None) | |||
The magnitude with which to clip the memory cell. | |||
Returns | |||
------- | |||
output_accumulator : ``torch.FloatTensor`` | |||
The outputs of the LSTM for each timestep. A tensor of shape | |||
(batch_size, max_timesteps, hidden_size) where for a given batch | |||
element, all outputs past the sequence length for that batch are | |||
zero tensors. | |||
final_state: ``Tuple[torch.FloatTensor, torch.FloatTensor]`` | |||
The final (state, memory) states of the LSTM, with shape | |||
(1, batch_size, hidden_size) and (1, batch_size, cell_size) | |||
respectively. The first dimension is 1 in order to match the Pytorch | |||
API for returning stacked LSTM states. | |||
""" | |||
def __init__(self, | |||
input_size: int, | |||
hidden_size: int, | |||
cell_size: int, | |||
go_forward: bool = True, | |||
recurrent_dropout_probability: float = 0.0, | |||
memory_cell_clip_value: Optional[float] = None, | |||
state_projection_clip_value: Optional[float] = None) -> None: | |||
super(LstmCellWithProjection, self).__init__() | |||
# Required to be wrapped with a :class:`PytorchSeq2SeqWrapper`. | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.cell_size = cell_size | |||
self.go_forward = go_forward | |||
self.state_projection_clip_value = state_projection_clip_value | |||
self.memory_cell_clip_value = memory_cell_clip_value | |||
self.recurrent_dropout_probability = recurrent_dropout_probability | |||
# We do the projections for all the gates all at once. | |||
self.input_linearity = torch.nn.Linear(input_size, 4 * cell_size, bias=False) | |||
self.state_linearity = torch.nn.Linear(hidden_size, 4 * cell_size, bias=True) | |||
# Additional projection matrix for making the hidden state smaller. | |||
self.state_projection = torch.nn.Linear(cell_size, hidden_size, bias=False) | |||
self.reset_parameters() | |||
def reset_parameters(self): | |||
# Use sensible default initializations for parameters. | |||
nn.init.orthogonal_(self.input_linearity.weight.data) | |||
nn.init.orthogonal_(self.state_linearity.weight.data) | |||
self.state_linearity.bias.data.fill_(0.0) | |||
# Initialize forget gate biases to 1.0 as per An Empirical | |||
# Exploration of Recurrent Network Architectures, (Jozefowicz, 2015). | |||
self.state_linearity.bias.data[self.cell_size:2 * self.cell_size].fill_(1.0) | |||
def forward(self, # pylint: disable=arguments-differ | |||
inputs: torch.FloatTensor, | |||
batch_lengths: List[int], | |||
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): | |||
""" | |||
Parameters | |||
---------- | |||
inputs : ``torch.FloatTensor``, required. | |||
A tensor of shape (batch_size, num_timesteps, input_size) | |||
to apply the LSTM over. | |||
batch_lengths : ``List[int]``, required. | |||
A list of length batch_size containing the lengths of the sequences in batch. | |||
initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None) | |||
A tuple (state, memory) representing the initial hidden state and memory | |||
of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the | |||
``memory`` has shape (1, batch_size, cell_size). | |||
Returns | |||
------- | |||
output_accumulator : ``torch.FloatTensor`` | |||
The outputs of the LSTM for each timestep. A tensor of shape | |||
(batch_size, max_timesteps, hidden_size) where for a given batch | |||
element, all outputs past the sequence length for that batch are | |||
zero tensors. | |||
final_state : ``Tuple[``torch.FloatTensor, torch.FloatTensor]`` | |||
A tuple (state, memory) representing the initial hidden state and memory | |||
of the LSTM. The ``state`` has shape (1, batch_size, hidden_size) and the | |||
``memory`` has shape (1, batch_size, cell_size). | |||
""" | |||
batch_size = inputs.size()[0] | |||
total_timesteps = inputs.size()[1] | |||
# We have to use this '.data.new().fill_' pattern to create tensors with the correct | |||
# type - forward has no knowledge of whether these are torch.Tensors or torch.cuda.Tensors. | |||
output_accumulator = inputs.data.new(batch_size, | |||
total_timesteps, | |||
self.hidden_size).fill_(0) | |||
if initial_state is None: | |||
full_batch_previous_memory = inputs.data.new(batch_size, | |||
self.cell_size).fill_(0) | |||
full_batch_previous_state = inputs.data.new(batch_size, | |||
self.hidden_size).fill_(0) | |||
else: | |||
full_batch_previous_state = initial_state[0].squeeze(0) | |||
full_batch_previous_memory = initial_state[1].squeeze(0) | |||
current_length_index = batch_size - 1 if self.go_forward else 0 | |||
if self.recurrent_dropout_probability > 0.0 and self.training: | |||
dropout_mask = get_dropout_mask(self.recurrent_dropout_probability, | |||
full_batch_previous_state) | |||
else: | |||
dropout_mask = None | |||
for timestep in range(total_timesteps): | |||
# The index depends on which end we start. | |||
index = timestep if self.go_forward else total_timesteps - timestep - 1 | |||
# What we are doing here is finding the index into the batch dimension | |||
# which we need to use for this timestep, because the sequences have | |||
# variable length, so once the index is greater than the length of this | |||
# particular batch sequence, we no longer need to do the computation for | |||
# this sequence. The key thing to recognise here is that the batch inputs | |||
# must be _ordered_ by length from longest (first in batch) to shortest | |||
# (last) so initially, we are going forwards with every sequence and as we | |||
# pass the index at which the shortest elements of the batch finish, | |||
# we stop picking them up for the computation. | |||
if self.go_forward: | |||
while batch_lengths[current_length_index] <= index: | |||
current_length_index -= 1 | |||
# If we're going backwards, we are _picking up_ more indices. | |||
else: | |||
# First conditional: Are we already at the maximum number of elements in the batch? | |||
# Second conditional: Does the next shortest sequence beyond the current batch | |||
# index require computation use this timestep? | |||
while current_length_index < (len(batch_lengths) - 1) and \ | |||
batch_lengths[current_length_index + 1] > index: | |||
current_length_index += 1 | |||
# Actually get the slices of the batch which we | |||
# need for the computation at this timestep. | |||
# shape (batch_size, cell_size) | |||
previous_memory = full_batch_previous_memory[0: current_length_index + 1].clone() | |||
# Shape (batch_size, hidden_size) | |||
previous_state = full_batch_previous_state[0: current_length_index + 1].clone() | |||
# Shape (batch_size, input_size) | |||
timestep_input = inputs[0: current_length_index + 1, index] | |||
# Do the projections for all the gates all at once. | |||
# Both have shape (batch_size, 4 * cell_size) | |||
projected_input = self.input_linearity(timestep_input) | |||
projected_state = self.state_linearity(previous_state) | |||
# Main LSTM equations using relevant chunks of the big linear | |||
# projections of the hidden state and inputs. | |||
input_gate = torch.sigmoid(projected_input[:, (0 * self.cell_size):(1 * self.cell_size)] + | |||
projected_state[:, (0 * self.cell_size):(1 * self.cell_size)]) | |||
forget_gate = torch.sigmoid(projected_input[:, (1 * self.cell_size):(2 * self.cell_size)] + | |||
projected_state[:, (1 * self.cell_size):(2 * self.cell_size)]) | |||
memory_init = torch.tanh(projected_input[:, (2 * self.cell_size):(3 * self.cell_size)] + | |||
projected_state[:, (2 * self.cell_size):(3 * self.cell_size)]) | |||
output_gate = torch.sigmoid(projected_input[:, (3 * self.cell_size):(4 * self.cell_size)] + | |||
projected_state[:, (3 * self.cell_size):(4 * self.cell_size)]) | |||
memory = input_gate * memory_init + forget_gate * previous_memory | |||
# Here is the non-standard part of this LSTM cell; first, we clip the | |||
# memory cell, then we project the output of the timestep to a smaller size | |||
# and again clip it. | |||
if self.memory_cell_clip_value: | |||
# pylint: disable=invalid-unary-operand-type | |||
memory = torch.clamp(memory, -self.memory_cell_clip_value, self.memory_cell_clip_value) | |||
# shape (current_length_index, cell_size) | |||
pre_projection_timestep_output = output_gate * torch.tanh(memory) | |||
# shape (current_length_index, hidden_size) | |||
timestep_output = self.state_projection(pre_projection_timestep_output) | |||
if self.state_projection_clip_value: | |||
# pylint: disable=invalid-unary-operand-type | |||
timestep_output = torch.clamp(timestep_output, | |||
-self.state_projection_clip_value, | |||
self.state_projection_clip_value) | |||
# Only do dropout if the dropout prob is > 0.0 and we are in training mode. | |||
if dropout_mask is not None: | |||
timestep_output = timestep_output * dropout_mask[0: current_length_index + 1] | |||
# We've been doing computation with less than the full batch, so here we create a new | |||
# variable for the the whole batch at this timestep and insert the result for the | |||
# relevant elements of the batch into it. | |||
full_batch_previous_memory = full_batch_previous_memory.data.clone() | |||
full_batch_previous_state = full_batch_previous_state.data.clone() | |||
full_batch_previous_memory[0:current_length_index + 1] = memory | |||
full_batch_previous_state[0:current_length_index + 1] = timestep_output | |||
output_accumulator[0:current_length_index + 1, index] = timestep_output | |||
# Mimic the pytorch API by returning state in the following shape: | |||
# (num_layers * num_directions, batch_size, ...). As this | |||
# LSTM cell cannot be stacked, the first dimension here is just 1. | |||
final_state = (full_batch_previous_state.unsqueeze(0), | |||
full_batch_previous_memory.unsqueeze(0)) | |||
return output_accumulator, final_state | |||
class LstmbiLm(nn.Module): | |||
def __init__(self, config): | |||
super(LstmbiLm, self).__init__() | |||
self.config = config | |||
self.encoder = nn.LSTM(self.config['encoder']['projection_dim'], | |||
self.config['encoder']['dim'], | |||
num_layers=self.config['encoder']['n_layers'], | |||
bidirectional=True, | |||
batch_first=True, | |||
dropout=self.config['dropout']) | |||
self.projection = nn.Linear(self.config['encoder']['dim'], self.config['encoder']['projection_dim'], bias=True) | |||
def forward(self, inputs, seq_len): | |||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||
inputs = inputs[sort_idx] | |||
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=self.batch_first) | |||
output, hx = self.encoder(inputs, None) # -> [N,L,C] | |||
output, _ = nn.util.rnn.pad_packed_sequence(output, batch_first=self.batch_first) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
output = output[unsort_idx] | |||
forward, backward = output.split(self.config['encoder']['dim'], 2) | |||
return torch.cat([self.projection(forward), self.projection(backward)], dim=2) | |||
class ElmobiLm(torch.nn.Module): | |||
def __init__(self, config): | |||
super(ElmobiLm, self).__init__() | |||
self.config = config | |||
input_size = config['encoder']['projection_dim'] | |||
hidden_size = config['encoder']['projection_dim'] | |||
cell_size = config['encoder']['dim'] | |||
num_layers = config['encoder']['n_layers'] | |||
memory_cell_clip_value = config['encoder']['cell_clip'] | |||
state_projection_clip_value = config['encoder']['proj_clip'] | |||
recurrent_dropout_probability = config['dropout'] | |||
self.input_size = input_size | |||
self.hidden_size = hidden_size | |||
self.num_layers = num_layers | |||
self.cell_size = cell_size | |||
forward_layers = [] | |||
backward_layers = [] | |||
lstm_input_size = input_size | |||
go_forward = True | |||
for layer_index in range(num_layers): | |||
forward_layer = LstmCellWithProjection(lstm_input_size, | |||
hidden_size, | |||
cell_size, | |||
go_forward, | |||
recurrent_dropout_probability, | |||
memory_cell_clip_value, | |||
state_projection_clip_value) | |||
backward_layer = LstmCellWithProjection(lstm_input_size, | |||
hidden_size, | |||
cell_size, | |||
not go_forward, | |||
recurrent_dropout_probability, | |||
memory_cell_clip_value, | |||
state_projection_clip_value) | |||
lstm_input_size = hidden_size | |||
self.add_module('forward_layer_{}'.format(layer_index), forward_layer) | |||
self.add_module('backward_layer_{}'.format(layer_index), backward_layer) | |||
forward_layers.append(forward_layer) | |||
backward_layers.append(backward_layer) | |||
self.forward_layers = forward_layers | |||
self.backward_layers = backward_layers | |||
def forward(self, inputs, seq_len): | |||
""" | |||
:param inputs: batch_size x max_len x embed_size | |||
:param seq_len: batch_size | |||
:return: torch.FloatTensor. num_layers x batch_size x max_len x hidden_size | |||
""" | |||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||
inputs = inputs[sort_idx] | |||
inputs = nn.utils.rnn.pack_padded_sequence(inputs, sort_lens, batch_first=True) | |||
output, _ = self._lstm_forward(inputs, None) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
output = output[:, unsort_idx] | |||
return output | |||
def _lstm_forward(self, | |||
inputs: PackedSequence, | |||
initial_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> \ | |||
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |||
""" | |||
Parameters | |||
---------- | |||
inputs : ``PackedSequence``, required. | |||
A batch first ``PackedSequence`` to run the stacked LSTM over. | |||
initial_state : ``Tuple[torch.Tensor, torch.Tensor]``, optional, (default = None) | |||
A tuple (state, memory) representing the initial hidden state and memory | |||
of the LSTM, with shape (num_layers, batch_size, 2 * hidden_size) and | |||
(num_layers, batch_size, 2 * cell_size) respectively. | |||
Returns | |||
------- | |||
output_sequence : ``torch.FloatTensor`` | |||
The encoded sequence of shape (num_layers, batch_size, sequence_length, hidden_size) | |||
final_states: ``Tuple[torch.FloatTensor, torch.FloatTensor]`` | |||
The per-layer final (state, memory) states of the LSTM, with shape | |||
(num_layers, batch_size, 2 * hidden_size) and (num_layers, batch_size, 2 * cell_size) | |||
respectively. The last dimension is duplicated because it contains the state/memory | |||
for both the forward and backward layers. | |||
""" | |||
if initial_state is None: | |||
hidden_states: List[Optional[Tuple[torch.Tensor, | |||
torch.Tensor]]] = [None] * len(self.forward_layers) | |||
elif initial_state[0].size()[0] != len(self.forward_layers): | |||
raise Exception("Initial states were passed to forward() but the number of " | |||
"initial states does not match the number of layers.") | |||
else: | |||
hidden_states = list(zip(initial_state[0].split(1, 0), initial_state[1].split(1, 0))) | |||
inputs, batch_lengths = pad_packed_sequence(inputs, batch_first=True) | |||
forward_output_sequence = inputs | |||
backward_output_sequence = inputs | |||
final_states = [] | |||
sequence_outputs = [] | |||
for layer_index, state in enumerate(hidden_states): | |||
forward_layer = getattr(self, 'forward_layer_{}'.format(layer_index)) | |||
backward_layer = getattr(self, 'backward_layer_{}'.format(layer_index)) | |||
forward_cache = forward_output_sequence | |||
backward_cache = backward_output_sequence | |||
if state is not None: | |||
forward_hidden_state, backward_hidden_state = state[0].split(self.hidden_size, 2) | |||
forward_memory_state, backward_memory_state = state[1].split(self.cell_size, 2) | |||
forward_state = (forward_hidden_state, forward_memory_state) | |||
backward_state = (backward_hidden_state, backward_memory_state) | |||
else: | |||
forward_state = None | |||
backward_state = None | |||
forward_output_sequence, forward_state = forward_layer(forward_output_sequence, | |||
batch_lengths, | |||
forward_state) | |||
backward_output_sequence, backward_state = backward_layer(backward_output_sequence, | |||
batch_lengths, | |||
backward_state) | |||
# Skip connections, just adding the input to the output. | |||
if layer_index != 0: | |||
forward_output_sequence += forward_cache | |||
backward_output_sequence += backward_cache | |||
sequence_outputs.append(torch.cat([forward_output_sequence, | |||
backward_output_sequence], -1)) | |||
# Append the state tuples in a list, so that we can return | |||
# the final states for all the layers. | |||
final_states.append((torch.cat([forward_state[0], backward_state[0]], -1), | |||
torch.cat([forward_state[1], backward_state[1]], -1))) | |||
stacked_sequence_outputs: torch.FloatTensor = torch.stack(sequence_outputs) | |||
# Stack the hidden state and memory for each layer into 2 tensors of shape | |||
# (num_layers, batch_size, hidden_size) and (num_layers, batch_size, cell_size) | |||
# respectively. | |||
final_hidden_states, final_memory_states = zip(*final_states) | |||
final_state_tuple: Tuple[torch.FloatTensor, | |||
torch.FloatTensor] = (torch.cat(final_hidden_states, 0), | |||
torch.cat(final_memory_states, 0)) | |||
return stacked_sequence_outputs, final_state_tuple | |||
class LstmTokenEmbedder(nn.Module): | |||
def __init__(self, config, word_emb_layer, char_emb_layer): | |||
super(LstmTokenEmbedder, self).__init__() | |||
self.config = config | |||
self.word_emb_layer = word_emb_layer | |||
self.char_emb_layer = char_emb_layer | |||
self.output_dim = config['encoder']['projection_dim'] | |||
emb_dim = 0 | |||
if word_emb_layer is not None: | |||
emb_dim += word_emb_layer.n_d | |||
if char_emb_layer is not None: | |||
emb_dim += char_emb_layer.n_d * 2 | |||
self.char_lstm = nn.LSTM(char_emb_layer.n_d, char_emb_layer.n_d, num_layers=1, bidirectional=True, | |||
batch_first=True, dropout=config['dropout']) | |||
self.projection = nn.Linear(emb_dim, self.output_dim, bias=True) | |||
def forward(self, words, chars): | |||
embs = [] | |||
if self.word_emb_layer is not None: | |||
if hasattr(self, 'words_to_words'): | |||
words = self.words_to_words[words] | |||
word_emb = self.word_emb_layer(words) | |||
embs.append(word_emb) | |||
if self.char_emb_layer is not None: | |||
batch_size, seq_len, _ = chars.shape | |||
chars = chars.view(batch_size * seq_len, -1) | |||
chars_emb = self.char_emb_layer(chars) | |||
# TODO 这里应该要考虑seq_len的问题 | |||
_, (chars_outputs, __) = self.char_lstm(chars_emb) | |||
chars_outputs = chars_outputs.contiguous().view(-1, self.config['token_embedder']['char_dim'] * 2) | |||
embs.append(chars_outputs) | |||
token_embedding = torch.cat(embs, dim=2) | |||
return self.projection(token_embedding) | |||
class ConvTokenEmbedder(nn.Module): | |||
def __init__(self, config, word_emb_layer, char_emb_layer): | |||
super(ConvTokenEmbedder, self).__init__() | |||
self.config = config | |||
self.word_emb_layer = word_emb_layer | |||
self.char_emb_layer = char_emb_layer | |||
self.output_dim = config['encoder']['projection_dim'] | |||
self.emb_dim = 0 | |||
if word_emb_layer is not None: | |||
self.emb_dim += word_emb_layer.weight.size(1) | |||
if char_emb_layer is not None: | |||
self.convolutions = [] | |||
cnn_config = config['token_embedder'] | |||
filters = cnn_config['filters'] | |||
char_embed_dim = cnn_config['char_dim'] | |||
for i, (width, num) in enumerate(filters): | |||
conv = torch.nn.Conv1d( | |||
in_channels=char_embed_dim, | |||
out_channels=num, | |||
kernel_size=width, | |||
bias=True | |||
) | |||
self.convolutions.append(conv) | |||
self.convolutions = nn.ModuleList(self.convolutions) | |||
self.n_filters = sum(f[1] for f in filters) | |||
self.n_highway = cnn_config['n_highway'] | |||
self.highways = Highway(self.n_filters, self.n_highway, activation=torch.nn.functional.relu) | |||
self.emb_dim += self.n_filters | |||
self.projection = nn.Linear(self.emb_dim, self.output_dim, bias=True) | |||
def forward(self, words, chars): | |||
embs = [] | |||
if self.word_emb_layer is not None: | |||
if hasattr(self, 'words_to_words'): | |||
words = self.words_to_words[words] | |||
word_emb = self.word_emb_layer(words) | |||
embs.append(word_emb) | |||
if self.char_emb_layer is not None: | |||
batch_size, seq_len, _ = chars.size() | |||
chars = chars.view(batch_size * seq_len, -1) | |||
character_embedding = self.char_emb_layer(chars) | |||
character_embedding = torch.transpose(character_embedding, 1, 2) | |||
cnn_config = self.config['token_embedder'] | |||
if cnn_config['activation'] == 'tanh': | |||
activation = torch.nn.functional.tanh | |||
elif cnn_config['activation'] == 'relu': | |||
activation = torch.nn.functional.relu | |||
else: | |||
raise Exception("Unknown activation") | |||
convs = [] | |||
for i in range(len(self.convolutions)): | |||
convolved = self.convolutions[i](character_embedding) | |||
# (batch_size * sequence_length, n_filters for this width) | |||
convolved, _ = torch.max(convolved, dim=-1) | |||
convolved = activation(convolved) | |||
convs.append(convolved) | |||
char_emb = torch.cat(convs, dim=-1) | |||
char_emb = self.highways(char_emb) | |||
embs.append(char_emb.view(batch_size, -1, self.n_filters)) | |||
token_embedding = torch.cat(embs, dim=2) | |||
return self.projection(token_embedding) | |||
class Highway(torch.nn.Module): | |||
""" | |||
A `Highway layer <https://arxiv.org/abs/1505.00387>`_ does a gated combination of a linear | |||
transformation and a non-linear transformation of its input. :math:`y = g * x + (1 - g) * | |||
f(A(x))`, where :math:`A` is a linear transformation, :math:`f` is an element-wise | |||
non-linearity, and :math:`g` is an element-wise gate, computed as :math:`sigmoid(B(x))`. | |||
This module will apply a fixed number of highway layers to its input, returning the final | |||
result. | |||
Parameters | |||
---------- | |||
input_dim : ``int`` | |||
The dimensionality of :math:`x`. We assume the input has shape ``(batch_size, | |||
input_dim)``. | |||
num_layers : ``int``, optional (default=``1``) | |||
The number of highway layers to apply to the input. | |||
activation : ``Callable[[torch.Tensor], torch.Tensor]``, optional (default=``torch.nn.functional.relu``) | |||
The non-linearity to use in the highway layers. | |||
""" | |||
def __init__(self, | |||
input_dim: int, | |||
num_layers: int = 1, | |||
activation: Callable[[torch.Tensor], torch.Tensor] = torch.nn.functional.relu) -> None: | |||
super(Highway, self).__init__() | |||
self._input_dim = input_dim | |||
self._layers = torch.nn.ModuleList([torch.nn.Linear(input_dim, input_dim * 2) | |||
for _ in range(num_layers)]) | |||
self._activation = activation | |||
for layer in self._layers: | |||
# We should bias the highway layer to just carry its input forward. We do that by | |||
# setting the bias on `B(x)` to be positive, because that means `g` will be biased to | |||
# be high, to we will carry the input forward. The bias on `B(x)` is the second half | |||
# of the bias vector in each Linear layer. | |||
layer.bias[input_dim:].data.fill_(1) | |||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: # pylint: disable=arguments-differ | |||
current_input = inputs | |||
for layer in self._layers: | |||
projected_input = layer(current_input) | |||
linear_part = current_input | |||
# NOTE: if you modify this, think about whether you should modify the initialization | |||
# above, too. | |||
nonlinear_part = projected_input[:, (0 * self._input_dim):(1 * self._input_dim)] | |||
gate = projected_input[:, (1 * self._input_dim):(2 * self._input_dim)] | |||
nonlinear_part = self._activation(nonlinear_part) | |||
gate = torch.sigmoid(gate) | |||
current_input = gate * linear_part + (1 - gate) * nonlinear_part | |||
return current_input | |||
class _ElmoModel(nn.Module): | |||
""" | |||
该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括 | |||
(1) 根据配置,加载模型; | |||
(2) 根据vocab,对模型中的embedding进行调整. 并将其正确初始化 | |||
(3) 保存一个words与chars的对应转换,获取时自动进行相应的转换 | |||
(4) 设计一个保存token的embedding,允许缓存word的表示。 | |||
""" | |||
def __init__(self, model_dir:str, vocab:Vocabulary=None, cache_word_reprs:bool=False): | |||
super(_ElmoModel, self).__init__() | |||
config = json.load(open(os.path.join(model_dir, 'structure_config.json'), 'r')) | |||
self.config = config | |||
OOV_TAG = '<oov>' | |||
PAD_TAG = '<pad>' | |||
BOS_TAG = '<bos>' | |||
EOS_TAG = '<eos>' | |||
BOW_TAG = '<bow>' | |||
EOW_TAG = '<eow>' | |||
# 将加载embedding放到这里 | |||
token_embedder_states = torch.load(os.path.join(model_dir, 'token_embedder.pkl'), map_location='cpu') | |||
# For the model trained with word form word encoder. | |||
if config['token_embedder']['word_dim'] > 0: | |||
word_lexicon = {} | |||
with codecs.open(os.path.join(model_dir, 'word.dic'), 'r', encoding='utf-8') as fpi: | |||
for line in fpi: | |||
tokens = line.strip().split('\t') | |||
if len(tokens) == 1: | |||
tokens.insert(0, '\u3000') | |||
token, i = tokens | |||
word_lexicon[token] = int(i) | |||
# 做一些sanity check | |||
for special_word in [PAD_TAG, OOV_TAG, BOS_TAG, EOS_TAG]: | |||
assert special_word in word_lexicon, f"{special_word} not found in word.dic." | |||
# 根据vocab调整word_embedding | |||
pre_word_embedding = token_embedder_states.pop('word_emb_layer.embedding.weight') | |||
word_emb_layer = nn.Embedding(len(vocab)+2, config['token_embedder']['word_dim']) #多增加两个是为了<bos>与<eos> | |||
found_word_count = 0 | |||
for word, index in vocab: | |||
if index == vocab.unknown_idx: # 因为fastNLP的unknow是<unk> 而在这里是<oov>所以ugly强制适配一下 | |||
index_in_pre = word_lexicon[OOV_TAG] | |||
found_word_count += 1 | |||
elif index == vocab.padding_idx: # 需要pad对齐 | |||
index_in_pre = word_lexicon[PAD_TAG] | |||
found_word_count += 1 | |||
elif word in word_lexicon: | |||
index_in_pre = word_lexicon[word] | |||
found_word_count += 1 | |||
else: | |||
index_in_pre = word_lexicon[OOV_TAG] | |||
word_emb_layer.weight.data[index] = pre_word_embedding[index_in_pre] | |||
print(f"{found_word_count} out of {len(vocab)} words were found in pretrained elmo embedding.") | |||
word_emb_layer.weight.data[-1] = pre_word_embedding[word_lexicon[EOS_TAG]] | |||
word_emb_layer.weight.data[-2] = pre_word_embedding[word_lexicon[BOS_TAG]] | |||
self.word_vocab = vocab | |||
else: | |||
word_emb_layer = None | |||
# For the model trained with character-based word encoder. | |||
if config['token_embedder']['char_dim'] > 0: | |||
char_lexicon = {} | |||
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: | |||
for line in fpi: | |||
tokens = line.strip().split('\t') | |||
if len(tokens) == 1: | |||
tokens.insert(0, '\u3000') | |||
token, i = tokens | |||
char_lexicon[token] = int(i) | |||
# 做一些sanity check | |||
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: | |||
assert special_word in char_lexicon, f"{special_word} not found in char.dic." | |||
# 从vocab中构建char_vocab | |||
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) | |||
# 需要保证<bow>与<eow>在里面 | |||
char_vocab.add_word(BOW_TAG) | |||
char_vocab.add_word(EOW_TAG) | |||
for word, index in vocab: | |||
char_vocab.add_word_lst(list(word)) | |||
# 保证<eos>, <bos>也在 | |||
char_vocab.add_word_lst(list(BOS_TAG)) | |||
char_vocab.add_word_lst(list(EOS_TAG)) | |||
# 根据char_lexicon调整 | |||
char_emb_layer = nn.Embedding(len(char_vocab), int(config['token_embedder']['char_dim'])) | |||
pre_char_embedding = token_embedder_states.pop('char_emb_layer.embedding.weight') | |||
found_char_count = 0 | |||
for char, index in char_vocab: # 调整character embedding | |||
if char in char_lexicon: | |||
index_in_pre = char_lexicon.get(char) | |||
found_char_count += 1 | |||
else: | |||
index_in_pre = char_lexicon[OOV_TAG] | |||
char_emb_layer.weight.data[index] = pre_char_embedding[index_in_pre] | |||
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||
# 生成words到chars的映射 | |||
if config['token_embedder']['name'].lower() == 'cnn': | |||
max_chars = config['token_embedder']['max_characters_per_token'] | |||
elif config['token_embedder']['name'].lower() == 'lstm': | |||
max_chars = max(map(lambda x: len(x[0]), vocab)) + 2 # 需要补充两个<bow>与<eow> | |||
else: | |||
raise ValueError('Unknown token_embedder: {0}'.format(config['token_embedder']['name'])) | |||
# 增加<bos>, <eos>所以加2. | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab)+2, max_chars), | |||
fill_value=char_vocab.to_index(PAD_TAG), dtype=torch.long), | |||
requires_grad=False) | |||
for word, index in vocab: | |||
if len(word)+2>max_chars: | |||
word = word[:max_chars-2] | |||
if index==vocab.padding_idx: # 如果是pad的话,需要和给定的对齐 | |||
word = PAD_TAG | |||
elif index==vocab.unknown_idx: | |||
word = OOV_TAG | |||
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] | |||
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) | |||
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) | |||
for index, word in enumerate([BOS_TAG, EOS_TAG]): # 加上<eos>, <bos> | |||
if len(word)+2>max_chars: | |||
word = word[:max_chars-2] | |||
char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [char_vocab.to_index(EOW_TAG)] | |||
char_ids += [char_vocab.to_index(PAD_TAG)]*(max_chars-len(char_ids)) | |||
self.words_to_chars_embedding[index+len(vocab)] = torch.LongTensor(char_ids) | |||
self.char_vocab = char_vocab | |||
else: | |||
char_emb_layer = None | |||
if config['token_embedder']['name'].lower() == 'cnn': | |||
self.token_embedder = ConvTokenEmbedder( | |||
config, word_emb_layer, char_emb_layer) | |||
elif config['token_embedder']['name'].lower() == 'lstm': | |||
self.token_embedder = LstmTokenEmbedder( | |||
config, word_emb_layer, char_emb_layer) | |||
self.token_embedder.load_state_dict(token_embedder_states, strict=False) | |||
if config['token_embedder']['word_dim'] > 0 and vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk | |||
words_to_words = nn.Parameter(torch.arange(len(vocab)+2).long(), requires_grad=False) | |||
for word, idx in vocab: | |||
if vocab._is_word_no_create_entry(word): | |||
words_to_words[idx] = vocab.unknown_idx | |||
setattr(self.token_embedder, 'words_to_words', words_to_words) | |||
self.output_dim = config['encoder']['projection_dim'] | |||
if config['encoder']['name'].lower() == 'elmo': | |||
self.encoder = ElmobiLm(config) | |||
elif config['encoder']['name'].lower() == 'lstm': | |||
self.encoder = LstmbiLm(config) | |||
self.encoder.load_state_dict(torch.load(os.path.join(model_dir, 'encoder.pkl'), | |||
map_location='cpu')) | |||
self.bos_index = len(vocab) | |||
self.eos_index = len(vocab) + 1 | |||
self._pad_index = vocab.padding_idx | |||
if cache_word_reprs: | |||
if config['token_embedder']['char_dim']>0: # 只有在使用了chars的情况下有用 | |||
print("Start to generate cache word representations.") | |||
batch_size = 320 | |||
num_batches = self.words_to_chars_embedding.size(0)//batch_size + \ | |||
int(self.words_to_chars_embedding.size(0)%batch_size!=0) | |||
self.cached_word_embedding = nn.Embedding(self.words_to_chars_embedding.size(0), | |||
config['encoder']['projection_dim']) | |||
with torch.no_grad(): | |||
for i in range(num_batches): | |||
words = torch.arange(i*batch_size, min((i+1)*batch_size, self.words_to_chars_embedding.size(0))).long() | |||
chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars | |||
word_reprs = self.token_embedder(words.unsqueeze(1), chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] | |||
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) | |||
print("Finish generating cached word representations. Going to delete the character encoder.") | |||
del self.token_embedder, self.words_to_chars_embedding | |||
else: | |||
print("There is no need to cache word representations, since no character information is used.") | |||
def forward(self, words): | |||
""" | |||
:param words: batch_size x max_len | |||
:return: num_layers x batch_size x max_len x hidden_size | |||
""" | |||
# 扩展<bos>, <eos> | |||
batch_size, max_len = words.size() | |||
expanded_words = words.new_zeros(batch_size, max_len + 2) # 因为pad一定为0, | |||
seq_len = words.ne(self._pad_index).sum(dim=-1) | |||
expanded_words[:, 1:-1] = words | |||
expanded_words[:, 0].fill_(self.bos_index) | |||
expanded_words[torch.arange(batch_size).to(words), seq_len+1] = self.eos_index | |||
seq_len = seq_len + 2 | |||
if hasattr(self, 'cached_word_embedding'): | |||
token_embedding = self.cached_word_embedding(expanded_words) | |||
else: | |||
if hasattr(self, 'words_to_chars_embedding'): | |||
chars = self.words_to_chars_embedding[expanded_words] | |||
else: | |||
chars = None | |||
token_embedding = self.token_embedder(expanded_words, chars) | |||
if self.config['encoder']['name'] == 'elmo': | |||
encoder_output = self.encoder(token_embedding, seq_len) | |||
if encoder_output.size(2) < max_len+2: | |||
dummy_tensor = encoder_output.new_zeros(encoder_output.size(0), batch_size, | |||
max_len + 2 - encoder_output.size(2), encoder_output.size(-1)) | |||
encoder_output = torch.cat([encoder_output, dummy_tensor], 2) | |||
sz = encoder_output.size() # 2, batch_size, max_len, hidden_size | |||
token_embedding = torch.cat([token_embedding, token_embedding], dim=2).view(1, sz[1], sz[2], sz[3]) | |||
encoder_output = torch.cat([token_embedding, encoder_output], dim=0) | |||
elif self.config['encoder']['name'] == 'lstm': | |||
encoder_output = self.encoder(token_embedding, seq_len) | |||
else: | |||
raise ValueError('Unknown encoder: {0}'.format(self.config['encoder']['name'])) | |||
# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 | |||
encoder_output = encoder_output[:, :, 1:-1] | |||
return encoder_output |
@@ -1,377 +1,92 @@ | |||
""" | |||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||
""" | |||
import copy | |||
import json | |||
import math | |||
import os | |||
import torch | |||
from torch import nn | |||
import torch | |||
from ...io.file_utils import _get_base_url, cached_path | |||
from ._bert import _WordPieceBertModel, BertModel | |||
CONFIG_FILE = 'bert_config.json' | |||
MODEL_WEIGHTS = 'pytorch_model.bin' | |||
def gelu(x): | |||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) | |||
def swish(x): | |||
return x * torch.sigmoid(x) | |||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} | |||
class BertLayerNorm(nn.Module): | |||
def __init__(self, hidden_size, eps=1e-12): | |||
super(BertLayerNorm, self).__init__() | |||
self.weight = nn.Parameter(torch.ones(hidden_size)) | |||
self.bias = nn.Parameter(torch.zeros(hidden_size)) | |||
self.variance_epsilon = eps | |||
def forward(self, x): | |||
u = x.mean(-1, keepdim=True) | |||
s = (x - u).pow(2).mean(-1, keepdim=True) | |||
x = (x - u) / torch.sqrt(s + self.variance_epsilon) | |||
return self.weight * x + self.bias | |||
class BertEmbeddings(nn.Module): | |||
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size, hidden_dropout_prob): | |||
super(BertEmbeddings, self).__init__() | |||
self.word_embeddings = nn.Embedding(vocab_size, hidden_size) | |||
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) | |||
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size) | |||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load | |||
# any TensorFlow checkpoint file | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
def forward(self, input_ids, token_type_ids=None): | |||
seq_length = input_ids.size(1) | |||
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) | |||
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) | |||
if token_type_ids is None: | |||
token_type_ids = torch.zeros_like(input_ids) | |||
words_embeddings = self.word_embeddings(input_ids) | |||
position_embeddings = self.position_embeddings(position_ids) | |||
token_type_embeddings = self.token_type_embeddings(token_type_ids) | |||
embeddings = words_embeddings + position_embeddings + token_type_embeddings | |||
embeddings = self.LayerNorm(embeddings) | |||
embeddings = self.dropout(embeddings) | |||
return embeddings | |||
class BertSelfAttention(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): | |||
super(BertSelfAttention, self).__init__() | |||
if hidden_size % num_attention_heads != 0: | |||
raise ValueError( | |||
"The hidden size (%d) is not a multiple of the number of attention " | |||
"heads (%d)" % (hidden_size, num_attention_heads)) | |||
self.num_attention_heads = num_attention_heads | |||
self.attention_head_size = int(hidden_size / num_attention_heads) | |||
self.all_head_size = self.num_attention_heads * self.attention_head_size | |||
self.query = nn.Linear(hidden_size, self.all_head_size) | |||
self.key = nn.Linear(hidden_size, self.all_head_size) | |||
self.value = nn.Linear(hidden_size, self.all_head_size) | |||
self.dropout = nn.Dropout(attention_probs_dropout_prob) | |||
def transpose_for_scores(self, x): | |||
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) | |||
x = x.view(*new_x_shape) | |||
return x.permute(0, 2, 1, 3) | |||
def forward(self, hidden_states, attention_mask): | |||
mixed_query_layer = self.query(hidden_states) | |||
mixed_key_layer = self.key(hidden_states) | |||
mixed_value_layer = self.value(hidden_states) | |||
query_layer = self.transpose_for_scores(mixed_query_layer) | |||
key_layer = self.transpose_for_scores(mixed_key_layer) | |||
value_layer = self.transpose_for_scores(mixed_value_layer) | |||
# Take the dot product between "query" and "key" to get the raw attention scores. | |||
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) | |||
attention_scores = attention_scores / math.sqrt(self.attention_head_size) | |||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function) | |||
attention_scores = attention_scores + attention_mask | |||
# Normalize the attention scores to probabilities. | |||
attention_probs = nn.Softmax(dim=-1)(attention_scores) | |||
# This is actually dropping out entire tokens to attend to, which might | |||
# seem a bit unusual, but is taken from the original Transformer paper. | |||
attention_probs = self.dropout(attention_probs) | |||
context_layer = torch.matmul(attention_probs, value_layer) | |||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() | |||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) | |||
context_layer = context_layer.view(*new_context_layer_shape) | |||
return context_layer | |||
class BertSelfOutput(nn.Module): | |||
def __init__(self, hidden_size, hidden_dropout_prob): | |||
super(BertSelfOutput, self).__init__() | |||
self.dense = nn.Linear(hidden_size, hidden_size) | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class BertAttention(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): | |||
super(BertAttention, self).__init__() | |||
self.self = BertSelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) | |||
self.output = BertSelfOutput(hidden_size, hidden_dropout_prob) | |||
def forward(self, input_tensor, attention_mask): | |||
self_output = self.self(input_tensor, attention_mask) | |||
attention_output = self.output(self_output, input_tensor) | |||
return attention_output | |||
class BertIntermediate(nn.Module): | |||
def __init__(self, hidden_size, intermediate_size, hidden_act): | |||
super(BertIntermediate, self).__init__() | |||
self.dense = nn.Linear(hidden_size, intermediate_size) | |||
self.intermediate_act_fn = ACT2FN[hidden_act] \ | |||
if isinstance(hidden_act, str) else hidden_act | |||
def forward(self, hidden_states): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.intermediate_act_fn(hidden_states) | |||
return hidden_states | |||
class BertOutput(nn.Module): | |||
def __init__(self, hidden_size, intermediate_size, hidden_dropout_prob): | |||
super(BertOutput, self).__init__() | |||
self.dense = nn.Linear(intermediate_size, hidden_size) | |||
self.LayerNorm = BertLayerNorm(hidden_size, eps=1e-12) | |||
self.dropout = nn.Dropout(hidden_dropout_prob) | |||
def forward(self, hidden_states, input_tensor): | |||
hidden_states = self.dense(hidden_states) | |||
hidden_states = self.dropout(hidden_states) | |||
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |||
return hidden_states | |||
class BertLayer(nn.Module): | |||
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
intermediate_size, hidden_act): | |||
super(BertLayer, self).__init__() | |||
self.attention = BertAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
hidden_dropout_prob) | |||
self.intermediate = BertIntermediate(hidden_size, intermediate_size, hidden_act) | |||
self.output = BertOutput(hidden_size, intermediate_size, hidden_dropout_prob) | |||
def forward(self, hidden_states, attention_mask): | |||
attention_output = self.attention(hidden_states, attention_mask) | |||
intermediate_output = self.intermediate(attention_output) | |||
layer_output = self.output(intermediate_output, attention_output) | |||
return layer_output | |||
class BertEncoder(nn.Module): | |||
def __init__(self, num_hidden_layers, hidden_size, num_attention_heads, attention_probs_dropout_prob, | |||
hidden_dropout_prob, | |||
intermediate_size, hidden_act): | |||
super(BertEncoder, self).__init__() | |||
layer = BertLayer(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob, | |||
intermediate_size, hidden_act) | |||
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(num_hidden_layers)]) | |||
def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True): | |||
all_encoder_layers = [] | |||
for layer_module in self.layer: | |||
hidden_states = layer_module(hidden_states, attention_mask) | |||
if output_all_encoded_layers: | |||
all_encoder_layers.append(hidden_states) | |||
if not output_all_encoded_layers: | |||
all_encoder_layers.append(hidden_states) | |||
return all_encoder_layers | |||
class BertPooler(nn.Module): | |||
def __init__(self, hidden_size): | |||
super(BertPooler, self).__init__() | |||
self.dense = nn.Linear(hidden_size, hidden_size) | |||
self.activation = nn.Tanh() | |||
def forward(self, hidden_states): | |||
# We "pool" the model by simply taking the hidden state corresponding | |||
# to the first token. | |||
first_token_tensor = hidden_states[:, 0] | |||
pooled_output = self.dense(first_token_tensor) | |||
pooled_output = self.activation(pooled_output) | |||
return pooled_output | |||
class BertModel(nn.Module): | |||
"""BERT(Bidirectional Embedding Representations from Transformers). | |||
如果你想使用预训练好的权重矩阵,请在以下网址下载. | |||
sources:: | |||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz", | |||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased.tar.gz", | |||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz", | |||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased.tar.gz", | |||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased.tar.gz", | |||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased.tar.gz", | |||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese.tar.gz", | |||
用预训练权重矩阵来建立BERT模型:: | |||
model = BertModel.from_pretrained("path/to/weights/directory") | |||
用随机初始化权重矩阵来建立BERT模型:: | |||
model = BertModel() | |||
:param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小 | |||
:param int hidden_size: 隐层大小,默认值为768,为BERT base的版本 | |||
:param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本 | |||
:param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本 | |||
:param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本 | |||
:param str hidden_act: FFN隐藏层激活函数,默认值为``gelu`` | |||
:param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1 | |||
:param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1 | |||
:param int max_position_embeddings: 最大的序列长度,默认值为512, | |||
:param int type_vocab_size: 最大segment数量,默认值为2 | |||
:param int initializer_range: 初始化权重范围,默认值为0.02 | |||
class BertWordPieceEncoder(nn.Module): | |||
""" | |||
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 | |||
def __init__(self, vocab_size=30522, | |||
hidden_size=768, | |||
num_hidden_layers=12, | |||
num_attention_heads=12, | |||
intermediate_size=3072, | |||
hidden_act="gelu", | |||
hidden_dropout_prob=0.1, | |||
attention_probs_dropout_prob=0.1, | |||
max_position_embeddings=512, | |||
type_vocab_size=2, | |||
initializer_range=0.02, **kwargs): | |||
super(BertModel, self).__init__() | |||
self.embeddings = BertEmbeddings(vocab_size, hidden_size, max_position_embeddings, | |||
type_vocab_size, hidden_dropout_prob) | |||
self.encoder = BertEncoder(num_hidden_layers, hidden_size, num_attention_heads, | |||
attention_probs_dropout_prob, hidden_dropout_prob, intermediate_size, | |||
hidden_act) | |||
self.pooler = BertPooler(hidden_size) | |||
self.initializer_range = initializer_range | |||
self.apply(self.init_bert_weights) | |||
def init_bert_weights(self, module): | |||
if isinstance(module, (nn.Linear, nn.Embedding)): | |||
# Slightly different from the TF version which uses truncated_normal for initialization | |||
# cf https://github.com/pytorch/pytorch/pull/5617 | |||
module.weight.data.normal_(mean=0.0, std=self.initializer_range) | |||
elif isinstance(module, BertLayerNorm): | |||
module.bias.data.zero_() | |||
module.weight.data.fill_(1.0) | |||
if isinstance(module, nn.Linear) and module.bias is not None: | |||
module.bias.data.zero_() | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True): | |||
if attention_mask is None: | |||
attention_mask = torch.ones_like(input_ids) | |||
if token_type_ids is None: | |||
token_type_ids = torch.zeros_like(input_ids) | |||
# We create a 3D attention mask from a 2D tensor mask. | |||
# Sizes are [batch_size, 1, 1, to_seq_length] | |||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |||
# this attention mask is more simple than the triangular masking of causal attention | |||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) | |||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |||
# masked positions, this operation will create a tensor which is 0.0 for | |||
# positions we want to attend and -10000.0 for masked positions. | |||
# Since we are adding it to the raw scores before the softmax, this is | |||
# effectively the same as removing these entirely. | |||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility | |||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |||
embedding_output = self.embeddings(input_ids, token_type_ids) | |||
encoded_layers = self.encoder(embedding_output, | |||
extended_attention_mask, | |||
output_all_encoded_layers=output_all_encoded_layers) | |||
sequence_output = encoded_layers[-1] | |||
pooled_output = self.pooler(sequence_output) | |||
if not output_all_encoded_layers: | |||
encoded_layers = encoded_layers[-1] | |||
return encoded_layers, pooled_output | |||
@classmethod | |||
def from_pretrained(cls, pretrained_model_dir, state_dict=None, *inputs, **kwargs): | |||
# Load config | |||
config_file = os.path.join(pretrained_model_dir, CONFIG_FILE) | |||
config = json.load(open(config_file, "r")) | |||
# config = BertConfig.from_json_file(config_file) | |||
# logger.info("Model config {}".format(config)) | |||
# Instantiate model. | |||
model = cls(*inputs, **config, **kwargs) | |||
if state_dict is None: | |||
weights_path = os.path.join(pretrained_model_dir, MODEL_WEIGHTS) | |||
state_dict = torch.load(weights_path) | |||
old_keys = [] | |||
new_keys = [] | |||
for key in state_dict.keys(): | |||
new_key = None | |||
if 'gamma' in key: | |||
new_key = key.replace('gamma', 'weight') | |||
if 'beta' in key: | |||
new_key = key.replace('beta', 'bias') | |||
if new_key: | |||
old_keys.append(key) | |||
new_keys.append(new_key) | |||
for old_key, new_key in zip(old_keys, new_keys): | |||
state_dict[new_key] = state_dict.pop(old_key) | |||
missing_keys = [] | |||
unexpected_keys = [] | |||
error_msgs = [] | |||
# copy state_dict so _load_from_state_dict can modify it | |||
metadata = getattr(state_dict, '_metadata', None) | |||
state_dict = state_dict.copy() | |||
if metadata is not None: | |||
state_dict._metadata = metadata | |||
def load(module, prefix=''): | |||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) | |||
module._load_from_state_dict( | |||
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) | |||
for name, child in module._modules.items(): | |||
if child is not None: | |||
load(child, prefix + name + '.') | |||
load(model, prefix='' if hasattr(model, 'bert') else 'bert.') | |||
if len(missing_keys) > 0: | |||
print("Weights of {} not initialized from pretrained model: {}".format( | |||
model.__class__.__name__, missing_keys)) | |||
if len(unexpected_keys) > 0: | |||
print("Weights from pretrained model not used in {}: {}".format( | |||
model.__class__.__name__, unexpected_keys)) | |||
return model | |||
:param fastNLP.Vocabulary vocab: 词表 | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` | |||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
:param bool requires_grad: 是否需要gradient。 | |||
""" | |||
def __init__(self, model_dir_or_name:str='en-base-uncased', layers:str='-1', | |||
requires_grad:bool=False): | |||
super().__init__() | |||
PRETRAIN_URL = _get_base_url('bert') | |||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
} | |||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(model_dir_or_name): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers) | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
self.requires_grad = requires_grad | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters()]) | |||
if len(requires_grads)==1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
param.requires_grad = value | |||
@property | |||
def embed_size(self): | |||
return self._embed_size | |||
def index_datasets(self, *datasets, field_name): | |||
""" | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||
:param datasets: DataSet对象 | |||
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||
:return: | |||
""" | |||
self.model.index_dataset(*datasets, field_name=field_name) | |||
def forward(self, word_pieces, token_type_ids=None): | |||
""" | |||
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | |||
:param words: batch_size x max_len | |||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
""" | |||
outputs = self.model(word_pieces, token_type_ids) | |||
outputs = torch.cat([*outputs], dim=-1) | |||
return outputs |
@@ -22,10 +22,10 @@ class ConvolutionCharEncoder(nn.Module): | |||
:param initial_method: 初始化参数的方式, 默认为`xavier normal` | |||
""" | |||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None): | |||
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None): | |||
super(ConvolutionCharEncoder, self).__init__() | |||
self.convs = nn.ModuleList([ | |||
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4)) | |||
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, kernels[i]//2)) | |||
for i in range(len(kernels))]) | |||
initial_parameter(self, initial_method) | |||
@@ -5,9 +5,6 @@ import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from ..utils import initial_parameter | |||
class ConvMaxpool(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.ConvMaxpool` :class:`fastNLP.modules.encoder.conv_maxpool.ConvMaxpool` | |||
@@ -19,20 +16,15 @@ class ConvMaxpool(nn.Module): | |||
:param int in_channels: 输入channel的大小,一般是embedding的维度; 或encoder的output维度 | |||
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 | |||
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 | |||
:param int stride: 见pytorch Conv1D文档。所有kernel共享一个stride。 | |||
:param int padding: 见pytorch Conv1D文档。所有kernel共享一个padding。 | |||
:param int dilation: 见pytorch Conv1D文档。所有kernel共享一个dilation。 | |||
:param int groups: 见pytorch Conv1D文档。所有kernel共享一个groups。 | |||
:param bool bias: 见pytorch Conv1D文档。所有kernel共享一个bias。 | |||
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh | |||
:param str initial_method: str。 | |||
""" | |||
def __init__(self, in_channels, out_channels, kernel_sizes, | |||
stride=1, padding=0, dilation=1, | |||
groups=1, bias=True, activation="relu", initial_method=None): | |||
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"): | |||
super(ConvMaxpool, self).__init__() | |||
for kernel_size in kernel_sizes: | |||
assert kernel_size%2==1, "kernel size has to be odd numbers." | |||
# convolution | |||
if isinstance(kernel_sizes, (list, tuple, int)): | |||
if isinstance(kernel_sizes, int) and isinstance(out_channels, int): | |||
@@ -49,11 +41,11 @@ class ConvMaxpool(nn.Module): | |||
in_channels=in_channels, | |||
out_channels=oc, | |||
kernel_size=ks, | |||
stride=stride, | |||
padding=padding, | |||
dilation=dilation, | |||
groups=groups, | |||
bias=bias) | |||
stride=1, | |||
padding=ks//2, | |||
dilation=1, | |||
groups=1, | |||
bias=None) | |||
for oc, ks in zip(out_channels, kernel_sizes)]) | |||
else: | |||
@@ -70,9 +62,7 @@ class ConvMaxpool(nn.Module): | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: relu, tanh, sigmoid") | |||
initial_parameter(self, initial_method) | |||
def forward(self, x, mask=None): | |||
""" | |||
@@ -86,7 +76,7 @@ class ConvMaxpool(nn.Module): | |||
xs = [self.activation(conv(x)) for conv in self.convs] # [[N,C,L], ...] | |||
if mask is not None: | |||
mask = mask.unsqueeze(1) # B x 1 x L | |||
xs = [x.masked_fill_(mask, float('-inf')) for x in xs] | |||
xs = [x.masked_fill_(mask.eq(0), float('-inf')) for x in xs] | |||
# max-pooling | |||
xs = [F.max_pool1d(input=i, kernel_size=i.size(2)).squeeze(2) | |||
for i in xs] # [[N, C], ...] | |||
@@ -1,50 +1,964 @@ | |||
__all__ = [ | |||
"Embedding" | |||
"Embedding", | |||
"StaticEmbedding", | |||
"ElmoEmbedding", | |||
"BertEmbedding", | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
] | |||
import torch.nn as nn | |||
from ..utils import get_embeddings | |||
from .lstm import LSTM | |||
from ...core.vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
import torch | |||
import numpy as np | |||
import torch.nn.functional as F | |||
import os | |||
from ._elmo import _ElmoModel | |||
from ...io.file_utils import cached_path, _get_base_url | |||
from ._bert import _WordBertModel | |||
from typing import List | |||
import warnings | |||
from ...core.dataset import DataSet | |||
from ...core.batch import DataSetIter | |||
from ...core.sampler import SequentialSampler | |||
from ...core.utils import _move_model_to_device, _get_model_device | |||
class Embedding(nn.Embedding): | |||
class Embedding(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding` | |||
Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度""" | |||
def __init__(self, init_embed, padding_idx=None, dropout=0.0, sparse=False, max_norm=None, norm_type=2, | |||
scale_grad_by_freq=False): | |||
def __init__(self, init_embed, dropout=0.0, dropout_word=0, unk_index=None): | |||
""" | |||
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding | |||
:param None,int padding_idx: 该index的Embedding将一直为0. | |||
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding; | |||
也可以传入TokenEmbedding对象 | |||
:param float dropout: 对Embedding的输出的dropout。 | |||
:param bool sparse: 如果为True,则对Embedding的梯度将是sparse的,参考Pytorch Embedding获取更多信息。 | |||
:param None,float max_norm: 每个vector最大的norm能为多大 | |||
:param int norm_type: norm的类型 | |||
:param bool scale_grad_by_freq: 如果为True,将会把梯度除以这个词出现的次数. | |||
:param float dropout_word: 按照一定比例随机将word设置为unk的idx,这样可以使得unk这个token得到足够的训练 | |||
:param int unk_index: drop word时替换为的index,如果init_embed为TokenEmbedding不需要传入该值。 | |||
""" | |||
embed = get_embeddings(init_embed) | |||
num_embeddings, embedding_dim = embed.weight.size() | |||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx, | |||
max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, | |||
sparse=sparse, _weight=embed.weight.data) | |||
del embed | |||
super(Embedding, self).__init__() | |||
self.embed = get_embeddings(init_embed) | |||
self.dropout = nn.Dropout(dropout) | |||
if not isinstance(self.embed, TokenEmbedding): | |||
self._embed_size = self.embed.weight.size(1) | |||
if dropout_word>0 and not isinstance(unk_index, int): | |||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | |||
else: | |||
self._embed_size = self.embed.embed_size | |||
unk_index = self.embed.get_word_vocab().unknown_idx | |||
self.unk_index = unk_index | |||
self.dropout_word = dropout_word | |||
def forward(self, x): | |||
""" | |||
:param torch.LongTensor x: [batch, seq_len] | |||
:return: torch.Tensor : [batch, seq_len, embed_dim] | |||
""" | |||
x = super().forward(x) | |||
if self.dropout_word>0 and self.training: | |||
mask = torch.ones_like(x).float() * self.dropout_word | |||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||
x = x.masked_fill(mask, self.unk_index) | |||
x = self.embed(x) | |||
return self.dropout(x) | |||
@property | |||
def num_embedding(self)->int: | |||
if isinstance(self.embed, nn.Embedding): | |||
return self.embed.weight.size(0) | |||
else: | |||
return self.embed.num_embedding | |||
def __len__(self): | |||
return len(self.embed) | |||
@property | |||
def embed_size(self) -> int: | |||
return self._embed_size | |||
@property | |||
def embedding_dim(self) -> int: | |||
return self._embed_size | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
if not isinstance(self.embed, TokenEmbedding): | |||
return self.embed.weight.requires_grad | |||
else: | |||
return self.embed.requires_grad | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
if not isinstance(self.embed, TokenEmbedding): | |||
self.embed.weight.requires_grad = value | |||
else: | |||
self.embed.requires_grad = value | |||
@property | |||
def size(self): | |||
if isinstance(self.embed, TokenEmbedding): | |||
return self.embed.size | |||
else: | |||
return self.embed.weight.size() | |||
class TokenEmbedding(nn.Module): | |||
def __init__(self, vocab): | |||
super(TokenEmbedding, self).__init__() | |||
assert vocab.padding_idx is not None, "You vocabulary must have padding." | |||
self._word_vocab = vocab | |||
self._word_pad_index = vocab.padding_idx | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for param in self.parameters()]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for param in self.parameters(): | |||
param.requires_grad = value | |||
def __len__(self): | |||
return len(self._word_vocab) | |||
@property | |||
def embed_size(self) -> int: | |||
return self._embed_size | |||
@property | |||
def num_embedding(self) -> int: | |||
return len(self._word_vocab) | |||
def get_word_vocab(self): | |||
""" | |||
返回embedding的词典。 | |||
:return: Vocabulary | |||
""" | |||
return self._word_vocab | |||
@property | |||
def size(self): | |||
return torch.Size(self.num_embedding, self._embed_size) | |||
class StaticEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.modules.StaticEmbedding` :class:`fastNLP.modules.encoder.embedding.StaticEmbedding` | |||
StaticEmbedding组件. 给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了 | |||
Example:: | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50') | |||
:param vocab: Vocabulary. 若该项为None则会读取所有的embedding。 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding | |||
的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d, | |||
`en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | |||
:param requires_grad: 是否需要gradient. 默认为True | |||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.xavier_uniform_ | |||
。调用该方法时传入一个tensor对象。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None): | |||
super(StaticEmbedding, self).__init__(vocab) | |||
# 优先定义需要下载的static embedding有哪些。这里估计需要自己搞一个server, | |||
PRETRAIN_STATIC_FILES = { | |||
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", | |||
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", | |||
'en-fasttext': "cc.en.300.vec-d53187b2.gz", | |||
'cn': "tencent_cn-dab24577.tar.gz", | |||
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", | |||
} | |||
# 得到cache_path | |||
if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | |||
PRETRAIN_URL = _get_base_url('static') | |||
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_path = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_path = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
# 读取embedding | |||
embedding, hit_flags = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||
padding_idx=vocab.padding_idx, | |||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||
sparse=False, _weight=embedding) | |||
if vocab._no_create_word_length > 0: # 需要映射,使得来自于dev, test的idx指向unk | |||
words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False) | |||
for word, idx in vocab: | |||
if vocab._is_word_no_create_entry(word) and not hit_flags[idx]: | |||
words_to_words[idx] = vocab.unknown_idx | |||
self.words_to_words = words_to_words | |||
self._embed_size = self.embedding.weight.size(1) | |||
self.requires_grad = requires_grad | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'words_to_words' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_words' in name: | |||
continue | |||
param.requires_grad = value | |||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore', init_method=None): | |||
""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 | |||
没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||
:param dtype: 读出的embedding的类型 | |||
:param str padding: 词表中padding的token | |||
:param str unknown: 词表中unknown的token | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 | |||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 | |||
:param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.zeros_ | |||
:return torch.tensor: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
if not os.path.exists(embed_filepath): | |||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||
if init_method is None: | |||
init_method = nn.init.xavier_uniform_ | |||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
found_count = 0 | |||
line = f.readline().strip() | |||
parts = line.split() | |||
start_idx = 0 | |||
if len(parts) == 2: | |||
dim = int(parts[1]) | |||
start_idx += 1 | |||
else: | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = torch.zeros(len(vocab), dim) | |||
init_method(matrix) | |||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
parts = line.strip().split() | |||
word = ''.join(parts[:-dim]) | |||
nums = parts[-dim:] | |||
# 对齐unk与pad | |||
if word == padding and vocab.padding is not None: | |||
word = vocab.padding | |||
elif word == unknown and vocab.unknown is not None: | |||
word = vocab.unknown | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | |||
found_count += 1 | |||
hit_flags[index] = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
if normalize: | |||
matrix /= (torch.norm(matrix, dim=1, keepdim=True) + 1e-12) | |||
return matrix, hit_flags | |||
def forward(self, words): | |||
""" | |||
传入words的index | |||
:param words: torch.LongTensor, [batch_size, max_len] | |||
:return: torch.FloatTensor, [batch_size, max_len, embed_size] | |||
""" | |||
if hasattr(self, 'words_to_words'): | |||
words = self.words_to_words[words] | |||
return self.embedding(words) | |||
class ContextualEmbedding(TokenEmbedding): | |||
def __init__(self, vocab: Vocabulary): | |||
super(ContextualEmbedding, self).__init__(vocab) | |||
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True): | |||
""" | |||
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 | |||
Example:: | |||
>>> | |||
:param datasets: DataSet对象 | |||
:param batch_size: int, 生成cache的sentence表示时使用的batch的大小 | |||
:param device: 参考 :class::fastNLP.Trainer 的device | |||
:param delete_weights: 似乎在生成了cache之后删除权重,在不需要finetune动态模型的情况下,删除权重会大量减少内存占用。 | |||
:return: | |||
""" | |||
for index, dataset in enumerate(datasets): | |||
try: | |||
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." | |||
assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." | |||
except Exception as e: | |||
print(f"Exception happens at {index} dataset.") | |||
raise e | |||
sent_embeds = {} | |||
_move_model_to_device(self, device=device) | |||
device = _get_model_device(self) | |||
pad_index = self._word_vocab.padding_idx | |||
print("Start to calculate sentence representations.") | |||
with torch.no_grad(): | |||
for index, dataset in enumerate(datasets): | |||
try: | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_x, batch_y in batch: | |||
words = batch_x['words'].to(device) | |||
words_list = words.tolist() | |||
seq_len = words.ne(pad_index).sum(dim=-1) | |||
max_len = words.size(1) | |||
# 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。 | |||
seq_len_from_behind =(max_len - seq_len).tolist() | |||
word_embeds = self(words).detach().cpu().numpy() | |||
for b in range(words.size(0)): | |||
length = seq_len_from_behind[b] | |||
if length==0: | |||
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b] | |||
else: | |||
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] | |||
except Exception as e: | |||
print(f"Exception happens at {index} dataset.") | |||
raise e | |||
print("Finish calculating sentence representations.") | |||
self.sent_embeds = sent_embeds | |||
if delete_weights: | |||
self._delete_model_weights() | |||
def _get_sent_reprs(self, words): | |||
""" | |||
获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None | |||
:param words: torch.LongTensor | |||
:return: | |||
""" | |||
if hasattr(self, 'sent_embeds'): | |||
words_list = words.tolist() | |||
seq_len = words.ne(self._word_pad_index).sum(dim=-1) | |||
_embeds = [] | |||
for b in range(len(words)): | |||
words_i = tuple(words_list[b][:seq_len[b]]) | |||
embed = self.sent_embeds[words_i] | |||
_embeds.append(embed) | |||
max_sent_len = max(map(len, _embeds)) | |||
embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float, | |||
device=words.device) | |||
for i, embed in enumerate(_embeds): | |||
embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device) | |||
return embeds | |||
return None | |||
@abstractmethod | |||
def _delete_model_weights(self): | |||
"""删除计算表示的模型以节省资源""" | |||
raise NotImplementedError | |||
def remove_sentence_cache(self): | |||
""" | |||
删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 | |||
:return: | |||
""" | |||
del self.sent_embeds | |||
class ElmoEmbedding(ContextualEmbedding): | |||
""" | |||
别名::class:`fastNLP.modules.ElmoEmbedding` :class:`fastNLP.modules.encoder.embedding.ElmoEmbedding` | |||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。 | |||
我们提供的ELMo预训练模型来自 https://github.com/HIT-SCIR/ELMoForManyLangs | |||
Example:: | |||
>>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True) | |||
:param vocab: 词表 | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称, | |||
目前支持的ELMo包括{`en` : 英文版本的ELMo, `cn` : 中文版本的ELMo,}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载 | |||
:param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 | |||
按照这个顺序concat起来。默认为'2'。 | |||
:param requires_grad: bool, 该层是否需要gradient. 默认为False | |||
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, | |||
并删除character encoder,之后将直接使用cache的embedding。默认为False。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', | |||
layers: str='2', requires_grad: bool=False, cache_word_reprs: bool=False): | |||
super(ElmoEmbedding, self).__init__(vocab) | |||
layers = list(map(int, layers.split(','))) | |||
assert len(layers) > 0, "Must choose one output" | |||
for layer in layers: | |||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||
self.layers = layers | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
PRETRAINED_ELMO_MODEL_DIR = {'en': 'elmo_en-d39843fe.tar.gz', | |||
'cn': 'elmo_cn-5e9b34e2.tar.gz'} | |||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('elmo') | |||
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.layers) * self.model.config['encoder']['projection_dim'] * 2 | |||
def forward(self, words: torch.LongTensor): | |||
""" | |||
计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 | |||
被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens; | |||
backward_hiddens]. | |||
:param words: batch_size x max_len | |||
:return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers)) | |||
""" | |||
outputs = self._get_sent_reprs(words) | |||
if outputs is not None: | |||
return outputs | |||
outputs = self.model(words) | |||
if len(self.layers) == 1: | |||
outputs = outputs[self.layers[0]] | |||
else: | |||
outputs = torch.cat([*outputs[self.layers]], dim=-1) | |||
return outputs | |||
def _delete_model_weights(self): | |||
del self.layers, self.model | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'words_to_chars_embedding' not in name and 'words_to_words' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'words_to_words' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
class BertEmbedding(ContextualEmbedding): | |||
""" | |||
别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding` | |||
使用BERT对words进行encode的Embedding。建议将输入的words长度限制在450以内,而不要使用512。这是由于预训练的bert模型长 | |||
度限制为512个token,而因为输入的word是未进行word piece分割的,在分割之后长度可能会超过最大长度限制。 | |||
Example:: | |||
>>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') | |||
:param fastNLP.Vocabulary vocab: 词表 | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``. | |||
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | |||
中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。 | |||
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | |||
会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。 | |||
:param bool requires_grad: 是否需要gradient。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
pool_method: str='first', include_cls_sep: bool=False, requires_grad: bool=False): | |||
super(BertEmbedding, self).__init__(vocab) | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
} | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size | |||
def _delete_model_weights(self): | |||
del self.model | |||
def forward(self, words): | |||
""" | |||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||
删除这两个token的表示。 | |||
:param torch.LongTensor words: [batch_size, max_len] | |||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
""" | |||
outputs = self._get_sent_reprs(words) | |||
if outputs is not None: | |||
return outputs | |||
outputs = self.model(words) | |||
outputs = torch.cat([*outputs], dim=-1) | |||
return outputs | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的大小 | |||
:return: torch.Size() | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
return self.weight.size() | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'word_pieces_lengths' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1): | |||
""" | |||
给定一个word的vocabulary生成character的vocabulary. | |||
:param vocab: 从vocab | |||
:param min_freq: | |||
:return: | |||
""" | |||
char_vocab = Vocabulary(min_freq=min_freq) | |||
for word, index in vocab: | |||
if not vocab._is_word_no_create_entry(word): | |||
char_vocab.add_word_lst(list(word)) | |||
return char_vocab | |||
class CNNCharEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding` | |||
使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool | |||
-> fc. 不同的kernel大小的fitler结果是concat起来的。 | |||
Example:: | |||
>>> cnn_char_embed = CNNCharEmbedding(vocab) | |||
:param vocab: 词表 | |||
:param embed_size: 该word embedding的大小,默认值为50. | |||
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. | |||
:param dropout: 以多大的概率drop | |||
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. | |||
:param kernel_sizes: kernel的大小. 默认值为[5, 3, 1]. | |||
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. | |||
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. | |||
:param min_char_freq: character的最少出现次数。默认值为2. | |||
""" | |||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, | |||
filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max', | |||
activation='relu', min_char_freq: int=2): | |||
super(CNNCharEmbedding, self).__init__(vocab) | |||
for kernel in kernel_sizes: | |||
assert kernel % 2 == 1, "Only odd kernel is allowed." | |||
assert pool_method in ('max', 'avg') | |||
self.dropout = nn.Dropout(dropout, inplace=True) | |||
self.pool_method = pool_method | |||
# activation function | |||
if isinstance(activation, str): | |||
if activation.lower() == 'relu': | |||
self.activation = F.relu | |||
elif activation.lower() == 'sigmoid': | |||
self.activation = F.sigmoid | |||
elif activation.lower() == 'tanh': | |||
self.activation = F.tanh | |||
elif activation is None: | |||
self.activation = lambda x: x | |||
elif callable(activation): | |||
self.activation = activation | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
print("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long), | |||
requires_grad=False) | |||
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的<pad>也是同一个embed | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||
self.word_lengths[index] = len(word) | |||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
self.convs = nn.ModuleList([nn.Conv1d( | |||
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) | |||
for i in range(len(kernel_sizes))]) | |||
self._embed_size = embed_size | |||
self.fc = nn.Linear(sum(filter_nums), embed_size) | |||
self.init_param() | |||
def forward(self, words): | |||
""" | |||
输入words的index后,生成对应的words的表示。 | |||
:param words: [batch_size, max_len] | |||
:return: [batch_size, max_len, embed_size] | |||
""" | |||
batch_size, max_len = words.size() | |||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | |||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||
max_word_len = word_lengths.max() | |||
chars = chars[:, :, :max_word_len] | |||
# 为1的地方为mask | |||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | |||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | |||
self.dropout(chars) | |||
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) | |||
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | |||
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | |||
for conv in self.convs] | |||
conv_chars = torch.cat(conv_chars, dim=-1).contiguous() # B x max_len x max_word_len x sum(filters) | |||
conv_chars = self.activation(conv_chars) | |||
if self.pool_method == 'max': | |||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) | |||
else: | |||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||
chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||
chars = self.fc(chars) | |||
return chars | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
params = [] | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' not in name and 'word_lengths' not in name: | |||
params.append(param.requires_grad) | |||
requires_grads = set(params) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
def init_param(self): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | |||
continue | |||
if param.data.dim()>1: | |||
nn.init.xavier_uniform_(param, 1) | |||
else: | |||
nn.init.uniform_(param, -1, 1) | |||
class LSTMCharEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.modules.LSTMCharEmbedding` :class:`fastNLP.modules.encoder.embedding.LSTMCharEmbedding` | |||
使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool | |||
Example:: | |||
>>> lstm_char_embed = LSTMCharEmbedding(vocab) | |||
:param vocab: 词表 | |||
:param embed_size: embedding的大小。默认值为50. | |||
:param char_emb_size: character的embedding的大小。默认值为50. | |||
:param dropout: 以多大概率drop | |||
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. | |||
:param pool_method: 支持'max', 'avg' | |||
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. | |||
:param min_char_freq: character的最小出现次数。默认值为2. | |||
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, dropout:float=0.5, hidden_size=50, | |||
pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True): | |||
super(LSTMCharEmbedding, self).__init__(vocab) | |||
assert hidden_size % 2 == 0, "Only even kernel is allowed." | |||
assert pool_method in ('max', 'avg') | |||
self.pool_method = pool_method | |||
self.dropout = nn.Dropout(dropout, inplace=True) | |||
# activation function | |||
if isinstance(activation, str): | |||
if activation.lower() == 'relu': | |||
self.activation = F.relu | |||
elif activation.lower() == 'sigmoid': | |||
self.activation = F.sigmoid | |||
elif activation.lower() == 'tanh': | |||
self.activation = F.tanh | |||
elif activation is None: | |||
self.activation = lambda x: x | |||
elif callable(activation): | |||
self.activation = activation | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
print("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long), | |||
requires_grad=False) | |||
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||
self.word_lengths[index] = len(word) | |||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
self.fc = nn.Linear(hidden_size, embed_size) | |||
hidden_size = hidden_size // 2 if bidirectional else hidden_size | |||
self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True) | |||
self._embed_size = embed_size | |||
self.bidirectional = bidirectional | |||
def forward(self, words): | |||
""" | |||
输入words的index后,生成对应的words的表示。 | |||
:param words: [batch_size, max_len] | |||
:return: [batch_size, max_len, embed_size] | |||
""" | |||
batch_size, max_len = words.size() | |||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | |||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||
max_word_len = word_lengths.max() | |||
chars = chars[:, :, :max_word_len] | |||
# 为mask的地方为1 | |||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | |||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | |||
chars = self.dropout(chars) | |||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | |||
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len) | |||
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) | |||
# B x M x M x H | |||
lstm_chars = self.activation(lstm_chars) | |||
if self.pool_method == 'max': | |||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||
chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H | |||
else: | |||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||
chars = self.fc(chars) | |||
return chars | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
params = [] | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' not in name and 'word_lengths' not in name: | |||
params.append(param) | |||
requires_grads = set(params) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
class StackEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.modules.StackEmbedding` :class:`fastNLP.modules.encoder.embedding.StackEmbedding` | |||
支持将多个embedding集合成一个embedding。 | |||
Example:: | |||
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
>>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) | |||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | |||
""" | |||
def __init__(self, embeds: List[TokenEmbedding]): | |||
vocabs = [] | |||
for embed in embeds: | |||
vocabs.append(embed.get_word_vocab()) | |||
_vocab = vocabs[0] | |||
for vocab in vocabs[1:]: | |||
assert vocab == _vocab, "All embeddings should use the same word vocabulary." | |||
super(StackEmbedding, self).__init__(_vocab) | |||
assert isinstance(embeds, list) | |||
for embed in embeds: | |||
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." | |||
self.embeds = nn.ModuleList(embeds) | |||
self._embed_size = sum([embed.embed_size for embed in self.embeds]) | |||
def append(self, embed: TokenEmbedding): | |||
""" | |||
添加一个embedding到结尾。 | |||
:param embed: | |||
:return: | |||
""" | |||
assert isinstance(embed, TokenEmbedding) | |||
self.embeds.append(embed) | |||
def pop(self): | |||
""" | |||
弹出最后一个embed | |||
:return: | |||
""" | |||
return self.embeds.pop() | |||
@property | |||
def embed_size(self): | |||
return self._embed_size | |||
@property | |||
def requires_grad(self): | |||
""" | |||
Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许 | |||
:return: | |||
""" | |||
requires_grads = set([embed.requires_grad for embed in self.embeds()]) | |||
if len(requires_grads)==1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for embed in self.embeds(): | |||
embed.requires_grad = value | |||
def forward(self, words): | |||
""" | |||
得到多个embedding的结果,并把结果按照顺序concat起来。 | |||
:param words: batch_size x max_len | |||
:return: 返回的shape和当前这个stack embedding中embedding的组成有关 | |||
""" | |||
outputs = [] | |||
for embed in self.embeds: | |||
outputs.append(embed(words)) | |||
return torch.cat(outputs, dim=-1) | |||
@@ -11,16 +11,18 @@ import torch.nn as nn | |||
import torch.nn.utils.rnn as rnn | |||
from ..utils import initial_parameter | |||
from torch import autograd | |||
class LSTM(nn.Module): | |||
""" | |||
别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM` | |||
LSTM 模块, 轻量封装的Pytorch LSTM | |||
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化 | |||
为1; 且可以应对DataParallel中LSTM的使用问题。 | |||
:param input_size: 输入 `x` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度 | |||
:param hidden_size: 隐状态 `h` 的特征维度. | |||
:param num_layers: rnn的层数. Default: 1 | |||
:param dropout: 层间dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
@@ -30,23 +32,35 @@ class LSTM(nn.Module): | |||
""" | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||
bidirectional=False, bias=True, initial_method=None): | |||
bidirectional=False, bias=True): | |||
super(LSTM, self).__init__() | |||
self.batch_first = batch_first | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||
dropout=dropout, bidirectional=bidirectional) | |||
initial_parameter(self, initial_method) | |||
self.init_param() | |||
def init_param(self): | |||
for name, param in self.named_parameters(): | |||
if 'bias' in name: | |||
# based on https://github.com/pytorch/pytorch/issues/750#issuecomment-280671871 | |||
param.data.fill_(0) | |||
n = param.size(0) | |||
start, end = n // 4, n // 2 | |||
param.data[start:end].fill_(1) | |||
else: | |||
nn.init.xavier_uniform_(param) | |||
def forward(self, x, seq_len=None, h0=None, c0=None): | |||
""" | |||
:param x: [batch, seq_len, input_size] 输入序列 | |||
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全0向量. Default: ``None`` | |||
:return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 | |||
和 [batch, hidden_size*num_direction] 最后时刻隐状态. | |||
""" | |||
batch_size, max_len, _ = x.size() | |||
if h0 is not None and c0 is not None: | |||
hx = (h0, c0) | |||
else: | |||
@@ -65,6 +79,15 @@ class LSTM(nn.Module): | |||
output = output[unsort_idx] | |||
else: | |||
output = output[:, unsort_idx] | |||
# 解决LSTM无法在DataParallel下使用的问题问题https://github.com/pytorch/pytorch/issues/1591 | |||
if self.batch_first: | |||
if output.size(1) < max_len: | |||
dummy_tensor = output.new_zeros(batch_size, max_len - output.size(1), output.size(-1)) | |||
output = torch.cat([output, dummy_tensor], 1) | |||
else: | |||
if output.size(0) < max_len: | |||
dummy_tensor = output.new_zeros(max_len - output.size(1), batch_size, output.size(-1)) | |||
output = torch.cat([output, dummy_tensor], 0) | |||
else: | |||
output, hx = self.lstm(x, hx) | |||
return output, hx |
@@ -82,7 +82,7 @@ def get_embeddings(init_embed): | |||
if isinstance(init_embed, tuple): | |||
res = nn.Embedding( | |||
num_embeddings=init_embed[0], embedding_dim=init_embed[1]) | |||
elif isinstance(init_embed, nn.Embedding): | |||
elif isinstance(init_embed, nn.Module): | |||
res = init_embed | |||
elif isinstance(init_embed, torch.Tensor): | |||
res = nn.Embedding.from_pretrained(init_embed, freeze=False) | |||
@@ -130,3 +130,17 @@ def summary(model: nn.Module): | |||
strings = [bar] + strings + [bar] | |||
print('\n'.join(strings)) | |||
return total, total_train, total_nontrain | |||
def get_dropout_mask(drop_p: float, tensor: torch.Tensor): | |||
""" | |||
根据tensor的形状,生成一个mask | |||
:param drop_p: float, 以多大的概率置为0。 | |||
:param tensor:torch.Tensor | |||
:return: torch.FloatTensor. 与tensor一样的shape | |||
""" | |||
mask_x = torch.ones_like(tensor) | |||
nn.functional.dropout(mask_x, p=drop_p, | |||
training=False, inplace=True) | |||
return mask_x |
@@ -184,11 +184,8 @@ def train(path): | |||
m.weight.requires_grad = True | |||
# Trainer | |||
trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||
loss=ParserLoss(), metrics=ParserMetric(), metric_key='UAS', | |||
**train_args.data, | |||
optimizer=fastNLP.Adam(**optim_args.data), | |||
save_path=path, | |||
trainer = Trainer(train_data=train_data, model=model, optimizer=fastNLP.Adam(**optim_args.data), loss=ParserLoss(), | |||
dev_data=dev_data, metrics=ParserMetric(), metric_key='UAS', save_path=path, | |||
callbacks=[MyCallback()]) | |||
# Start training | |||
@@ -89,11 +89,11 @@ def train(train_data_path, dev_data_path, checkpoint=None, save=None): | |||
model = torch.load(checkpoint) | |||
# call trainer to train | |||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
target="truth", | |||
seq_lens="word_seq_origin_len"), | |||
dev_data=dev_data, metric_key="f", | |||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path=save) | |||
trainer = Trainer(dataset, model, loss=None, n_epochs=20, print_every=10, dev_data=dev_data, | |||
metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
target="truth", | |||
seq_lens="word_seq_origin_len"), metric_key="f", save_path=save, | |||
use_tqdm=True) | |||
trainer.train(load_best_model=True) | |||
# save model & pipeline | |||
@@ -2,43 +2,28 @@ | |||
这里复现了在fastNLP中实现的模型,旨在达到与论文中相符的性能。 | |||
复现的模型有: | |||
- Star-Transformer | |||
- [Star-Transformer](Star_transformer/) | |||
- ... | |||
# 任务复现 | |||
## Text Classification (文本分类) | |||
- still in progress | |||
## Matching (自然语言推理/句子匹配) | |||
- still in progress | |||
## Sequence Labeling (序列标注) | |||
- still in progress | |||
## Coreference resolution (指代消解) | |||
- still in progress | |||
## Summarization (摘要) | |||
- still in progress | |||
## Star-Transformer | |||
[reference](https://arxiv.org/abs/1902.09113) | |||
### Performance (still in progress) | |||
|任务| 数据集 | SOTA | 模型表现 | | |||
|------|------| ------| ------| | |||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |||
|Text Classification|SST|-|49.18| | |||
|Natural Language Inference|SNLI|-|83.76| | |||
### Usage | |||
``` python | |||
# for sequence labeling(ner, pos tagging, etc) | |||
from fastNLP.models.star_transformer import STSeqLabel | |||
model = STSeqLabel( | |||
vocab_size=10000, num_cls=50, | |||
emb_dim=300) | |||
# for sequence classification | |||
from fastNLP.models.star_transformer import STSeqCls | |||
model = STSeqCls( | |||
vocab_size=10000, num_cls=50, | |||
emb_dim=300) | |||
# for natural language inference | |||
from fastNLP.models.star_transformer import STNLICls | |||
model = STNLICls( | |||
vocab_size=10000, num_cls=50, | |||
emb_dim=300) | |||
``` | |||
## ... |
@@ -0,0 +1,34 @@ | |||
# Star-Transformer | |||
paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) | |||
## Performance (still in progress) | |||
|任务| 数据集 | SOTA | 模型表现 | | |||
|------|------| ------| ------| | |||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |||
|Text Classification|SST|-|49.18| | |||
|Natural Language Inference|SNLI|-|83.76| | |||
## Usage | |||
``` python | |||
# for sequence labeling(ner, pos tagging, etc) | |||
from fastNLP.models.star_transformer import STSeqLabel | |||
model = STSeqLabel( | |||
vocab_size=10000, num_cls=50, | |||
emb_dim=300) | |||
# for sequence classification | |||
from fastNLP.models.star_transformer import STSeqCls | |||
model = STSeqCls( | |||
vocab_size=10000, num_cls=50, | |||
emb_dim=300) | |||
# for natural language inference | |||
from fastNLP.models.star_transformer import STNLICls | |||
model = STNLICls( | |||
vocab_size=10000, num_cls=50, | |||
emb_dim=300) | |||
``` |
@@ -149,14 +149,10 @@ def train(): | |||
) if x.requires_grad and x.size(0) != len(word_v)] | |||
optim_cfg = [{'params': model.enc.embedding.parameters(), 'lr': g_args.lr*0.1}, | |||
{'params': ex_param, 'lr': g_args.lr, 'weight_decay': g_args.w_decay}, ] | |||
trainer = FN.Trainer(model=model, train_data=train_data, dev_data=dev_data, | |||
loss=loss, metrics=metric, metric_key=metric_key, | |||
optimizer=torch.optim.Adam(optim_cfg), | |||
n_epochs=g_args.ep, batch_size=g_args.bsz, print_every=10, validate_every=3000, | |||
device=device, | |||
use_tqdm=False, prefetch=False, | |||
save_path=g_args.log, | |||
callbacks=[MyCallback()]) | |||
trainer = FN.Trainer(train_data=train_data, model=model, optimizer=torch.optim.Adam(optim_cfg), loss=loss, | |||
batch_size=g_args.bsz, n_epochs=g_args.ep, print_every=10, dev_data=dev_data, metrics=metric, | |||
metric_key=metric_key, validate_every=3000, save_path=g_args.log, use_tqdm=False, | |||
device=device, callbacks=[MyCallback()]) | |||
trainer.train() | |||
tester = FN.Tester(data=test_data, model=model, metrics=metric, | |||
@@ -0,0 +1,223 @@ | |||
import os | |||
from nltk import Tree | |||
from typing import Union, Dict | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.dataset_loader import JsonLoader | |||
from fastNLP.io.file_utils import _get_base_url, cached_path | |||
from fastNLP.modules.encoder._bert import BertTokenizer | |||
class MatchingLoader(JsonLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
""" | |||
def __init__(self, fields=None, paths: dict=None): | |||
super(MatchingLoader, self).__init__(fields=fields) | |||
self.paths = paths | |||
def _load(self, path): | |||
return super(MatchingLoader, self)._load(path) | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name=None, | |||
to_lower=False, char_information=False, seq_len_type: str=None, | |||
bert_tokenizer: str=None, get_index=True, set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataInfo: | |||
if isinstance(set_input, str): | |||
set_input = [set_input] | |||
if isinstance(set_target, str): | |||
set_target = [set_target] | |||
if isinstance(set_input, bool): | |||
auto_set_input = set_input | |||
else: | |||
auto_set_input = False | |||
if isinstance(set_target, bool): | |||
auto_set_target = set_target | |||
else: | |||
auto_set_target = False | |||
if isinstance(paths, str): | |||
if os.path.isdir(paths): | |||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||
else: | |||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||
else: | |||
path = paths | |||
data_info = DataInfo() | |||
for data_name in path.keys(): | |||
data_info.datasets[data_name] = self._load(path[data_name]) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if auto_set_input: | |||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||
if auto_set_target: | |||
data_set.set_target(Const.TARGET) | |||
if to_lower: | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||
is_input=auto_set_input) | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||
is_input=auto_set_input) | |||
if bert_tokenizer is not None: | |||
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip', | |||
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip', | |||
} | |||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(bert_tokenizer): | |||
model_dir = bert_tokenizer | |||
else: | |||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
if isinstance(concat, bool): | |||
concat = 'default' if concat else None | |||
if concat is not None: | |||
if isinstance(concat, str): | |||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||
'default': ['', '<sep>', '', '']} | |||
if concat.lower() in CONCAT_MAP: | |||
concat = CONCAT_MAP[concat] | |||
else: | |||
concat = 4 * [concat] | |||
assert len(concat) == 4, \ | |||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||
f'sentence. Your input is {concat}' | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||
is_input=auto_set_input) | |||
if seq_len_type is not None: | |||
if seq_len_type == 'seq_len': # | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.TARGET), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'mask': | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [1] * len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.TARGET), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'bert': | |||
for data_name, data_set in data_info.datasets.items(): | |||
if Const.INPUT not in data_set.get_field_names(): | |||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||
f'got {data_set.get_field_names()}') | |||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
data_set_list = [d for n, d in data_info.datasets.items()] | |||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
if bert_tokenizer is not None: | |||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||
else: | |||
words_vocab = Vocabulary() | |||
words_vocab = words_vocab.from_dataset(*data_set_list, | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab = target_vocab.from_dataset(*data_set_list, field_name=Const.TARGET) | |||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||
if get_index: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||
is_input=auto_set_input) | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if isinstance(set_input, list): | |||
data_set.set_input(set_input) | |||
if isinstance(set_target, list): | |||
data_set.set_target(set_target) | |||
return data_info | |||
class SNLILoader(MatchingLoader): | |||
""" | |||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` | |||
读取SNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self, paths: dict=None): | |||
fields = { | |||
'sentence1_parse': Const.INPUTS(0), | |||
'sentence2_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
paths = paths if paths is not None else { | |||
'train': 'snli_1.0_train.jsonl', | |||
'dev': 'snli_1.0_dev.jsonl', | |||
'test': 'snli_1.0_test.jsonl'} | |||
super(SNLILoader, self).__init__(fields=fields, paths=paths) | |||
def _load(self, path): | |||
ds = super(SNLILoader, self)._load(path) | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: parse_tree( | |||
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds | |||
@@ -0,0 +1,44 @@ | |||
import os | |||
import torch | |||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||
from fastNLP.io.dataset_loader import MatchingLoader | |||
from reproduction.matching.model.bert import BertForNLI | |||
from reproduction.matching.model.esim import ESIMModel | |||
bert_dirs = 'path/to/bert/dir' | |||
# load data set | |||
# data_info = MatchingLoader(data_format='snli', for_model='bert', bert_dir=bert_dirs).process(... | |||
data_info = MatchingLoader(data_format='snli', for_model='esim').process( | |||
{'train': './data/snli/snli_1.0_train.jsonl', | |||
'dev': './data/snli/snli_1.0_dev.jsonl', | |||
'test': './data/snli/snli_1.0_test.jsonl'}, | |||
input_field=[Const.TARGET] | |||
) | |||
# model = BertForNLI(bert_dir=bert_dirs) | |||
model = ESIMModel(data_info.embeddings['elmo'],) | |||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||
optimizer=Adam(lr=1e-4, model_params=model.parameters()), | |||
batch_size=torch.cuda.device_count() * 24, n_epochs=20, print_every=-1, | |||
dev_data=data_info.datasets['dev'], | |||
metrics=AccuracyMetric(), metric_key='acc', device=[i for i in range(torch.cuda.device_count())], | |||
check_code_level=-1) | |||
trainer.train(load_best_model=True) | |||
tester = Tester( | |||
data=data_info.datasets['test'], | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=torch.cuda.device_count() * 12, | |||
device=[i for i in range(torch.cuda.device_count())], | |||
) | |||
tester.test() | |||
@@ -0,0 +1,41 @@ | |||
import torch | |||
import torch.nn as nn | |||
from fastNLP.core.const import Const | |||
from fastNLP.models import BaseModel | |||
from fastNLP.modules.encoder.bert import BertModel | |||
class BertForNLI(BaseModel): | |||
# TODO: still in progress | |||
def __init__(self, class_num=3, bert_dir=None): | |||
super(BertForNLI, self).__init__() | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
else: | |||
self.bert = BertModel() | |||
hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1) | |||
self.classifier = nn.Linear(hidden_size, class_num) | |||
def forward(self, words, seq_len1, seq_len2, target=None): | |||
""" | |||
:param torch.Tensor words: [batch_size, seq_len] input_ids | |||
:param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids | |||
:param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask | |||
:param torch.Tensor target: [batch] | |||
:return: | |||
""" | |||
_, pooled_output = self.bert(words, seq_len1, seq_len2) | |||
logits = self.classifier(pooled_output) | |||
if target is not None: | |||
loss_func = torch.nn.CrossEntropyLoss() | |||
loss = loss_func(logits, target) | |||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
return {Const.OUTPUT: logits} | |||
def predict(self, words, seq_len1, seq_len2, target=None): | |||
return self.forward(words, seq_len1, seq_len2) | |||
@@ -0,0 +1,182 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from torch.nn import CrossEntropyLoss | |||
from fastNLP.models import BaseModel | |||
from fastNLP.modules.encoder.embedding import TokenEmbedding | |||
from fastNLP.modules.encoder.lstm import LSTM | |||
from fastNLP.core.const import Const | |||
from fastNLP.core.utils import seq_len_to_mask | |||
class ESIMModel(BaseModel): | |||
def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3, | |||
dropout_embed=0.1): | |||
super(ESIMModel, self).__init__() | |||
self.embedding = init_embedding | |||
self.dropout_embed = EmbedDropout(p=dropout_embed) | |||
if hidden_size is None: | |||
hidden_size = self.embedding.embed_size | |||
self.rnn = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||
# self.rnn = LSTM(self.embedding.embed_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||
self.interfere = nn.Sequential(nn.Dropout(p=dropout_rate), | |||
nn.Linear(8 * hidden_size, hidden_size), | |||
nn.ReLU()) | |||
nn.init.xavier_uniform_(self.interfere[1].weight.data) | |||
self.bi_attention = SoftmaxAttention() | |||
self.rnn_high = BiRNN(self.embedding.embed_size, hidden_size, dropout_rate=dropout_rate) | |||
# self.rnn_high = LSTM(hidden_size, hidden_size, dropout=dropout_rate, bidirectional=True) | |||
self.classifier = nn.Sequential(nn.Dropout(p=dropout_rate), | |||
nn.Linear(8 * hidden_size, hidden_size), | |||
nn.Tanh(), | |||
nn.Dropout(p=dropout_rate), | |||
nn.Linear(hidden_size, num_labels)) | |||
nn.init.xavier_uniform_(self.classifier[1].weight.data) | |||
nn.init.xavier_uniform_(self.classifier[4].weight.data) | |||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||
mask1 = seq_len_to_mask(seq_len1) | |||
mask2 = seq_len_to_mask(seq_len2) | |||
a0 = self.embedding(words1) # B * len * emb_dim | |||
b0 = self.embedding(words2) | |||
a0, b0 = self.dropout_embed(a0), self.dropout_embed(b0) | |||
a = self.rnn(a0, mask1.byte()) # a: [B, PL, 2 * H] | |||
b = self.rnn(b0, mask2.byte()) | |||
ai, bi = self.bi_attention(a, mask1, b, mask2) | |||
a_ = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 8 * H] | |||
b_ = torch.cat((b, bi, b - bi, b * bi), dim=2) | |||
a_f = self.interfere(a_) | |||
b_f = self.interfere(b_) | |||
a_h = self.rnn_high(a_f, mask1.byte()) # ma: [B, PL, 2 * H] | |||
b_h = self.rnn_high(b_f, mask2.byte()) | |||
a_avg = self.mean_pooling(a_h, mask1, dim=1) | |||
a_max, _ = self.max_pooling(a_h, mask1, dim=1) | |||
b_avg = self.mean_pooling(b_h, mask2, dim=1) | |||
b_max, _ = self.max_pooling(b_h, mask2, dim=1) | |||
out = torch.cat((a_avg, a_max, b_avg, b_max), dim=1) # v: [B, 8 * H] | |||
logits = torch.tanh(self.classifier(out)) | |||
if target is not None: | |||
loss_fct = CrossEntropyLoss() | |||
loss = loss_fct(logits, target) | |||
return {Const.LOSS: loss, Const.OUTPUT: logits} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
def predict(self, **kwargs): | |||
return self.forward(**kwargs) | |||
# input [batch_size, len , hidden] | |||
# mask [batch_size, len] (111...00) | |||
@staticmethod | |||
def mean_pooling(input, mask, dim=1): | |||
masks = mask.view(mask.size(0), mask.size(1), -1).float() | |||
return torch.sum(input * masks, dim=dim) / torch.sum(masks, dim=1) | |||
@staticmethod | |||
def max_pooling(input, mask, dim=1): | |||
my_inf = 10e12 | |||
masks = mask.view(mask.size(0), mask.size(1), -1) | |||
masks = masks.expand(-1, -1, input.size(2)).float() | |||
return torch.max(input + masks.le(0.5).float() * -my_inf, dim=dim) | |||
class EmbedDropout(nn.Dropout): | |||
def forward(self, sequences_batch): | |||
ones = sequences_batch.data.new_ones(sequences_batch.shape[0], sequences_batch.shape[-1]) | |||
dropout_mask = nn.functional.dropout(ones, self.p, self.training, inplace=False) | |||
return dropout_mask.unsqueeze(1) * sequences_batch | |||
class BiRNN(nn.Module): | |||
def __init__(self, input_size, hidden_size, dropout_rate=0.3): | |||
super(BiRNN, self).__init__() | |||
self.dropout_rate = dropout_rate | |||
self.rnn = nn.LSTM(input_size, hidden_size, | |||
num_layers=1, | |||
bidirectional=True, | |||
batch_first=True) | |||
def forward(self, x, x_mask): | |||
# Sort x | |||
lengths = x_mask.data.eq(1).long().sum(1).squeeze() | |||
_, idx_sort = torch.sort(lengths, dim=0, descending=True) | |||
_, idx_unsort = torch.sort(idx_sort, dim=0) | |||
lengths = list(lengths[idx_sort]) | |||
x = x.index_select(0, idx_sort) | |||
# Pack it up | |||
rnn_input = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) | |||
# Apply dropout to input | |||
if self.dropout_rate > 0: | |||
dropout_input = F.dropout(rnn_input.data, p=self.dropout_rate, training=self.training) | |||
rnn_input = nn.utils.rnn.PackedSequence(dropout_input, rnn_input.batch_sizes) | |||
output = self.rnn(rnn_input)[0] | |||
# Unpack everything | |||
output = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)[0] | |||
output = output.index_select(0, idx_unsort) | |||
if output.size(1) != x_mask.size(1): | |||
padding = torch.zeros(output.size(0), | |||
x_mask.size(1) - output.size(1), | |||
output.size(2)).type(output.data.type()) | |||
output = torch.cat([output, padding], 1) | |||
return output | |||
def masked_softmax(tensor, mask): | |||
tensor_shape = tensor.size() | |||
reshaped_tensor = tensor.view(-1, tensor_shape[-1]) | |||
# Reshape the mask so it matches the size of the input tensor. | |||
while mask.dim() < tensor.dim(): | |||
mask = mask.unsqueeze(1) | |||
mask = mask.expand_as(tensor).contiguous().float() | |||
reshaped_mask = mask.view(-1, mask.size()[-1]) | |||
result = F.softmax(reshaped_tensor * reshaped_mask, dim=-1) | |||
result = result * reshaped_mask | |||
# 1e-13 is added to avoid divisions by zero. | |||
result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) | |||
return result.view(*tensor_shape) | |||
def weighted_sum(tensor, weights, mask): | |||
w_sum = weights.bmm(tensor) | |||
while mask.dim() < w_sum.dim(): | |||
mask = mask.unsqueeze(1) | |||
mask = mask.transpose(-1, -2) | |||
mask = mask.expand_as(w_sum).contiguous().float() | |||
return w_sum * mask | |||
class SoftmaxAttention(nn.Module): | |||
def forward(self, premise_batch, premise_mask, hypothesis_batch, hypothesis_mask): | |||
similarity_matrix = premise_batch.bmm(hypothesis_batch.transpose(2, 1) | |||
.contiguous()) | |||
prem_hyp_attn = masked_softmax(similarity_matrix, hypothesis_mask) | |||
hyp_prem_attn = masked_softmax(similarity_matrix.transpose(1, 2) | |||
.contiguous(), | |||
premise_mask) | |||
attended_premises = weighted_sum(hypothesis_batch, | |||
prem_hyp_attn, | |||
premise_mask) | |||
attended_hypotheses = weighted_sum(premise_batch, | |||
hyp_prem_attn, | |||
hypothesis_mask) | |||
return attended_premises, attended_hypotheses |
@@ -0,0 +1,10 @@ | |||
import unittest | |||
from ..data import MatchingDataLoader | |||
from fastNLP.core.vocabulary import Vocabulary | |||
class TestCWSDataLoader(unittest.TestCase): | |||
def test_case1(self): | |||
snli_loader = MatchingDataLoader() | |||
# TODO: still in progress | |||
@@ -0,0 +1,249 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
from reproduction.utils import check_dataloader_paths | |||
from functools import partial | |||
class SigHanLoader(DataSetLoader): | |||
""" | |||
任务相关的说明可以在这里找到http://sighan.cs.uchicago.edu/ | |||
支持的数据格式为,一行一句,不同的word用空格隔开。如下例 | |||
共同 创造 美好 的 新 世纪 —— 二○○一年 新年 | |||
女士 们 , 先生 们 , 同志 们 , 朋友 们 : | |||
读取sighan中的数据集,返回的DataSet将包含以下的内容fields: | |||
raw_chars: list(str), 每个元素是一个汉字 | |||
chars: list(str), 每个元素是一个index(汉字对应的index) | |||
target: list(int), 根据不同的encoding_type会有不同的变化 | |||
:param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay" | |||
""" | |||
def __init__(self, target_type:str): | |||
super().__init__() | |||
if target_type.lower() not in ('bmes', 'shift_relay'): | |||
raise ValueError("target_type only supports 'bmes', 'shift_relay'.") | |||
self.target_type = target_type | |||
if target_type=='bmes': | |||
self._word_len_to_target = self._word_len_to_bems | |||
elif target_type=='shift_relay': | |||
self._word_len_to_target = self._word_lens_to_relay | |||
@staticmethod | |||
def _word_lens_to_relay(word_lens: Iterator[int]): | |||
""" | |||
[1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长); | |||
:param word_lens: | |||
:return: {'target': , 'end_seg_mask':, 'start_seg_mask':} | |||
""" | |||
tags = [] | |||
end_seg_mask = [] | |||
start_seg_mask = [] | |||
for word_len in word_lens: | |||
tags.extend([idx for idx in range(word_len - 1, -1, -1)]) | |||
end_seg_mask.extend([0] * (word_len - 1) + [1]) | |||
start_seg_mask.extend([1] + [0] * (word_len - 1)) | |||
return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask} | |||
@staticmethod | |||
def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]: | |||
""" | |||
:param word_lens: 每个word的长度 | |||
:return: | |||
""" | |||
tags = [] | |||
for word_len in word_lens: | |||
if word_len==1: | |||
tags.append('S') | |||
else: | |||
tags.append('B') | |||
for _ in range(word_len-2): | |||
tags.append('M') | |||
tags.append('E') | |||
return {'target':tags} | |||
@staticmethod | |||
def _gen_bigram(chars:List[str])->List[str]: | |||
""" | |||
:param chars: | |||
:return: | |||
""" | |||
return [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])] | |||
def load(self, path:str, bigram:bool=False)->DataSet: | |||
""" | |||
:param path: str | |||
:param bigram: 是否使用bigram feature | |||
:return: | |||
""" | |||
dataset = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if not line: # 去掉空行 | |||
continue | |||
parts = line.split() | |||
word_lens = map(len, parts) | |||
chars = list(''.join(parts)) | |||
tags = self._word_len_to_target(word_lens) | |||
assert len(chars)==len(tags['target']) | |||
dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars))) | |||
if len(dataset)==0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
if bigram: | |||
dataset.apply_field(self._gen_bigram, field_name='raw_chars', new_field_name='bigrams') | |||
return dataset | |||
def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None, | |||
char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None, | |||
bigram_embed_opt:EmbeddingOption=None, L:int=4): | |||
""" | |||
支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如 | |||
Option:: | |||
共同 创造 美好 的 新 世纪 —— 二○○一年 新年 贺词 | |||
( 二○○○年 十二月 三十一日 ) ( 附 图片 1 张 ) | |||
女士 们 , 先生 们 , 同志 们 , 朋友 们 : | |||
paths支持两种格式,第一种是str,第二种是Dict[str, str]. | |||
Option:: | |||
# 1. str类型 | |||
# 1.1 传入具体的文件路径 | |||
data = SigHanLoader('bmes').process('/path/to/cws/data.txt') # 将读取data.txt的内容 | |||
# 包含以下的内容data.vocabs['chars']:Vocabulary对象, | |||
# data.vocabs['target']: Vocabulary对象,根据encoding_type可能会没有该值 | |||
# data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项 | |||
# data.datasets['train']: DataSet对象 | |||
# 包含的field有: | |||
# raw_chars: list[str], 每个元素是一个汉字 | |||
# chars: list[int], 每个元素是汉字对应的index | |||
# target: list[int], 根据encoding_type有对应的变化 | |||
# 1.2 传入一个目录, 里面必须包含train.txt文件 | |||
data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt | |||
# 包含以下的内容data.vocabs['chars']: Vocabulary对象 | |||
# data.vocabs['target']:Vocabulary对象 | |||
# data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象; | |||
# data.datasets['train']: DataSet对象 | |||
# 包含的field有: | |||
# raw_chars: list[str], 每个元素是一个汉字 | |||
# chars: list[int], 每个元素是汉字对应的index | |||
# target: list[int], 根据encoding_type有对应的变化 | |||
# data.datasets['dev']: DataSet对象,如果文件夹下包含了dev.txt;内容与data.datasets['train']一样 | |||
# 2. dict类型, key是文件的名称,value是对应的读取路径. 必须包含'train'这个key | |||
paths = {'train': '/path/to/train/train.txt', 'test':'/path/to/test/test.txt', 'dev':'/path/to/dev/dev.txt'} | |||
data = SigHanLoader(paths).process(paths) | |||
# 结果与传入目录时是一致的,但是可以传入多个数据集。data.datasets中的key将与这里传入的一致 | |||
:param paths: 支持传入目录,文件路径,以及dict。 | |||
:param char_vocab_opt: 用于构建chars的vocabulary参数,默认为min_freq=2 | |||
:param char_embed_opt: 用于读取chars的Embedding的参数,默认不读取pretrained的embedding | |||
:param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。 | |||
为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e<eos> | |||
:param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效 | |||
:param L: 当target_type为shift_relay时传入的segment长度 | |||
:return: | |||
""" | |||
# 推荐大家使用这个check_data_loader_paths进行paths的验证 | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
data = DataInfo() | |||
bigram = bigram_vocab_opt is not None | |||
for name, path in paths.items(): | |||
dataset = self.load(path, bigram=bigram) | |||
datasets[name] = dataset | |||
input_fields = [] | |||
target_fields = [] | |||
# 创建vocab | |||
char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt) | |||
char_vocab.from_dataset(datasets['train'], field_name='raw_chars') | |||
char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars') | |||
data.vocabs[Const.CHAR_INPUT] = char_vocab | |||
input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET]) | |||
target_fields.append(Const.TARGET) | |||
# 创建target | |||
if self.target_type == 'bmes': | |||
target_vocab = Vocabulary(unknown=None, padding=None) | |||
target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S']) | |||
target_vocab.index_dataset(*datasets.values(), field_name='target') | |||
data.vocabs[Const.TARGET] = target_vocab | |||
if char_embed_opt is not None: | |||
char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab) | |||
data.embeddings['chars'] = char_embed | |||
if bigram: | |||
bigram_vocab = Vocabulary(**bigram_vocab_opt) | |||
bigram_vocab.from_dataset(datasets['train'], field_name='bigrams') | |||
bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams') | |||
data.vocabs['bigrams'] = bigram_vocab | |||
if bigram_embed_opt is not None: | |||
bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab) | |||
data.embeddings['bigrams'] = bigram_embed | |||
input_fields.append('bigrams') | |||
if self.target_type == 'shift_relay': | |||
func = partial(self._clip_target, L=L) | |||
for name, dataset in datasets.items(): | |||
res = dataset.apply_field(func, field_name='target') | |||
relay_target = [res_i[0] for res_i in res] | |||
relay_mask = [res_i[1] for res_i in res] | |||
dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False) | |||
dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False) | |||
if self.target_type == 'shift_relay': | |||
input_fields.extend(['end_seg_mask']) | |||
target_fields.append('start_seg_mask') | |||
# 将dataset加入DataInfo | |||
for name, dataset in datasets.items(): | |||
dataset.set_input(*input_fields) | |||
dataset.set_target(*target_fields) | |||
data.datasets[name] = dataset | |||
return data | |||
@staticmethod | |||
def _clip_target(target:List[int], L:int): | |||
""" | |||
只有在target_type为shift_relay的使用 | |||
:param target: List[int] | |||
:param L: | |||
:return: | |||
""" | |||
relay_target_i = [] | |||
tmp = [] | |||
for j in range(len(target) - 1): | |||
tmp.append(target[j]) | |||
if target[j] > target[j + 1]: | |||
pass | |||
else: | |||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||
tmp = [] | |||
# 处理未结束的部分 | |||
if len(tmp) == 0: | |||
relay_target_i.append(0) | |||
else: | |||
tmp.append(target[-1]) | |||
relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]]) | |||
relay_mask_i = [] | |||
j = 0 | |||
while j < len(target): | |||
seg_len = target[j] + 1 | |||
if target[j] < L: | |||
relay_mask_i.extend([0] * (seg_len)) | |||
else: | |||
relay_mask_i.extend([1] * (seg_len - L) + [0] * L) | |||
j = seg_len + j | |||
return relay_target_i, relay_mask_i | |||
@@ -0,0 +1,44 @@ | |||
from fastNLP.core.metrics import MetricBase | |||
class RelayMetric(MetricBase): | |||
def __init__(self, pred=None, pred_mask=None, target=None, start_seg_mask=None): | |||
super().__init__() | |||
self._init_param_map(pred=pred, pred_mask=pred_mask, target=target, start_seg_mask=start_seg_mask) | |||
self.tp = 0 | |||
self.rec = 0 | |||
self.pre = 0 | |||
def evaluate(self, pred, pred_mask, target, start_seg_mask): | |||
""" | |||
给定每个batch,累计一下结果。 | |||
:param pred: 预测的结果,为当前位置的开始的segment的(长度-1) | |||
:param pred_mask: 当前位置预测有segment开始 | |||
:param target: 当前位置开始的segment的(长度-1) | |||
:param start_seg_mask: 当前有segment结束 | |||
:return: | |||
""" | |||
self.tp += ((pred.long().eq(target.long())).__and__(pred_mask.byte().__and__(start_seg_mask.byte()))).sum().item() | |||
self.rec += start_seg_mask.sum().item() | |||
self.pre += pred_mask.sum().item() | |||
def get_metric(self, reset=True): | |||
""" | |||
在所有数据都计算结束之后,得到performance | |||
:param reset: | |||
:return: | |||
""" | |||
pre = self.tp/(self.pre + 1e-12) | |||
rec = self.tp/(self.rec + 1e-12) | |||
f = 2*pre*rec/(1e-12 + pre + rec) | |||
if reset: | |||
self.tp = 0 | |||
self.rec = 0 | |||
self.pre = 0 | |||
self.bigger_than_L = 0 | |||
return {'f': round(f, 6), 'pre': round(pre, 6), 'rec': round(rec, 6)} |
@@ -0,0 +1,74 @@ | |||
from torch import nn | |||
import torch | |||
from fastNLP.modules import Embedding | |||
import numpy as np | |||
from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay | |||
from fastNLP.modules import LSTM | |||
class ShiftRelayCWSModel(nn.Module): | |||
""" | |||
该模型可以用于进行分词操作 | |||
包含两个方法, | |||
forward(chars, bigrams, seq_len) -> {'loss': batch_size,} | |||
predict(chars, bigrams) -> {'pred': batch_size x max_len, 'pred_mask': batch_size x max_len} | |||
pred是对当前segment的长度预测,pred_mask是仅在有预测的位置为1 | |||
:param char_embed: 预训练的Embedding或者embedding的shape | |||
:param bigram_embed: 预训练的Embedding或者embedding的shape | |||
:param hidden_size: LSTM的隐藏层大小 | |||
:param num_layers: LSTM的层数 | |||
:param L: SemiCRFShiftRelay的segment大小 | |||
:param num_bigram_per_char: 每个character对应的bigram的数量 | |||
:param drop_p: Dropout的大小 | |||
""" | |||
def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1, | |||
L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2): | |||
super().__init__() | |||
self.char_embedding = Embedding(char_embed, dropout=drop_p) | |||
self._pretrained_embed = False | |||
if isinstance(char_embed, np.ndarray): | |||
self._pretrained_embed = True | |||
self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p) | |||
self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True, | |||
batch_first=True) | |||
self.feature_fn = FeatureFunMax(hidden_size, L) | |||
self.semi_crf_relay = SemiCRFShiftRelay(L) | |||
self.feat_drop = nn.Dropout(drop_p) | |||
self.reset_param() | |||
# self.feature_fn.reset_parameters() | |||
def reset_param(self): | |||
for name, param in self.named_parameters(): | |||
if 'embedding' in name and self._pretrained_embed: | |||
continue | |||
if 'bias_hh' in name: | |||
nn.init.constant_(param, 0) | |||
elif 'bias_ih' in name: | |||
nn.init.constant_(param, 1) | |||
elif len(param.size()) < 2: | |||
nn.init.uniform_(param, -0.1, 0.1) | |||
else: | |||
nn.init.xavier_uniform_(param) | |||
def get_feats(self, chars, bigrams, seq_len): | |||
batch_size, max_len = chars.size() | |||
chars = self.char_embedding(chars) | |||
bigrams = self.bigram_embedding(bigrams) | |||
bigrams = bigrams.view(bigrams.size(0), max_len, -1) | |||
chars = torch.cat([chars, bigrams], dim=-1) | |||
feats, _ = self.lstm(chars, seq_len) | |||
feats = self.feat_drop(feats) | |||
logits, relay_logits = self.feature_fn(feats) | |||
return logits, relay_logits | |||
def forward(self, chars, bigrams, relay_target, relay_mask, end_seg_mask, seq_len): | |||
logits, relay_logits = self.get_feats(chars, bigrams, seq_len) | |||
loss = self.semi_crf_relay(logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len) | |||
return {'loss':loss} | |||
def predict(self, chars, bigrams, seq_len): | |||
logits, relay_logits = self.get_feats(chars, bigrams, seq_len) | |||
pred, pred_mask = self.semi_crf_relay.predict(logits, relay_logits, seq_len) | |||
return {'pred': pred, 'pred_mask': pred_mask} | |||
@@ -0,0 +1,197 @@ | |||
from torch import nn | |||
import torch | |||
import numpy as np | |||
class SemiCRFShiftRelay(nn.Module): | |||
""" | |||
该模块是一个decoder,但当前不支持含有tag的decode。 | |||
""" | |||
def __init__(self, L): | |||
""" | |||
:param L: 不包含relay的长度 | |||
""" | |||
if L<2: | |||
raise RuntimeError() | |||
super().__init__() | |||
self.L = L | |||
def forward(self, logits, relay_logits, relay_target, relay_mask, end_seg_mask, seq_len): | |||
""" | |||
relay node是接下来L个字都不是它的结束。relay的状态是往后滑动1个位置 | |||
:param logits: batch_size x max_len x L, 当前位置往左边L个segment的分数,最后一维的0是长度为1的segment(即本身) | |||
:param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数 | |||
:param relay_target: batch_size x max_len 每个位置他的segment在哪里开始的。如果超过L,则一直保持为L-1。比如长度为 | |||
5的词,L=3, [0, 1, 2, 2, 2] | |||
:param relay_mask: batch_size x max_len, 在需要relay的地方为1, 长度为5的词, L=3时,为[1, 1, 1, 0, 0] | |||
:param end_seg_mask: batch_size x max_len, segment结束的地方为1。 | |||
:param seq_len: batch_size, 句子的长度 | |||
:return: loss: batch_size, | |||
""" | |||
batch_size, max_len, L = logits.size() | |||
# 当前时刻为relay node的分数是多少 | |||
relay_scores = logits.new_zeros(batch_size, max_len) | |||
# 当前时刻结束的分数是多少 | |||
scores = logits.new_zeros(batch_size, max_len+1) | |||
# golden的分数 | |||
gold_scores = relay_logits[:, 0].masked_fill(relay_mask[:, 0].eq(0), 0) + \ | |||
logits[:, 0, 0].masked_fill(end_seg_mask[:, 0].eq(0), 0) | |||
# 初始化 | |||
scores[:, 1] = logits[:, 0, 0] | |||
batch_i = torch.arange(batch_size).to(logits.device).long() | |||
relay_scores[:, 0] = relay_logits[:, 0] | |||
last_relay_index = max_len - self.L | |||
for t in range(1, max_len): | |||
real_L = min(t+1, L) | |||
flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment | |||
# 计算relay_scores的更新 | |||
if t<last_relay_index: | |||
# (1) 从正常位置跳转 | |||
tmp1 = relay_logits[:, t] + scores[:, t] # batch_size | |||
# (2) 从relay跳转 | |||
tmp2 = relay_logits[:, t] + relay_scores[:, t-1] # batch_size | |||
tmp1 = torch.stack([tmp1, tmp2], dim=0) | |||
relay_scores[:, t] = torch.logsumexp(tmp1, dim=0) | |||
# 计算scores的更新 | |||
# (1)从之前的位置跳转过来的 | |||
tmp1 = scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L | |||
if t>self.L-1: | |||
# (2)从relay跳转过来的 | |||
tmp2 = relay_scores[:, t-self.L] # batch_size | |||
tmp2 = tmp2 + flip_logits_t[:, 0] # batch_size | |||
tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1) | |||
scores[:, t+1] = torch.logsumexp(tmp1, dim=-1) # 更新当前时刻的分数 | |||
# 计算golden | |||
seg_i = relay_target[:, t] # batch_size | |||
gold_segment_scores = logits[:, t][(batch_i, seg_i)].masked_fill(end_seg_mask[:, t].eq(0), 0) # batch_size, 后向从0到L长度的segment的分数 | |||
relay_score = relay_logits[:, t].masked_fill(relay_mask[:, t].eq(0), 0) | |||
gold_scores = gold_scores + relay_score + gold_segment_scores | |||
all_scores = scores.gather(dim=1, index=seq_len.unsqueeze(1)).squeeze(1) # batch_size | |||
return all_scores - gold_scores | |||
def predict(self, logits, relay_logits, seq_len): | |||
""" | |||
relay node是接下来L个字都不是它的结束。relay的状态是往后滑动L-1个位置 | |||
:param logits: batch_size x max_len x L, 当前位置左边L个segment的分数,最后一维的0是长度为1的segment(即本身) | |||
:param relay_logits: batch_size x max_len, 当前位置是接下来L-1个位置都不是终点的分数 | |||
:param seq_len: batch_size, 句子的长度 | |||
:return: pred: batch_size x max_len以该点开始的segment的(长度-1); pred_mask为1的地方预测有segment开始 | |||
""" | |||
batch_size, max_len, L = logits.size() | |||
# 当前时刻为relay node的分数是多少 | |||
max_relay_scores = logits.new_zeros(batch_size, max_len) | |||
relay_bt = seq_len.new_zeros(batch_size, max_len) # 当前结果是否来自于relay的结果 | |||
# 当前时刻结束的分数是多少 | |||
max_scores = logits.new_zeros(batch_size, max_len+1) | |||
bt = seq_len.new_zeros(batch_size, max_len) | |||
# 初始化 | |||
max_scores[:, 1] = logits[:, 0, 0] | |||
max_relay_scores[:, 0] = relay_logits[:, 0] | |||
last_relay_index = max_len - self.L | |||
for t in range(1, max_len): | |||
real_L = min(t+1, L) | |||
flip_logits_t = logits[:, t, :real_L].flip(dims=[1]) # flip之后低0个位置为real_L-1的segment | |||
# 计算relay_scores的更新 | |||
if t<last_relay_index: | |||
# (1) 从正常位置跳转 | |||
tmp1 = relay_logits[:, t] + max_scores[:, t] | |||
# (2) 从relay跳转 | |||
tmp2 = relay_logits[:, t] + max_relay_scores[:, t-1] # batch_size | |||
# 每个sample的倒数L位不能是relay了 | |||
tmp2 = tmp2.masked_fill(seq_len.le(t+L), float('-inf')) | |||
mask_i = tmp1.lt(tmp2) # 为1的位置为relay跳转 | |||
relay_bt[:, t].masked_fill_(mask_i, 1) | |||
max_relay_scores[:, t] = torch.max(tmp1, tmp2) | |||
# 计算scores的更新 | |||
# (1)从之前的位置跳转过来的 | |||
tmp1 = max_scores[:, t-real_L+1:t+1] + flip_logits_t # batch_size x L | |||
tmp1 = tmp1.flip(dims=[1]) # 0的位置代表长度为1的segment | |||
if self.L-1<t: | |||
# (2)从relay跳转过来的 | |||
tmp2 = max_relay_scores[:, t-self.L] # batch_size | |||
tmp2 = tmp2 + flip_logits_t[:, 0] | |||
tmp1 = torch.cat([tmp1, tmp2.unsqueeze(-1)], dim=-1) | |||
# 看哪个更大 | |||
max_score, pt = torch.max(tmp1, dim=1) | |||
max_scores[:, t+1] = max_score | |||
# mask_i = pt.ge(self.L) | |||
bt[:, t] = pt # 假设L=3, 那么对于0,1,2,3分别代表的是[t, t], [t-1, t], [t-2, t], [t-self.L(relay), t] | |||
# 需要把结果decode出来 | |||
pred = np.zeros((batch_size, max_len), dtype=int) | |||
pred_mask = np.zeros((batch_size, max_len), dtype=int) | |||
seq_len = seq_len.tolist() | |||
bt = bt.tolist() | |||
relay_bt = relay_bt.tolist() | |||
for b in range(batch_size): | |||
seq_len_i = seq_len[b] | |||
bt_i = bt[b][:seq_len_i] | |||
relay_bt_i = relay_bt[b][:seq_len_i] | |||
j = seq_len_i - 1 | |||
assert relay_bt_i[j]!=1 | |||
while j>-1: | |||
if bt_i[j]==self.L: | |||
seg_start_pos = j | |||
j = j-self.L | |||
while relay_bt_i[j]!=0 and j>-1: | |||
j = j - 1 | |||
pred[b, j] = seg_start_pos - j | |||
pred_mask[b, j] = 1 | |||
else: | |||
length = bt_i[j] | |||
j = j - bt_i[j] | |||
pred_mask[b, j] = 1 | |||
pred[b, j] = length | |||
j = j - 1 | |||
return torch.LongTensor(pred).to(logits.device), torch.LongTensor(pred_mask).to(logits.device) | |||
class FeatureFunMax(nn.Module): | |||
def __init__(self, hidden_size:int, L:int): | |||
""" | |||
用于计算semi-CRF特征的函数。给定batch_size x max_len x hidden_size形状的输入,输出为batch_size x max_len x L的 | |||
分数,以及batch_size x max_len的relay的分数。两者的区别参考论文 TODO 补充 | |||
:param hidden_size: 输入特征的维度大小 | |||
:param L: 不包含relay node的segment的长度大小。 | |||
""" | |||
super().__init__() | |||
self.end_fc = nn.Linear(hidden_size, 1, bias=False) | |||
self.whole_w = nn.Parameter(torch.randn(L, hidden_size)) | |||
self.relay_fc = nn.Linear(hidden_size, 1) | |||
self.length_bias = nn.Parameter(torch.randn(L)) | |||
self.L = L | |||
def forward(self, logits): | |||
""" | |||
:param logits: batch_size x max_len x hidden_size | |||
:return: batch_size x max_len x L # 最后一维为左边segment的分数,0处为长度为1的segment | |||
batch_size x max_len, # 当前位置是接下来L-1个位置都不是终点的分数 | |||
""" | |||
batch_size, max_len, hidden_size = logits.size() | |||
# start_scores = self.start_fc(logits) # batch_size x max_len x 1 # 每个位置作为start的分数 | |||
tmp = logits.new_zeros(batch_size, max_len+self.L-1, hidden_size) | |||
tmp[:, -max_len:] = logits | |||
# batch_size x max_len x hidden_size x (self.L) -> batch_size x max_len x (self.L) x hidden_size | |||
start_logits = tmp.unfold(dimension=1, size=self.L, step=1).transpose(2, 3).flip(dims=[2]) | |||
end_scores = self.end_fc(logits) # batch_size x max_len x 1 | |||
# 计算relay的特征 | |||
relay_tmp = logits.new_zeros(batch_size, max_len, hidden_size) | |||
relay_tmp[:, :-self.L] = logits[:, self.L:] | |||
# batch_size x max_len x hidden_size | |||
relay_logits_max = torch.max(relay_tmp, logits) # end - start | |||
logits_max = torch.max(logits.unsqueeze(2), start_logits) # batch_size x max_len x L x hidden_size | |||
whole_scores = (logits_max*self.whole_w).sum(dim=-1) # batch_size x max_len x self.L | |||
# whole_scores = self.whole_fc().squeeze(-1) # bz x max_len x self.L | |||
# batch_size x max_len | |||
relay_scores = self.relay_fc(relay_logits_max).squeeze(-1) | |||
return whole_scores+end_scores+self.length_bias.view(1, 1, -1), relay_scores |
@@ -0,0 +1,17 @@ | |||
import unittest | |||
from ..data.CWSDataLoader import SigHanLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
class TestCWSDataLoader(unittest.TestCase): | |||
def test_case1(self): | |||
cws_loader = SigHanLoader(target_type='bmes') | |||
data = cws_loader.process('pku_demo.txt') | |||
print(data.datasets) | |||
def test_calse2(self): | |||
cws_loader = SigHanLoader(target_type='bmes') | |||
data = cws_loader.process('pku_demo.txt', bigram_vocab_opt=VocabularyOption()) | |||
print(data.datasets) |
@@ -0,0 +1,64 @@ | |||
import os | |||
from fastNLP import cache_results | |||
from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader | |||
from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel | |||
from fastNLP.io.embed_loader import EmbeddingOption | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP import Trainer | |||
from torch.optim import Adam | |||
from fastNLP import BucketSampler | |||
from fastNLP import GradientClipCallback | |||
from reproduction.seqence_labelling.cws.model.metric import RelayMetric | |||
# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果 | |||
@cache_results(None) | |||
def prepare_data(): | |||
data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt, | |||
bigram_vocab_opt=bigram_vocab_opt, | |||
bigram_embed_opt=bigram_embed_opt, | |||
L=L) | |||
return data | |||
#########hyper | |||
L = 4 | |||
hidden_size = 200 | |||
num_layers = 1 | |||
drop_p = 0.2 | |||
lr = 0.02 | |||
#########hyper | |||
device = 0 | |||
# !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到 | |||
# 你们的reproduction路径下,然后设置.gitignore | |||
file_dir = '/path/to/' | |||
char_embed_path = '/pretrain/vectors/1grams_t3_m50_corpus.txt' | |||
bigram_embed_path = '/pretrain/vectors/2grams_t3_m50_corpus.txt' | |||
bigram_vocab_opt = VocabularyOption(min_freq=3) | |||
char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path) | |||
bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path) | |||
data_name = os.path.basename(file_dir) | |||
cache_fp = 'caches/{}.pkl'.format(data_name) | |||
data = prepare_data(_cache_fp=cache_fp, _refresh=True) | |||
model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'], | |||
hidden_size=hidden_size, num_layers=num_layers, | |||
L=L, num_bigram_per_char=1, drop_p=drop_p) | |||
sampler = BucketSampler(batch_size=32) | |||
optimizer = Adam(model.parameters(), lr=lr) | |||
clipper = GradientClipCallback(clip_value=5, clip_type='value') | |||
callbacks = [clipper] | |||
# if pretrain: | |||
# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until) | |||
# callbacks.append(fixer) | |||
trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler, | |||
update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(), | |||
metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks, | |||
check_code_level=0) | |||
trainer.train() |
@@ -0,0 +1,93 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from typing import Union, Dict | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
from reproduction.utils import check_dataloader_paths | |||
from fastNLP.io.dataset_loader import ConllLoader | |||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||
class Conll2003DataLoader(DataSetLoader): | |||
def __init__(self, task:str='ner', encoding_type:str='bioes'): | |||
""" | |||
加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos | |||
时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回 | |||
的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但 | |||
鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行 | |||
ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。 | |||
:param task: 指定需要标注任务。可选ner, pos, chunk | |||
""" | |||
assert task in ('ner', 'pos', 'chunk') | |||
index = {'ner':3, 'pos':1, 'chunk':2}[task] | |||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) | |||
self._tag_converters = None | |||
if task in ('ner', 'chunk'): | |||
self._tag_converters = [iob2] | |||
if encoding_type == 'bioes': | |||
self._tag_converters.append(iob2bioes) | |||
def load(self, path: str): | |||
dataset = self._loader.load(path) | |||
def convert_tag_schema(tags): | |||
for converter in self._tag_converters: | |||
tags = converter(tags) | |||
return tags | |||
if self._tag_converters: | |||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
return dataset | |||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=True): | |||
""" | |||
读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略 | |||
:param paths: | |||
:param word_vocab_opt: vocabulary的初始化值 | |||
:param lower: 是否将所有字母转为小写 | |||
:return: | |||
""" | |||
# 读取数据 | |||
paths = check_dataloader_paths(paths) | |||
data = DataInfo() | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||
if lower: | |||
dataset.words.lower() | |||
data.datasets[name] = dataset | |||
# 对construct vocab | |||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
data.vocabs[Const.INPUT] = word_vocab | |||
# cap words | |||
cap_word_vocab = Vocabulary() | |||
cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words', | |||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||
input_fields.append('cap_words') | |||
data.vocabs['cap_words'] = cap_word_vocab | |||
# 对target建vocab | |||
target_vocab = Vocabulary(unknown=None, padding=None) | |||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
data.vocabs[Const.TARGET] = target_vocab | |||
for name, dataset in data.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||
dataset.set_input(*input_fields) | |||
dataset.set_target(*target_fields) | |||
return data | |||
if __name__ == '__main__': | |||
pass |
@@ -0,0 +1,152 @@ | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from typing import Union, Dict | |||
from fastNLP import DataSet | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
from reproduction.utils import check_dataloader_paths | |||
from fastNLP.io.dataset_loader import ConllLoader | |||
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 | |||
class OntoNoteNERDataLoader(DataSetLoader): | |||
""" | |||
用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。 | |||
""" | |||
def __init__(self, encoding_type:str='bioes'): | |||
assert encoding_type in ('bioes', 'bio') | |||
self.encoding_type = encoding_type | |||
if encoding_type=='bioes': | |||
self.encoding_method = iob2bioes | |||
else: | |||
self.encoding_method = iob2 | |||
def load(self, path:str)->DataSet: | |||
""" | |||
给定一个文件路径,读取数据。返回的DataSet包含以下的field | |||
raw_words: List[str] | |||
target: List[str] | |||
:param path: | |||
:return: | |||
""" | |||
dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path) | |||
def convert_to_bio(tags): | |||
bio_tags = [] | |||
flag = None | |||
for tag in tags: | |||
label = tag.strip("()*") | |||
if '(' in tag: | |||
bio_label = 'B-' + label | |||
flag = label | |||
elif flag: | |||
bio_label = 'I-' + flag | |||
else: | |||
bio_label = 'O' | |||
if ')' in tag: | |||
flag = None | |||
bio_tags.append(bio_label) | |||
return self.encoding_method(bio_tags) | |||
def convert_word(words): | |||
converted_words = [] | |||
for word in words: | |||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的 | |||
if not word.startswith('-'): | |||
converted_words.append(word) | |||
continue | |||
# 以下是由于这些符号被转义了,再转回来 | |||
tfrs = {'-LRB-':'(', | |||
'-RRB-': ')', | |||
'-LSB-': '[', | |||
'-RSB-': ']', | |||
'-LCB-': '{', | |||
'-RCB-': '}' | |||
} | |||
if word in tfrs: | |||
converted_words.append(tfrs[word]) | |||
else: | |||
converted_words.append(word) | |||
return converted_words | |||
dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words') | |||
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') | |||
return dataset | |||
def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, | |||
lower:bool=True)->DataInfo: | |||
""" | |||
读取并处理数据。返回的DataInfo包含以下的内容 | |||
vocabs: | |||
word: Vocabulary | |||
target: Vocabulary | |||
datasets: | |||
train: DataSet | |||
words: List[int], 被设置为input | |||
target: int. label,被同时设置为input和target | |||
seq_len: int. 句子的长度,被同时设置为input和target | |||
raw_words: List[str] | |||
xxx(根据传入的paths可能有所变化) | |||
:param paths: | |||
:param word_vocab_opt: vocabulary的初始化值 | |||
:param lower: 是否使用小写 | |||
:return: | |||
""" | |||
paths = check_dataloader_paths(paths) | |||
data = DataInfo() | |||
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] | |||
target_fields = [Const.TARGET, Const.INPUT_LEN] | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT) | |||
if lower: | |||
dataset.words.lower() | |||
data.datasets[name] = dataset | |||
# 对construct vocab | |||
word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) | |||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train']) | |||
word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
data.vocabs[Const.INPUT] = word_vocab | |||
# cap words | |||
cap_word_vocab = Vocabulary() | |||
cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words') | |||
cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words') | |||
input_fields.append('cap_words') | |||
data.vocabs['cap_words'] = cap_word_vocab | |||
# 对target建vocab | |||
target_vocab = Vocabulary(unknown=None, padding=None) | |||
target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET) | |||
data.vocabs[Const.TARGET] = target_vocab | |||
for name, dataset in data.datasets.items(): | |||
dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN) | |||
dataset.set_input(*input_fields) | |||
dataset.set_target(*target_fields) | |||
return data | |||
if __name__ == '__main__': | |||
loader = OntoNoteNERDataLoader() | |||
dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt') | |||
print(dataset.target.value_count()) | |||
print(dataset[:4]) | |||
""" | |||
train 115812 2200752 | |||
development 15680 304684 | |||
test 12217 230111 | |||
train 92403 1901772 | |||
valid 13606 279180 | |||
test 10258 204135 | |||
""" |
@@ -0,0 +1,49 @@ | |||
from typing import List | |||
def iob2(tags:List[str])->List[str]: | |||
""" | |||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。 | |||
:param tags: 需要转换的tags | |||
""" | |||
for i, tag in enumerate(tags): | |||
if tag == "O": | |||
continue | |||
split = tag.split("-") | |||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||
raise TypeError("The encoding schema is not a valid IOB type.") | |||
if split[0] == "B": | |||
continue | |||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
elif tags[i - 1][1:] == tag[1:]: | |||
continue | |||
else: # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
return tags | |||
def iob2bioes(tags:List[str])->List[str]: | |||
""" | |||
将iob的tag转换为bmeso编码 | |||
:param tags: | |||
:return: | |||
""" | |||
new_tags = [] | |||
for i, tag in enumerate(tags): | |||
if tag == 'O': | |||
new_tags.append(tag) | |||
else: | |||
split = tag.split('-')[0] | |||
if split == 'B': | |||
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('B-', 'S-')) | |||
elif split == 'I': | |||
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('I-', 'E-')) | |||
else: | |||
raise TypeError("Invalid IOB format.") | |||
return new_tags |
@@ -0,0 +1,56 @@ | |||
import torch | |||
from torch import nn | |||
from fastNLP import seq_len_to_mask | |||
from fastNLP.modules import Embedding | |||
from fastNLP.modules import LSTM | |||
from fastNLP.modules import ConditionalRandomField, allowed_transitions | |||
import torch.nn.functional as F | |||
from fastNLP import Const | |||
class CNNBiLSTMCRF(nn.Module): | |||
def __init__(self, embed, char_embed, hidden_size, num_layers, tag_vocab, dropout=0.5, encoding_type='bioes'): | |||
super().__init__() | |||
self.embedding = Embedding(embed, dropout=0.5, dropout_word=0) | |||
self.char_embedding = Embedding(char_embed, dropout=0.5, dropout_word=0.01) | |||
self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim, | |||
hidden_size=hidden_size//2, num_layers=num_layers, | |||
bidirectional=True, batch_first=True) | |||
self.fc = nn.Linear(hidden_size, len(tag_vocab)) | |||
transitions = allowed_transitions(tag_vocab.idx2word, encoding_type=encoding_type, include_start_end=True) | |||
self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=transitions) | |||
self.dropout = nn.Dropout(dropout, inplace=True) | |||
for name, param in self.named_parameters(): | |||
if 'fc' in name: | |||
if param.data.dim()>1: | |||
nn.init.xavier_uniform_(param) | |||
else: | |||
nn.init.constant_(param, 0) | |||
if 'crf' in name: | |||
nn.init.zeros_(param) | |||
def _forward(self, words, cap_words, seq_len, target=None): | |||
words = self.embedding(words) | |||
chars = self.char_embedding(cap_words) | |||
words = torch.cat([words, chars], dim=-1) | |||
outputs, _ = self.lstm(words, seq_len) | |||
self.dropout(outputs) | |||
logits = F.log_softmax(self.fc(outputs), dim=-1) | |||
if target is not None: | |||
loss = self.crf(logits, target, seq_len_to_mask(seq_len)) | |||
return {Const.LOSS: loss} | |||
else: | |||
pred, _ = self.crf.viterbi_decode(logits, seq_len_to_mask(seq_len)) | |||
return {Const.OUTPUT: pred} | |||
def forward(self, words, cap_words, seq_len, target): | |||
return self._forward(words, cap_words, seq_len, target) | |||
def predict(self, words, cap_words, seq_len): | |||
return self._forward(words, cap_words, seq_len, None) |
@@ -0,0 +1,33 @@ | |||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader | |||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import iob2, iob2bioes | |||
import unittest | |||
class TestTagSchemaConverter(unittest.TestCase): | |||
def test_iob2(self): | |||
tags = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] | |||
golden = ['B-ORG', 'O', 'B-MISC', 'O', 'O', 'O', 'B-MISC', 'O', 'O'] | |||
self.assertListEqual(golden, iob2(tags)) | |||
tags = ['I-ORG', 'O'] | |||
golden = ['B-ORG', 'O'] | |||
self.assertListEqual(golden, iob2(tags)) | |||
tags = ['I-MISC', 'I-MISC', 'O', 'I-PER', 'I-PER', 'O'] | |||
golden = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] | |||
self.assertListEqual(golden, iob2(tags)) | |||
def test_iob2bemso(self): | |||
tags = ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O'] | |||
golden = ['B-MISC', 'E-MISC', 'O', 'B-PER', 'E-PER', 'O'] | |||
self.assertListEqual(golden, iob2bioes(tags)) | |||
def test_conll2003_loader(): | |||
path = '/hdd/fudanNLP/fastNLP/others/data/conll2003/train.txt' | |||
loader = Conll2003DataLoader().load(path) | |||
print(loader[:3]) | |||
if __name__ == '__main__': | |||
test_conll2003_loader() |
@@ -0,0 +1,70 @@ | |||
from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding, BertEmbedding, ElmoEmbedding, LSTMCharEmbedding | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | |||
from fastNLP import Trainer | |||
from fastNLP import SpanFPreRecMetric | |||
from fastNLP import BucketSampler | |||
from fastNLP import Const | |||
from torch.optim import SGD, Adam | |||
from fastNLP import GradientClipCallback | |||
from fastNLP.core.callback import FitlogCallback, LRScheduler | |||
from torch.optim.lr_scheduler import LambdaLR | |||
from reproduction.seqence_labelling.ner.model.swats import SWATS | |||
import fitlog | |||
fitlog.debug() | |||
from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003DataLoader | |||
encoding_type = 'bioes' | |||
data = Conll2003DataLoader(encoding_type=encoding_type).process('../../../../others/data/conll2003', | |||
word_vocab_opt=VocabularyOption(min_freq=2), | |||
lower=False) | |||
print(data) | |||
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], | |||
kernel_sizes=[3]) | |||
# char_embed = LSTMCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30 ,char_emb_size=30) | |||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | |||
model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/wiki_en_100_50_case_2.txt', | |||
requires_grad=True) | |||
word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std() | |||
# import joblib | |||
# raw_data = joblib.load('/hdd/fudanNLP/fastNLP/others/NER-with-LS/data/conll_with_data.joblib') | |||
# def convert_to_ids(raw_words): | |||
# ids = [] | |||
# for word in raw_words: | |||
# id = raw_data['word_to_id'][word] | |||
# id = raw_data['id_to_emb_map'][id] | |||
# ids.append(id) | |||
# return ids | |||
# word_embed = raw_data['emb_matrix'] | |||
# for name, dataset in data.datasets.items(): | |||
# dataset.apply_field(convert_to_ids, field_name='raw_words', new_field_name=Const.INPUT) | |||
# word_embed = ElmoEmbedding(vocab=data.vocabs['cap_words'], | |||
# model_dir_or_name='/hdd/fudanNLP/fastNLP/others/pretrained_models/elmo_en', | |||
# requires_grad=True) | |||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||
encoding_type=encoding_type) | |||
callbacks = [ | |||
GradientClipCallback(clip_type='value', clip_value=5) | |||
, FitlogCallback({'test':data.datasets['test']}, verbose=1) | |||
] | |||
# optimizer = Adam(model.parameters(), lr=0.005) | |||
optimizer = SWATS(model.parameters(), verbose=True) | |||
# optimizer = SGD(model.parameters(), lr=0.008, momentum=0.9) | |||
# scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) | |||
# callbacks.append(scheduler) | |||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(), | |||
device=1, dev_data=data.datasets['dev'], batch_size=10, | |||
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), | |||
callbacks=callbacks, num_workers=1, n_epochs=100) | |||
trainer.train() |
@@ -0,0 +1,65 @@ | |||
import sys | |||
sys.path.append('../../..') | |||
from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding | |||
from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF | |||
from fastNLP import Trainer | |||
from fastNLP import SpanFPreRecMetric | |||
from fastNLP import BucketSampler | |||
from fastNLP import Const | |||
from torch.optim import SGD, Adam | |||
from torch.optim.lr_scheduler import LambdaLR | |||
from fastNLP import GradientClipCallback | |||
from fastNLP.core.callback import FitlogCallback, LRScheduler | |||
from reproduction.seqence_labelling.ner.model.swats import SWATS | |||
import fitlog | |||
fitlog.debug() | |||
from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader | |||
encoding_type = 'bioes' | |||
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/v4/english', | |||
lower=True) | |||
import joblib | |||
raw_data = joblib.load('/hdd/fudanNLP/fastNLP/others/NER-with-LS/data/ontonotes_with_data.joblib') | |||
def convert_to_ids(raw_words): | |||
ids = [] | |||
for word in raw_words: | |||
id = raw_data['word_to_id'][word] | |||
id = raw_data['id_to_emb_map'][id] | |||
ids.append(id) | |||
return ids | |||
word_embed = raw_data['emb_matrix'] | |||
for name, dataset in data.datasets.items(): | |||
dataset.apply_field(convert_to_ids, field_name='raw_words', new_field_name=Const.INPUT) | |||
print(data) | |||
char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], | |||
kernel_sizes=[3]) | |||
# word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | |||
# model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt', | |||
# requires_grad=True) | |||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], | |||
encoding_type=encoding_type) | |||
callbacks = [GradientClipCallback(clip_value=5, clip_type='value'), | |||
FitlogCallback(data.datasets['test'], verbose=1)] | |||
optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) | |||
scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) | |||
callbacks.append(scheduler) | |||
# optimizer = SWATS(model.parameters(), verbose=True) | |||
# optimizer = Adam(model.parameters(), lr=0.005) | |||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(num_buckets=100), | |||
device=0, dev_data=data.datasets['dev'], batch_size=10, | |||
metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), | |||
callbacks=callbacks, num_workers=1, n_epochs=100) | |||
trainer.train() |
@@ -0,0 +1,75 @@ | |||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io.base_loader import DataSetLoader, DataInfo | |||
from typing import Union, Dict, List, Iterator | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import Vocabulary | |||
from fastNLP import Const | |||
from reproduction.utils import check_dataloader_paths | |||
from functools import partial | |||
class MTL16Loader(DataSetLoader): | |||
""" | |||
读取MTL16数据集,DataSet包含以下fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
数据来源:https://pan.baidu.com/s/1c2L6vdA | |||
""" | |||
def __init__(self): | |||
super(MTL16Loader, self).__init__() | |||
def _load(self, path): | |||
dataset = DataSet() | |||
with open(path, 'r', encoding="utf-8") as f: | |||
for line in f: | |||
line = line.strip() | |||
if not line: | |||
continue | |||
parts = line.split('\t') | |||
target = parts[0] | |||
words = parts[1].split() | |||
dataset.append(Instance(words=words, target=target)) | |||
if len(dataset)==0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None, | |||
src_embed_opt: EmbeddingOption = None): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||
info.vocabs = { | |||
"words": src_vocab, | |||
"target": tgt_vocab | |||
} | |||
info.datasets = datasets | |||
if src_embed_opt is not None: | |||
embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | |||
info.embeddings['words'] = embed | |||
return info |
@@ -0,0 +1,68 @@ | |||
import ast | |||
from fastNLP import DataSet, Instance, Vocabulary | |||
from fastNLP.core.vocabulary import VocabularyOption | |||
from fastNLP.io import JsonLoader | |||
from fastNLP.io.base_loader import DataInfo | |||
from fastNLP.io.embed_loader import EmbeddingOption | |||
from fastNLP.io.file_reader import _read_json | |||
from typing import Union, Dict | |||
from reproduction.Star_transformer.datasets import EmbedLoader | |||
from reproduction.utils import check_dataloader_paths | |||
class yelpLoader(JsonLoader): | |||
""" | |||
读取Yelp数据集, DataSet包含fields: | |||
review_id: str, 22 character unique review id | |||
user_id: str, 22 character unique user id | |||
business_id: str, 22 character business id | |||
useful: int, number of useful votes received | |||
funny: int, number of funny votes received | |||
cool: int, number of cool votes received | |||
date: str, date formatted YYYY-MM-DD | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
数据来源: https://www.yelp.com/dataset/download | |||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||
""" | |||
def __init__(self, fine_grained=False): | |||
super(yelpLoader, self).__init__() | |||
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||
'4.0': 'positive', '5.0': 'very positive'} | |||
if not fine_grained: | |||
tag_v['1.0'] = tag_v['2.0'] | |||
tag_v['5.0'] = tag_v['4.0'] | |||
self.fine_grained = fine_grained | |||
self.tag_v = tag_v | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
d = ast.literal_eval(d) | |||
d["words"] = d.pop("text").split() | |||
d["target"] = self.tag_v[str(d.pop("stars"))] | |||
ds.append(Instance(**d)) | |||
return ds | |||
def process(self, paths: Union[str, Dict[str, str]], vocab_opt: VocabularyOption = None, | |||
embed_opt: EmbeddingOption = None): | |||
paths = check_dataloader_paths(paths) | |||
datasets = {} | |||
info = DataInfo() | |||
vocab = Vocabulary(min_freq=2) if vocab_opt is None else Vocabulary(**vocab_opt) | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
vocab.from_dataset(dataset, field_name="words") | |||
info.vocabs = vocab | |||
info.datasets = datasets | |||
if embed_opt is not None: | |||
embed = EmbedLoader.load_with_vocab(**embed_opt, vocab=vocab) | |||
info.embeddings['words'] = embed | |||
return info | |||
@@ -0,0 +1 @@ | |||
# TODO |
@@ -0,0 +1 @@ | |||
# TODO |
@@ -0,0 +1 @@ | |||
# TODO |
@@ -0,0 +1,10 @@ | |||
1 the only thing better than these sunglasses is the customer service i got , after i dropped and broke the lenses on these i called 80 's purple and they actually sent me out a replacement free of charge . i was blown away | |||
0 this light worked for one day . i should have known better because in the past , i bought a tap light , and it worked for only a few days , too . do n't waste your money | |||
1 i 've tried 6 different nursing bras . this one , with the center snap closure , is the easiest to use . it is also the lightest and most comfortable , while providing good support . my only complaint is that after about 50 washes the underwire begins to poke free from the fabric . even when i try to sew it back into place , it breaks loose after a few washes . perhaps if i handwashed the bra instead of using a machine , it would last longer . this bra is less durabe than my other nursing bras ( particularly the leading lady bra , which seems to be indestructible ) , but it is well worth the sacrifice for comfort , lightness , and ease of use . it is by far my favorite | |||
0 i have had my bag for a couple of months . the liner on the inside has already ripped | |||
0 the photo is quite deceiving . this suit is made out of cheap polyester fabric that looks cheap , shiny , and is horrible to the touch . my three year olds hate the uncomfortable stiffness . spend the extra money for a decent fabric that is actually practical for a toddler if they really need a suit | |||
1 i had bought a bra of this model at a discount store , just got lucky . it quickly became my favorite , and i was glad to find it at amazon . | |||
0 lookslike it would be a nice product , but it 's only for very small babies up to 12 pounds and 23 inches . my baby is very long and just does n't fit - wish target/amazon would have been more upfront with the sizing | |||
0 i purchased the non-premium kit ( $ 9.99 ) with a silicone skin case cover and 2 screen protectors ( one for each screen ) , but it is the same case . the problem is that the silicone skin cover is slippery , twice as slippery as the nintendo lite without the cover . we thought that washing them in dove dish soap would wash away the slipperyness , but that did n't work . after handling the cover , your hands have a slippery residue on them . the other issue is that the cover is so thin that it is little more than scratch protection , not impact protection . the screen covers that come with the non-premium kit are ok , i guess , but one of them had 2 defect particles that were raised ( trust me , the screen was clean ) . i purchased 2 kits , and i had one screen protector defect and my wife accidentally broke one of the silicone covers hinge straps with little effort . i do not recommend this product at all | |||
1 good quality jeans at an affordable price . size is just right , quite comfortable | |||
0 not the best fabric , scratchy and see thru . you get what you pay for on these |
@@ -0,0 +1,20 @@ | |||
"{\"review_id\":\"Q1sbwvVQXV2734tPgoKj4Q\",\"user_id\":\"hG7b0MtEbXx5QzbzE6C_VA\",\"business_id\":\"ujmEBvifdJM6h6RLv4wQIg\",\"stars\":1.0,\"useful\":6,\"funny\":1,\"cool\":0,\"text\":\"Total bill for this horrible service? Over $8Gs. These crooks actually had the nerve to charge us $69 for 3 pills. I checked online the pills can be had for 19 cents EACH! Avoid Hospital ERs at all costs.\",\"date\":\"2013-05-07 04:34:36\"}\n" | |||
"{\"review_id\":\"GJXCdrto3ASJOqKeVWPi6Q\",\"user_id\":\"yXQM5uF2jS6es16SJzNHfg\",\"business_id\":\"NZnhc2sEQy3RmzKTZnqtwQ\",\"stars\":5.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"I *adore* Travis at the Hard Rock's new Kelly Cardenas Salon! I'm always a fan of a great blowout and no stranger to the chains that offer this service; however, Travis has taken the flawless blowout to a whole new level! \\n\\nTravis's greets you with his perfectly green swoosh in his otherwise perfectly styled black hair and a Vegas-worthy rockstar outfit. Next comes the most relaxing and incredible shampoo -- where you get a full head message that could cure even the very worst migraine in minutes --- and the scented shampoo room. Travis has freakishly strong fingers (in a good way) and use the perfect amount of pressure. That was superb! Then starts the glorious blowout... where not one, not two, but THREE people were involved in doing the best round-brush action my hair has ever seen. The team of stylists clearly gets along extremely well, as it's evident from the way they talk to and help one another that it's really genuine and not some corporate requirement. It was so much fun to be there! \\n\\nNext Travis started with the flat iron. The way he flipped his wrist to get volume all around without over-doing it and making me look like a Texas pagent girl was admirable. It's also worth noting that he didn't fry my hair -- something that I've had happen before with less skilled stylists. At the end of the blowout & style my hair was perfectly bouncey and looked terrific. The only thing better? That this awesome blowout lasted for days! \\n\\nTravis, I will see you every single time I'm out in Vegas. You make me feel beauuuutiful!\",\"date\":\"2017-01-14 21:30:33\"}\n" | |||
"{\"review_id\":\"2TzJjDVDEuAW6MR5Vuc1ug\",\"user_id\":\"n6-Gk65cPZL6Uz8qRm3NYw\",\"business_id\":\"WTqjgwHlXbSFevF32_DJVw\",\"stars\":5.0,\"useful\":3,\"funny\":0,\"cool\":0,\"text\":\"I have to say that this office really has it together, they are so organized and friendly! Dr. J. Phillipp is a great dentist, very friendly and professional. The dental assistants that helped in my procedure were amazing, Jewel and Bailey helped me to feel comfortable! I don't have dental insurance, but they have this insurance through their office you can purchase for $80 something a year and this gave me 25% off all of my dental work, plus they helped me get signed up for care credit which I knew nothing about before this visit! I highly recommend this office for the nice synergy the whole office has!\",\"date\":\"2016-11-09 20:09:03\"}\n" | |||
"{\"review_id\":\"yi0R0Ugj_xUx_Nek0-_Qig\",\"user_id\":\"dacAIZ6fTM6mqwW5uxkskg\",\"business_id\":\"ikCg8xy5JIg_NGPx-MSIDA\",\"stars\":5.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"Went in for a lunch. Steak sandwich was delicious, and the Caesar salad had an absolutely delicious dressing, with a perfect amount of dressing, and distributed perfectly across each leaf. I know I'm going on about the salad ... But it was perfect.\\n\\nDrink prices were pretty good.\\n\\nThe Server, Dawn, was friendly and accommodating. Very happy with her.\\n\\nIn summation, a great pub experience. Would go again!\",\"date\":\"2018-01-09 20:56:38\"}\n" | |||
"{\"review_id\":\"11a8sVPMUFtaC7_ABRkmtw\",\"user_id\":\"ssoyf2_x0EQMed6fgHeMyQ\",\"business_id\":\"b1b1eb3uo-w561D0ZfCEiQ\",\"stars\":1.0,\"useful\":7,\"funny\":0,\"cool\":0,\"text\":\"Today was my second out of three sessions I had paid for. Although my first session went well, I could tell Meredith had a particular enjoyment for her male clients over her female. However, I returned because she did my teeth fine and I was pleased with the results. When I went in today, I was in the whitening room with three other gentlemen. My appointment started out well, although, being a person who is in the service industry, I always attend to my female clientele first when a couple arrives. Unbothered by those signs, I waited my turn. She checked on me once after my original 30 minute timer to ask if I was ok. She attended my boyfriend on numerous occasions, as well as the other men, and would exit the room without even asking me or looking to see if I had any irritation. Half way through, another woman had showed up who she was explaining the deals to in the lobby. While she admits timers must be reset half way through the process, she reset my boyfriends, left, rest the gentleman furthest away from me who had time to come in, redeem his deal, get set, and gave his timer done, before me, then left, and at this point my time was at 10 minutes. So, she should have reset it 5 minutes ago, according to her. While I sat there patiently this whole time with major pain in my gums, i watched the time until the lamp shut off. Not only had she reset two others, explained deals to other guest, but she never once checked on my time. When my light turned off, I released the stance of my mouth to a more relaxed state, assuming I was only getting a thirty minute session instead of the usual 45, because she had yet to come in. At this point, the teeth formula was not only burning the gum she neglected for 25 minutes now, but it began to burn my lips. I began squealing and slapping my chair trying to get her attention from the other room in a panic. I was in so much pain, that by the time she entered the room I was already out of my chair. She finally then acknowledged me, and asked if she could put vitamin E on my gum burn (pictured below). At this point, she has treated two other gums burns, while neglecting me, and I was so irritated that I had to suffer, all I wanted was to leave. While I waited for my boyfriend, she kept harassing me about the issue. Saying, \\\"well burns come with teeth whitening.\\\" While I totally agree, and under justifiable circumstances would not be as irritate, it could have easily been avoid if she had checked on me even a second time, so I could let her know. Not only did she never check on my physical health, she couldn't even take two seconds to reset the timer, which she even admitted to me. Her accuse was that she was coming in to do it, but I had the light off for a solid two minutes before I couldn't stand the pain. She admitted it should be reset every 15 minutes, which means for 25 minutes she did not bother to help me at all. Her guest in the lobby then proceeded to attack me as well, simply because I wanted to leave after the way I was treated. I also expected a refund for not getting a complete session today, due to the neglect, and the fact I won't be returning for my last, she had failed to do that. She was even screaming from the door, and continued to until my boyfriend and I were down the steps. I have never in my life been more appalled by a grown woman's behavior, who claims to be in the business for \\\"10 years.\\\" Admit your wrongs, but don't make your guest feel unwelcome because you can't do you job properly.\",\"date\":\"2018-01-30 23:07:38\"}\n" | |||
"{\"review_id\":\"fdiNeiN_hoCxCMy2wTRW9g\",\"user_id\":\"w31MKYsNFMrjhWxxAb5wIw\",\"business_id\":\"eU_713ec6fTGNO4BegRaww\",\"stars\":4.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"I'll be the first to admit that I was not excited about going to La Tavolta. Being a food snob, when a group of friends suggested we go for dinner I looked online at the menu and to me there was nothing special and it seemed overpriced. Im also not big on ordering pasta when I go out. Alas, I was outnumbered. Thank goodness! I ordered the sea bass special. It was to die for. Cooked perfectly, seasoned perfectly, perfect portion. I can not say enough good things about this dish. When the server asked how it was he seemed very proud of the dish and said, \\\" doesn't she (the chef) do an incredible job?\\\" She does. \\n\\nMy hubby got the crab tortellini and also loved his. I heard \\\"mmmm this is so good\\\" from all around the table. Our waiter was super nice and even gave us free desserts because we were some of the last people in the restaurant. Service was very slow and the place was PACKED but we had our jugs of wine and a large group with good conversation so it didn't seem to bother anyone.\\n\\nSo-\\n\\nDo order the calamari and fried zucchini appetizers. Leave out the mussels. \\n\\nIf they have the sea bass special, I highly recommend it. The chicken parm and crab tortellini were also very good and very big. The chicken Romano was a bit bland. The house salads were teeny. \\n\\nDo make a reservation but still expect to wait for your food. Go with a large group of people and plan for it to be loud. Don't go with a date unless you're fighting and don't feel like hearing anything they have to say. Ask to sit in the side room if it's available.\",\"date\":\"2013-01-20 13:25:59\"}\n" | |||
"{\"review_id\":\"G7XHMxG0bx9oBJNECG4IFg\",\"user_id\":\"jlu4CztcSxrKx56ba1a5AQ\",\"business_id\":\"3fw2X5bZYeW9xCz_zGhOHg\",\"stars\":3.0,\"useful\":5,\"funny\":4,\"cool\":5,\"text\":\"Tracy dessert had a big name in Hong Kong and the one in First Markham place has been here for many years now! \\n\\nCame in for some Chinese dessert, and I must say their selection has increased tremendously over the years. I might as well add that the price has also increased tremendously as well. The waitress gave us tea, which I could taste had red date in it. Fancy!\\n\\nA simple taro with coconut with tapioca pearls was like $5.25 or something. Basically all the desserts were more than $5. That's crazy! I can literally just make this dessert at home and for a bowl, it would probably cost like $0.50. A few years ago, I think I can still get it for like $3-$4, which is more reasonable, but wow, more than $5 is a little over the top for this dessert. Though I must say, it is Tracy Dessert, and they are a little more on the expensive side. \\n\\nI also saw other items on the menu like fish balls, chicken wings, shaved ice. My friend got a mango drink with fresh mango in it! \\n\\nI'm also surprised how many people come to Tracy Dessert after work. We came on a Sunday and the tables were always filled. I think the amount of tables they had were just perfect because no one really waited for seats for a long time, but the tables kept filling up once a table was finished.\",\"date\":\"2016-05-07 01:21:02\"}\n" | |||
"{\"review_id\":\"8e9HxxLjjqc9ez5ezzN7iQ\",\"user_id\":\"d6xvYpyzcfbF_AZ8vMB7QA\",\"business_id\":\"zvO-PJCpNk4fgAVUnExYAA\",\"stars\":1.0,\"useful\":3,\"funny\":1,\"cool\":1,\"text\":\"This place has gone down hill. Clearly they have cut back on staff and food quality\\n\\nMany of the reviews were written before the menu changed. I've been going for years and the food quality has gone down hill.\\n\\nThe service is slow & my salad, which was $15, was as bad as it gets.\\n\\nIt's just not worth spending the money on this place when there are so many other options.\",\"date\":\"2010-10-05 19:12:35\"}\n" | |||
"{\"review_id\":\"qrffudO73zsslZbe8B9D3Q\",\"user_id\":\"sG_h0dIzTKWa3Q6fmb4u-g\",\"business_id\":\"b2jN2mm9Wf3RcrZCgfo1cg\",\"stars\":2.0,\"useful\":1,\"funny\":0,\"cool\":0,\"text\":\"I was really looking forward to visiting after having some of their beers. The \\\"Man O'War\\\" quickly became my favorite DIPA; the Rusulka Vanilla Stout is a good thick, sweet stout; and the Ironclad is a top notch IPA. \\nThe only big miss on their beers I've had is the Big Chuck Barleywine. It could probably benefit greatly with age, but at this age all there is to taste is the alcohol. \\nNonetheless, I had enough to convince me that the other beers I hadn't had from them would be top notch... and they are! \\nThe reason for the 2 stars should not reflect the quality of the brewers, they obviously know their craft well! \\nThe servers are great and friendly.... but relying on two servers to wait on 100+ customers says a lot about how inexperienced management must be. In fact, after waiting 15 mins at a dirty table I was finally able to track down someone I guessed was an employee to let them know we were even there! \\nAfter another 5+ mins, the GM finally stopped over to take our drink order. The smugness of this guy was amazing. The thought of offering a simple apology never seemed to enter into his head. \\nThis is the time a server finally stopped by to pick up the non-final check left by the party before us... who didn't seem very pleased when leaving. \\nThe toast & cheese was good, but by the time we were able to dig into their heartiest offering of food, saltines and butter may have been equally pleasing.\",\"date\":\"2015-01-18 14:04:18\"}\n" | |||
"{\"review_id\":\"RS_GTIT6836bCaPy637kNQ\",\"user_id\":\"nMeCE5-xsdleyxYuNZ_7rA\",\"business_id\":\"oxwGyA17NL6c5t1Etg5WgQ\",\"stars\":3.0,\"useful\":1,\"funny\":0,\"cool\":1,\"text\":\"It's a giant Best Buy with 66 registers. I don't get it. What's the big deal about this place??\",\"date\":\"2012-02-29 21:52:43\"}\n" | |||
"{\"review_id\":\"kbtscdyz6lvrtGjD1quQTg\",\"user_id\":\"FIk4lQQu1eTe2EpzQ4xhBA\",\"business_id\":\"8mIrX_LrOnAqWsB5JrOojQ\",\"stars\":4.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"Like walking back in time, every Saturday morning my sister and I was in a bowling league and after we were done, we'd spend a few quarters playing the pin ball machines until our mother came to pick us up.\\n\\nMy sister was daring and play the machines hard, she was afraid of that \\\"tilt\\\" showing up and freezing the game. I, on the other hand was a bit more gentler and wanted to make sure I got my quarter's worth.\\n\\nThis place has rows and rows of machines, some are really old and some are more of a mid 80's theme. There is even a Ms pac man! It was fun to spend an afternoon playing the machines and remembering all the fun of my early teen years.\",\"date\":\"2011-11-30 02:11:15\"}\n" | |||
"{\"review_id\":\"-I5umRTkhw15RqpKMl_o1Q\",\"user_id\":\"-mA3-1mN4JIEkqOtdbNXCQ\",\"business_id\":\"mRUVMJkUGxrByzMQ2MuOpA\",\"stars\":1.0,\"useful\":0,\"funny\":1,\"cool\":0,\"text\":\"Walked in around 4 on a Friday afternoon, we sat at a table just off the bar and walked out after 5 min or so. Don't even think they realized we walked in. However everyone at the bar noticed we walked in!!! Service was non existent at best. Not a good way for a new business to start out. Oh well, the location they are at has been about 5 different things over the past several years, so they will just be added to the list. SMDH!!!\",\"date\":\"2017-12-15 23:27:08\"}\n" | |||
"{\"review_id\":\"Z7wgXp98wYB57QdRY3HQ3w\",\"user_id\":\"GYNnVehQeXjty0xH7-6Fhw\",\"business_id\":\"FxLfqxdYPA6Z85PFKaqLrg\",\"stars\":4.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"Wow. So surprised at the one and two star reviews! We started with the most tender calamari. Although the marinara sauce was a bit bland, but a touch of salt made it just right. My husband had the veal with peppers and said it was so delicious and tender. The mashed potatoes were perfect. I had the salmon Diablo which was also delicious. Our salad was beautiful! Dressing was served on the salad and it was a nice amount. We ended our delicious meal with a piece of tiramisu. Our server Matt was right on!! Very pleasant and knowledgeable about the menu. Our appetizer, salad and entrees were timed perfectly. I love salad and did not mind that my entree was served while I was still eating it! No problem it let my dinner cool to just the right temp for me to eat it comfortably. \\nI wonder sometimes if people just don't appreciate relaxing and taking time to eat a wonderful and beautifully prepared meal. A wonderful atmosphere. So relaxing. The chairs are super comfortable too!!! We will certainly be back. \\nGive it a try. Don't always go by the reviews. \\nA bottle of Riesling, calamari app, two delicious entrees and dessert for $92! \\nWell with it.\",\"date\":\"2016-05-07 01:36:53\"}\n" | |||
"{\"review_id\":\"qlXw1JQ0UodW7qrmVgwCXw\",\"user_id\":\"bAhqAPoWaZYcyYi7bs024Q\",\"business_id\":\"LUN6swQYa4xJKaM_UEUOEw\",\"stars\":4.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"Michael from Red Carpet VIP is amazing ! I reached out because I needed help planning my soon to be sister in law's bachelorette. It was a group of 10 girls so I was a little overwhelmed but Michael saved the day! Everything was super smooth and easy! We got good deals and had the best time ever! We booked hotel and a bachelorette package for a great price. I have saved contact info because I will for sure reach out again on next Vegas trip!!!\",\"date\":\"2018-04-27 20:25:26\"}\n" | |||
"{\"review_id\":\"JVcjMhlavKKn3UIt9p9OXA\",\"user_id\":\"TpyOT5E16YASd7EWjLQlrw\",\"business_id\":\"AakkkTuGZA2KBodKi2_u8A\",\"stars\":1.0,\"useful\":1,\"funny\":1,\"cool\":0,\"text\":\"I cannot believe how things have changed in 3 years. I picked up duck congee sometime in the winter when my hubby was sick. I was very disappointed because the ginger fish sauce tasted like it had gone bad (it should never be bitter). Today, my hubby wanted to eat there since he was craving the duck congee and most places don't serve the duck & coleslaw side. We waited about 10 minutes to get our menu. After we placed our orders, we waited another 5 minutes to get the tea that most places bring with the menu. I could go on with the details but the gist of the story is they were understaffed or the staff was slow. The worst part of it was that the service. The servers make us feel bad for asking for anything (like when they took our order). We had arrived and placed our order before another couple bside us at least 10 minutes ahead but somehow, this couple received their pho before mine. They were almost done eating their pho before mine came out.\",\"date\":\"2012-07-16 00:37:14\"}\n" | |||
"{\"review_id\":\"svK3nBU7Rk8VfGorlrN52A\",\"user_id\":\"NJlxGtouq06hhC7sS2ECYw\",\"business_id\":\"YvrylyuWgbP90RgMqZQVnQ\",\"stars\":5.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"You can't really find anything wrong with this place, the pastas and pizzas are both amazing and high quality, the price is very reasonable, the owner and the staff are very friendly, if you're in downtown check this place out, a lot of people think just because it's downtown there are lots of options around but that's not always the case as there is also a lot of poor quality food in downtown as well.\",\"date\":\"2017-04-07 21:27:49\"}\n" | |||
"{\"review_id\":\"1wVA2-vQIuW_ClmXkDxqMQ\",\"user_id\":\"86J5DwcFk4f4In1Vxe2TvA\",\"business_id\":\"NyLYY8q1-H3hfsTwuwLPCg\",\"stars\":4.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"Great lunch today. Staff was very helpful in assisting with selections and knowledgeable on the ingredients. We enjoyed the BBQ chicken with tika masala sauce and really good naan bread. The biryani with chicken was also yummy! Fun to see the food being prepared in the tandoori ovens. Great addition to the fast casual scene in Cleveland.\",\"date\":\"2015-01-03 22:47:34\"}\n" | |||
"{\"review_id\":\"6BnQwlxRn7ZuWdzninM9sQ\",\"user_id\":\"JSrP-dUmLlwZiI7Dp3PQ2A\",\"business_id\":\"cHdJXLlKNWixBXpDwEGb_A\",\"stars\":3.0,\"useful\":1,\"funny\":7,\"cool\":1,\"text\":\"I love chinese food and I love mexican food. What can go wrong? A couple of things. First things first, this place is more of a \\\"rice bowl\\\" kind of place. I thought it was going to be more diverse as far as the menu goes, but its mainly rice bowls you get with different kinds of meats. The ordering was a little confusing at first, but one of the employees helped us out and I got the 2-item bowl and got the jade chicken and hengrenade chicken with all rice(jerk). I also ordered a jade chicken quesadilla on the side.\\n\\nI'm gonna admit, this place looks kinda dirty. I don't think Arizona uses those health department letter grade system like California does, but if I were to just judge by how it looked inside, i'd give it a \\\"C\\\" grade lol. We waited for about 15 minutes or so and finally got our food. We took it to go and ate at our hotel room. \\n\\nMmmm... the food was just alright. The jade chicken was nothing special. It tasted like any generic chinese fast food orange chicken\\/sesame chicken variant. The hengrenade chicken, although was the less spicier version of the jerk chicken, was still pretty spicy for me. Just be warned the jerk chicken is super spicy. If you aren't sure, ask for a sample at the restaurant before ordering, but it was way too spicy for me. \\n\\nThe jade chicken quesadilla was decent, but nothing special. Just imagine orange chicken in between a tortilla and cheese. A friend of mine ordered a jade chicken burrito and we were confused when we pulled it out of the bag because it was literally the size of Mcdonald's apple pie. If you order the burrito, be warned that it's a burrito for gnomes and smurfs, but he said it was tasty. \\n\\nThey provide a snicker doodle sugar cookie for each meal and it was decent, again nothing special. \\n\\nNot gonna lie, the next day my stomach felt like a little mexican dude and chinese dude were wrestling and throwing molotov cocktails inside. I used the bathroom like 5 times. I don't recommend eating this place if you have a lot to do the next day.\",\"date\":\"2015-04-01 16:30:00\"}\n" | |||
"{\"review_id\":\"rEITo90tpyKmEfNDp3Ou3A\",\"user_id\":\"6Fz_nus_OG4gar721OKgZA\",\"business_id\":\"6lj2BJ4tJeu7db5asGHQ4w\",\"stars\":5.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"We've been a huge Slim's fan since they opened one up in Texas about two years ago when we used to live there. This place never disappoints. They even have great salads and grilled chicken. Plus they have fresh brewed sweet tea, it's the best!\",\"date\":\"2017-05-26 01:23:19\"}\n" | |||
"{\"review_id\":\"4bUyL7lzoWzDZaJETAKREg\",\"user_id\":\"_N7Ndn29bpll_961oPeEfw\",\"business_id\":\"y-Iw6dZflNix4BdwIyTNGA\",\"stars\":3.0,\"useful\":0,\"funny\":0,\"cool\":0,\"text\":\"Good selection of classes of beers and mains. I've been here twice.\\n\\nFirst time I had the fried chicken. It was delicious, but be warned, extremely salty. I couldn't even finish the last piece of chicken after experiencing a salt overload.\\n\\nSecond time we came on a wednesday. We didn't know it was BBQ night, where they have a completely different menu, and don't offer anything from their original vegetarian-friendly menu. This menu has one vegetarian-friendly option - an eggplant sandwich. The vegetarian in my party said it was awful. Also, on BBQ night you choose 2 sides. Except they were out of all their sides except 2 - fries and potato salad. I can't say I was thrilled to have carb heavy sides with my carb heavy main. How do you run out of sides so early in the evening?\\n\\nService not so great.\\n\\nI'd avoid coming here on wednesdays.\",\"date\":\"2014-06-27 21:19:23\"}\n" |
@@ -0,0 +1,10 @@ | |||
import unittest | |||
from reproduction.text_classification.data.MTL16Loader import MTL16Loader | |||
class TestDataLoader(unittest.TestCase): | |||
def test_MTL16Loader(self): | |||
loader = MTL16Loader() | |||
data = loader.process('sample_MTL16.txt') | |||
print(data.datasets) | |||
@@ -0,0 +1,7 @@ | |||
import unittest | |||
from reproduction.text_classification.data.yelpLoader import yelpLoader | |||
class TestDatasetLoader(unittest.TestCase): | |||
def test_yelpLoader(self): | |||
ds = yelpLoader().load('sample_yelp.json') | |||
assert len(ds) == 20 |
@@ -0,0 +1,52 @@ | |||
import os | |||
from typing import Union, Dict | |||
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||
""" | |||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||
{ | |||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||
'test': 'xxx' # 可能有,也可能没有 | |||
... | |||
} | |||
如果paths为不合法的,将直接进行raise相应的错误 | |||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train.txt, | |||
test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
:return: | |||
""" | |||
if isinstance(paths, str): | |||
if os.path.isfile(paths): | |||
return {'train': paths} | |||
elif os.path.isdir(paths): | |||
train_fp = os.path.join(paths, 'train.txt') | |||
if not os.path.isfile(train_fp): | |||
raise FileNotFoundError(f"train.txt is not found in folder {paths}.") | |||
files = {'train': train_fp} | |||
for filename in ['dev.txt', 'test.txt']: | |||
fp = os.path.join(paths, filename) | |||
if os.path.isfile(fp): | |||
files[filename.split('.')[0]] = fp | |||
return files | |||
else: | |||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||
elif isinstance(paths, dict): | |||
if paths: | |||
if 'train' not in paths: | |||
raise KeyError("You have to include `train` in your dict.") | |||
for key, value in paths.items(): | |||
if isinstance(key, str) and isinstance(value, str): | |||
if not os.path.isfile(value): | |||
raise TypeError(f"{value} is not a valid file.") | |||
else: | |||
raise TypeError("All keys and values in paths should be str.") | |||
return paths | |||
else: | |||
raise ValueError("Empty paths is not allowed.") | |||
else: | |||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | |||
@@ -1,4 +1,5 @@ | |||
numpy | |||
torch>=0.4.0 | |||
tqdm | |||
nltk | |||
nltk | |||
requests |
@@ -3,7 +3,7 @@ import unittest | |||
import numpy as np | |||
import torch | |||
from fastNLP import Batch | |||
from fastNLP import DataSetIter | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import SequentialSampler | |||
@@ -57,7 +57,7 @@ class TestCase1(unittest.TestCase): | |||
dataset = construct_dataset( | |||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
dataset.set_target() | |||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
batch = DataSetIter(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
cnt = 0 | |||
for _, _ in batch: | |||
@@ -68,7 +68,7 @@ class TestCase1(unittest.TestCase): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||
self.assertEqual(len(x["x"]), 4) | |||
@@ -81,7 +81,7 @@ class TestCase1(unittest.TestCase): | |||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
for x, y in iter: | |||
self.assertEqual(x["x"].shape, (4, 4)) | |||
self.assertEqual(y["y"].shape, (4, 4)) | |||
@@ -91,7 +91,7 @@ class TestCase1(unittest.TestCase): | |||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
for x, y in iter: | |||
self.assertEqual(x["x"].shape, (4, 4)) | |||
self.assertEqual(y["y"].shape, (4, 4)) | |||
@@ -101,7 +101,7 @@ class TestCase1(unittest.TestCase): | |||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
@@ -113,7 +113,7 @@ class TestCase1(unittest.TestCase): | |||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
@@ -125,7 +125,7 @@ class TestCase1(unittest.TestCase): | |||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
@@ -137,7 +137,7 @@ class TestCase1(unittest.TestCase): | |||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
print(x, y) | |||
@@ -146,7 +146,7 @@ class TestCase1(unittest.TestCase): | |||
num_samples = 1000 | |||
dataset = generate_fake_dataset(num_samples) | |||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_x, batch_y in batch: | |||
pass | |||
@@ -12,6 +12,7 @@ from fastNLP import AccuracyMetric | |||
from fastNLP import SGD | |||
from fastNLP import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier | |||
from fastNLP.core.callback import EarlyStopError | |||
def prepare_env(): | |||
@@ -39,89 +40,50 @@ class TestCallback(unittest.TestCase): | |||
def test_gradient_clip(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=20, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)], check_code_level=2) | |||
trainer.train() | |||
def test_early_stop(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=20, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.01), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[EarlyStopCallback(5)]) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.01), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=20, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
callbacks=[EarlyStopCallback(5)], check_code_level=2) | |||
trainer.train() | |||
def test_lr_scheduler(self): | |||
data_set, model = prepare_env() | |||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=5, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=optimizer, | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||
trainer = Trainer(data_set, model, optimizer=optimizer, loss=BCELoss(pred="predict", target="y"), batch_size=32, | |||
n_epochs=5, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))], | |||
check_code_level=2) | |||
trainer.train() | |||
def test_KeyBoardInterrupt(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=5, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
callbacks=[ControlC(False)]) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, callbacks=[ControlC(False)], | |||
check_code_level=2) | |||
trainer.train() | |||
def test_LRFinder(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=5, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
callbacks=[LRFinder(len(data_set) // 32)]) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=5, print_every=50, use_tqdm=False, | |||
callbacks=[LRFinder(len(data_set) // 32)], check_code_level=2) | |||
trainer.train() | |||
def test_TensorboardCallback(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=5, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[TensorboardCallback("loss", "metric")]) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=5, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, | |||
callbacks=[TensorboardCallback("loss", "metric")], check_code_level=2) | |||
trainer.train() | |||
def test_readonly_property(self): | |||
@@ -140,16 +102,9 @@ class TestCallback(unittest.TestCase): | |||
print(self.optimizer) | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=total_epochs, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[MyCallback()]) | |||
trainer = Trainer(data_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=total_epochs, print_every=50, dev_data=data_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), use_tqdm=False, callbacks=[MyCallback()], | |||
check_code_level=2) | |||
trainer.train() | |||
assert passed_epochs == list(range(1, total_epochs + 1)) |
@@ -1,8 +1,55 @@ | |||
import unittest | |||
import numpy as np | |||
import torch | |||
from fastNLP import FieldArray | |||
from fastNLP.core.field import _get_ele_type_and_dim | |||
from fastNLP import AutoPadder | |||
class TestFieldArrayTyepDimDetect(unittest.TestCase): | |||
""" | |||
检测FieldArray能否正确识别type与ndim | |||
""" | |||
def test_case1(self): | |||
# 1.1 常规类型测试 | |||
for value in [1, True, 1.0, 'abc']: | |||
type_ = type(value) | |||
_type, _dim = _get_ele_type_and_dim(cell=value) | |||
self.assertListEqual([_type, _dim], [type_, 0]) | |||
# 1.2 mix类型报错 | |||
with self.assertRaises(Exception): | |||
value = [1, 2, 1.0] | |||
self.assertRaises(_get_ele_type_and_dim(value)) | |||
# 带有numpy的测试 | |||
# 2.1 | |||
value = np.array([1, 2, 3]) | |||
type_ = value.dtype | |||
dim_ = 1 | |||
self.assertSequenceEqual(_get_ele_type_and_dim(cell=value), [type_, dim_]) | |||
# 2.2 | |||
value = np.array([[1, 2], [3, 4, 5]]) # char embedding的场景 | |||
self.assertSequenceEqual([int, 2], _get_ele_type_and_dim(value)) | |||
# 2.3 | |||
value = np.zeros((3, 4)) | |||
self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value)) | |||
# 2.4 测试错误的dimension | |||
with self.assertRaises(Exception): | |||
value = np.array([[1, 2], [3, [1]]]) | |||
_get_ele_type_and_dim(value) | |||
# 2.5 测试混合类型 | |||
with self.assertRaises(Exception): | |||
value = np.array([[1, 2], [3.0]]) | |||
_get_ele_type_and_dim(value) | |||
# 带有tensor的测试 | |||
# 3.1 word embedding的场景 | |||
value = torch.zeros(3, 10) | |||
self.assertSequenceEqual([value.dtype, 2], _get_ele_type_and_dim(value)) | |||
# 3.2 char embedding/image的场景 | |||
value = torch.zeros(3, 32, 32) | |||
self.assertSequenceEqual([value.dtype, 3], _get_ele_type_and_dim(value)) | |||
class TestFieldArrayInit(unittest.TestCase): | |||
@@ -31,12 +78,6 @@ class TestFieldArrayInit(unittest.TestCase): | |||
# 三维list | |||
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) | |||
def test_init_v7(self): | |||
# list of array | |||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) | |||
self.assertEqual(fa.pytype, int) | |||
self.assertEqual(fa.dtype, np.int) | |||
def test_init_v4(self): | |||
# 一维list | |||
val = [1, 2, 3, 4] | |||
@@ -56,6 +97,11 @@ class TestFieldArrayInit(unittest.TestCase): | |||
fa.append(val) | |||
def test_init_v7(self): | |||
# list of array | |||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) | |||
self.assertEqual(fa.dtype, np.array([1]).dtype) | |||
def test_init_v8(self): | |||
# 二维list | |||
val = np.array([[1, 2], [3, 4]]) | |||
fa = FieldArray("x", [val], is_input=True) | |||
@@ -79,33 +125,23 @@ class TestFieldArray(unittest.TestCase): | |||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | |||
def test_type_conversion(self): | |||
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True) | |||
self.assertEqual(fa.pytype, float) | |||
self.assertEqual(fa.dtype, np.float64) | |||
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) | |||
fa.append(1.3333) | |||
self.assertEqual(fa.pytype, float) | |||
self.assertEqual(fa.dtype, np.float64) | |||
self.assertEqual(fa.dtype, int) | |||
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) | |||
fa.append(10) | |||
self.assertEqual(fa.pytype, float) | |||
self.assertEqual(fa.dtype, np.float64) | |||
fa.append(10.0) | |||
self.assertEqual(fa.dtype, float) | |||
fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True) | |||
fa.append("e") | |||
self.assertEqual(fa.dtype, np.str) | |||
self.assertEqual(fa.pytype, str) | |||
self.assertEqual(fa.dtype, str) | |||
def test_support_np_array(self): | |||
fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) | |||
self.assertEqual(fa.dtype, np.float64) | |||
self.assertEqual(fa.pytype, float) | |||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | |||
self.assertEqual(fa.dtype, np.float64) | |||
self.assertEqual(fa.pytype, float) | |||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | |||
# in this case, pytype is actually a float. We do not care about it. | |||
@@ -113,11 +149,10 @@ class TestFieldArray(unittest.TestCase): | |||
def test_nested_list(self): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True) | |||
self.assertEqual(fa.pytype, float) | |||
self.assertEqual(fa.dtype, np.float64) | |||
self.assertEqual(fa.dtype, float) | |||
def test_getitem_v1(self): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) | |||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
ans = fa[[0, 1]] | |||
self.assertTrue(isinstance(ans, np.ndarray)) | |||
@@ -150,7 +185,7 @@ class TestFieldArray(unittest.TestCase): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||
fa.append(["str", 0, 0, 0, 1.89]) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True) | |||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | |||
self.assertEqual(len(fa), 3) | |||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | |||
@@ -163,33 +198,86 @@ class TestFieldArray(unittest.TestCase): | |||
fa = FieldArray("y", [(1, "1"), (2, "2"), (3, "3"), (4, "4")], is_target=True, ignore_type=True) | |||
class TestPadder(unittest.TestCase): | |||
class TestAutoPadder(unittest.TestCase): | |||
def test00(self): | |||
padder = AutoPadder() | |||
# 没有类型时 | |||
contents = [(1, 2), ('str', 'a')] | |||
padder(contents, None, None, None) | |||
def test01(self): | |||
""" | |||
测试AutoPadder能否正常工作 | |||
:return: | |||
""" | |||
from fastNLP import AutoPadder | |||
# 测试使用多维的bool, int, str, float的情况 | |||
# str | |||
padder = AutoPadder() | |||
content = ['This is a str', 'this is another str'] | |||
self.assertListEqual(content, padder(content, None, np.str).tolist()) | |||
self.assertListEqual(content, padder(content, None, str, 0).tolist()) | |||
content = [1, 2] | |||
self.assertListEqual(content, padder(content, None, np.int64).tolist()) | |||
content = [[1,2], [3], [4]] | |||
self.assertListEqual([[1,2], [3, 0], [4, 0]], | |||
padder(content, None, np.int64).tolist()) | |||
# 1维int | |||
content = [[1, 2, 3], [4,], [5, 6, 7, 8]] | |||
padded_content = [[1, 2, 3, 0], [4, 0, 0, 0], [5, 6, 7, 8]] | |||
self.assertListEqual(padder(content, None, int, 1).tolist(), padded_content) | |||
# 二维int | |||
padded_content = [[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]] | |||
content = [ | |||
[[1, 2, 3], [4, 5], [7,8,9,10]], | |||
[[1]] | |||
] | |||
self.assertListEqual(content, | |||
padder(content, None, np.int64).tolist()) | |||
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], | |||
[[1]] | |||
] | |||
self.assertListEqual(padder(content, None, int, 2).tolist(), padded_content) | |||
# 3维图片 | |||
contents = [np.random.rand(3, 4, 4).tolist() for _ in range(5)] | |||
self.assertTrue(padder(contents, None, float, 3).shape==(5, 3, 4, 4)) | |||
# 更高维度直接返回 | |||
contents = [np.random.rand(24, 3, 4, 4).tolist() for _ in range(5)] | |||
self.assertTrue(isinstance(padder(contents, None, float, 4), np.ndarray)) | |||
def test02(self): | |||
padder = AutoPadder() | |||
# 测试numpy的情况 | |||
# 0维 | |||
contents = np.arange(12) | |||
self.assertListEqual(padder(contents, None, contents.dtype, 0).tolist(), contents.tolist()) | |||
# 1维 | |||
contents = np.arange(12).reshape((3, 4)) | |||
self.assertListEqual(padder(contents, None, contents.dtype, 1).tolist(), contents.tolist()) | |||
# 2维 | |||
contents = np.ones((3, 10, 5)) | |||
self.assertListEqual(padder(contents, None, contents.dtype, 2).tolist(), contents.tolist()) | |||
# 3维 | |||
contents = [np.random.rand(3, 4, 4) for _ in range(5)] | |||
l_contents = [content.tolist() for content in contents] | |||
self.assertListEqual(padder(contents, None, contents[0].dtype, 3).tolist(), l_contents) | |||
def test03(self): | |||
padder = AutoPadder() | |||
# 测试tensor的情况 | |||
# 0维 | |||
contents = torch.arange(12) | |||
r_contents = padder(contents, None, contents.dtype, 0) | |||
self.assertSequenceEqual(r_contents.tolist(), contents.tolist()) | |||
self.assertTrue(r_contents.dtype==contents.dtype) | |||
# 0维 | |||
contents = [torch.tensor(1) for _ in range(10)] | |||
self.assertSequenceEqual(padder(contents, None, torch.int64, 0).tolist(), contents) | |||
# 1维 | |||
contents = torch.randn(3, 4) | |||
padder(contents, None, torch.float64, 1) | |||
# 3维 | |||
contents = [torch.randn(3, 4, 4) for _ in range(5)] | |||
padder(contents, None, torch.float64, 3) | |||
class TestEngChar2DPadder(unittest.TestCase): | |||
def test01(self): | |||
""" | |||
测试EngChar2DPadder能不能正确使用 | |||
:return: | |||
@@ -198,38 +286,31 @@ class TestPadder(unittest.TestCase): | |||
padder = EngChar2DPadder(pad_length=0) | |||
contents = [1, 2] | |||
# 不能是1维 | |||
with self.assertRaises(ValueError): | |||
padder(contents, None, np.int64) | |||
# 不能是0维 | |||
with self.assertRaises(Exception): | |||
padder(contents, None, np.int64, 0) | |||
contents = [[1, 2]] | |||
# 不能是2维 | |||
with self.assertRaises(ValueError): | |||
padder(contents, None, np.int64) | |||
contents = [[[[1, 2]]]] | |||
# 不能是1维 | |||
with self.assertRaises(Exception): | |||
padder(contents, None, np.int64, 1) | |||
contents = [ | |||
[[[[1, 2]]]] | |||
] | |||
# 不能是3维以上 | |||
with self.assertRaises(ValueError): | |||
padder(contents, None, np.int64) | |||
with self.assertRaises(Exception): | |||
padder(contents, None, np.int64, 3) | |||
contents = [ | |||
[[1, 2, 3], [4, 5], [7,8,9,10]], | |||
[[1]] | |||
] | |||
self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]], | |||
padder(contents, None, np.int64).tolist()) | |||
padder(contents, None, np.int64, 2).tolist()) | |||
padder = EngChar2DPadder(pad_length=5, pad_val=-100) | |||
self.assertListEqual( | |||
[[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]], | |||
[[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]], | |||
padder(contents, None, np.int64).tolist() | |||
padder(contents, None, np.int64, 2).tolist() | |||
) | |||
def test_None_dtype(self): | |||
from fastNLP import AutoPadder | |||
padder = AutoPadder() | |||
content = [ | |||
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], | |||
[[1]] | |||
] | |||
ans = padder(content, None, None).tolist() | |||
self.assertListEqual(content, ans) |
@@ -161,7 +161,15 @@ class TestAccuracyMetric(unittest.TestCase): | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_duplicate(self): | |||
# 0.4.1的潜在bug,不能出现形参重复的情况 | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} | |||
target_dict = {'targets':torch.zeros(4, 3), 'target': 0} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
def test_seq_len(self): | |||
N = 256 | |||
seq_len = torch.zeros(N).long() | |||
@@ -46,18 +46,10 @@ class TrainerTestGround(unittest.TestCase): | |||
model = NaiveClassifier(2, 1) | |||
trainer = Trainer(train_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
n_epochs=10, | |||
batch_size=32, | |||
print_every=50, | |||
validate_every=-1, | |||
dev_data=dev_set, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=True, | |||
save_path=None) | |||
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"), | |||
batch_size=32, n_epochs=10, print_every=50, dev_data=dev_set, | |||
metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None, | |||
use_tqdm=True, check_code_level=2) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
@@ -83,10 +75,7 @@ class TrainerTestGround(unittest.TestCase): | |||
model = Model() | |||
with self.assertRaises(RuntimeError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model | |||
) | |||
trainer = Trainer(train_data=dataset, model=model) | |||
""" | |||
# 应该获取到的报错提示 | |||
NameError: | |||
@@ -116,12 +105,7 @@ class TrainerTestGround(unittest.TestCase): | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
@@ -147,12 +131,7 @@ class TrainerTestGround(unittest.TestCase): | |||
model = Model() | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
trainer.train() | |||
def test_trainer_suggestion4(self): | |||
@@ -175,12 +154,7 @@ class TrainerTestGround(unittest.TestCase): | |||
model = Model() | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
def test_trainer_suggestion5(self): | |||
# 检查报错提示能否正确提醒用户 | |||
@@ -203,12 +177,7 @@ class TrainerTestGround(unittest.TestCase): | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer = Trainer(train_data=dataset, model=model, print_every=2, use_tqdm=False) | |||
def test_trainer_suggestion6(self): | |||
# 检查报错提示能否正确提醒用户 | |||
@@ -233,14 +202,8 @@ class TrainerTestGround(unittest.TestCase): | |||
model = Model() | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
dev_data=dataset, | |||
loss=CrossEntropyLoss(), | |||
metrics=AccuracyMetric(), | |||
use_tqdm=False, | |||
print_every=2) | |||
trainer = Trainer(train_data=dataset, model=model, loss=CrossEntropyLoss(), print_every=2, dev_data=dataset, | |||
metrics=AccuracyMetric(), use_tqdm=False) | |||
""" | |||
def test_trainer_multiprocess(self): | |||
@@ -18,7 +18,7 @@ class Model(nn.Module): | |||
self.param = nn.Parameter(torch.zeros(0)) | |||
class TestMoveModelDeivce(unittest.TestCase): | |||
class TestMoveModelDevice(unittest.TestCase): | |||
def test_case1(self): | |||
# 测试str | |||
model = Model() | |||
@@ -237,6 +237,10 @@ class TestSeqLenToMask(unittest.TestCase): | |||
with self.assertRaises(AssertionError): | |||
mask = seq_len_to_mask(seq_len) | |||
# 3. pad到指定长度 | |||
seq_len = np.random.randint(1, 10, size=(10,)) | |||
mask = seq_len_to_mask(seq_len, 100) | |||
self.assertEqual(100, mask.shape[1]) | |||
def test_pytorch_seq_len(self): | |||
# 1. 随机测试 | |||
@@ -250,3 +254,8 @@ class TestSeqLenToMask(unittest.TestCase): | |||
seq_len = torch.randn(3, 4) | |||
with self.assertRaises(AssertionError): | |||
mask = seq_len_to_mask(seq_len) | |||
# 3. pad到指定长度 | |||
seq_len = torch.randint(1, 10, size=(10, )) | |||
mask = seq_len_to_mask(seq_len, 100) | |||
self.assertEqual(100, mask.size(1)) |
@@ -70,6 +70,24 @@ class TestAdd(unittest.TestCase): | |||
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||
vocab.index_dataset(dataset, field_name='char') | |||
def test_from_dataset_no_entry(self): | |||
# 测试能否正确将no_create_entry正确设置 | |||
dataset = DataSet() | |||
start_char = 65 | |||
num_samples = 10 | |||
test_dataset = DataSet() | |||
for i in range(num_samples): | |||
char = [chr(start_char + i)] * 6 | |||
ins = Instance(char=char) | |||
dataset.append(ins) | |||
ins = Instance(char=[c+c for c in char]) | |||
test_dataset.append(ins) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='char', no_create_entry_dataset=test_dataset) | |||
vocab.index_dataset(dataset, field_name='char') | |||
for i in range(num_samples): | |||
self.assertEqual(True, vocab._is_word_no_create_entry(chr(start_char + i)+chr(start_char + i))) | |||
class TestIndexing(unittest.TestCase): | |||
def test_len(self): | |||
@@ -100,13 +118,14 @@ class TestIndexing(unittest.TestCase): | |||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||
def test_iteration(self): | |||
vocab = Vocabulary() | |||
vocab = Vocabulary(padding=None, unknown=None) | |||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
"works", "well", "in", "most", "cases", "scales", "well"] | |||
vocab.update(text) | |||
text = set(text) | |||
for word in vocab: | |||
for word, idx in vocab: | |||
self.assertTrue(word in text) | |||
self.assertTrue(idx < len(vocab)) | |||
class TestOther(unittest.TestCase): | |||
@@ -1,6 +1,8 @@ | |||
import unittest | |||
import os | |||
from fastNLP.io import Conll2003Loader, PeopleDailyCorpusLoader, CSVLoader, SNLILoader, JsonLoader | |||
from fastNLP.io.dataset_loader import SSTLoader | |||
from reproduction.text_classification.data.yelpLoader import yelpLoader | |||
class TestDatasetLoader(unittest.TestCase): | |||
@@ -28,3 +30,34 @@ class TestDatasetLoader(unittest.TestCase): | |||
def test_JsonLoader(self): | |||
ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') | |||
assert len(ds) == 3 | |||
def test_SST(self): | |||
train_data = """(3 (2 (2 The) (2 Rock)) (4 (3 (2 is) (4 (2 destined) (2 (2 (2 (2 (2 to) (2 (2 be) (2 (2 the) (2 (2 21st) (2 (2 (2 Century) (2 's)) (2 (3 new) (2 (2 ``) (2 Conan)))))))) (2 '')) (2 and)) (3 (2 that) (3 (2 he) (3 (2 's) (3 (2 going) (3 (2 to) (4 (3 (2 make) (3 (3 (2 a) (3 splash)) (2 (2 even) (3 greater)))) (2 (2 than) (2 (2 (2 (2 (1 (2 Arnold) (2 Schwarzenegger)) (2 ,)) (2 (2 Jean-Claud) (2 (2 Van) (2 Damme)))) (2 or)) (2 (2 Steven) (2 Segal))))))))))))) (2 .))) | |||
(4 (4 (4 (2 The) (4 (3 gorgeously) (3 (2 elaborate) (2 continuation)))) (2 (2 (2 of) (2 ``)) (2 (2 The) (2 (2 (2 Lord) (2 (2 of) (2 (2 the) (2 Rings)))) (2 (2 '') (2 trilogy)))))) (2 (3 (2 (2 is) (2 (2 so) (2 huge))) (2 (2 that) (3 (2 (2 (2 a) (2 column)) (2 (2 of) (2 words))) (2 (2 (2 (2 can) (1 not)) (3 adequately)) (2 (2 describe) (2 (3 (2 (2 co-writer\/director) (2 (2 Peter) (3 (2 Jackson) (2 's)))) (3 (2 expanded) (2 vision))) (2 (2 of) (2 (2 (2 J.R.R.) (2 (2 Tolkien) (2 's))) (2 Middle-earth))))))))) (2 .))) | |||
(3 (3 (2 (2 (2 (2 (2 Singer\/composer) (2 (2 Bryan) (2 Adams))) (2 (2 contributes) (2 (2 (2 a) (2 slew)) (2 (2 of) (2 songs))))) (2 (2 --) (2 (2 (2 (2 a) (2 (2 few) (3 potential))) (2 (2 (2 hits) (2 ,)) (2 (2 (2 a) (2 few)) (1 (1 (2 more) (1 (2 simply) (2 intrusive))) (2 (2 to) (2 (2 the) (2 story))))))) (2 --)))) (2 but)) (3 (4 (2 the) (3 (2 whole) (2 package))) (2 (3 certainly) (3 (2 captures) (2 (1 (2 the) (2 (2 (2 intended) (2 (2 ,) (2 (2 er) (2 ,)))) (3 spirit))) (2 (2 of) (2 (2 the) (2 piece)))))))) (2 .)) | |||
(2 (2 (2 You) (2 (2 'd) (2 (2 think) (2 (2 by) (2 now))))) (2 (2 America) (2 (2 (2 would) (1 (2 have) (2 (2 (2 had) (1 (2 enough) (2 (2 of) (2 (2 plucky) (2 (2 British) (1 eccentrics)))))) (4 (2 with) (4 (3 hearts) (3 (2 of) (3 gold))))))) (2 .)))) | |||
""" | |||
test_data = """(3 (2 Yet) (3 (2 (2 the) (2 act)) (3 (4 (3 (2 is) (3 (2 still) (4 charming))) (2 here)) (2 .)))) | |||
(4 (2 (2 Whether) (2 (2 (2 (2 or) (1 not)) (3 (2 you) (2 (2 're) (3 (3 enlightened) (2 (2 by) (2 (2 any) (2 (2 of) (2 (2 Derrida) (2 's))))))))) (2 (2 lectures) (2 (2 on) (2 (2 ``) (2 (2 (2 (2 (2 (2 the) (2 other)) (2 '')) (2 and)) (2 ``)) (2 (2 the) (2 self)))))))) (3 (2 ,) (3 (2 '') (3 (2 Derrida) (3 (3 (2 is) (4 (2 an) (4 (4 (2 undeniably) (3 (4 (3 fascinating) (2 and)) (4 playful))) (2 fellow)))) (2 .)))))) | |||
(4 (3 (2 (2 Just) (2 (2 the) (2 labour))) (3 (2 involved) (3 (2 in) (4 (2 creating) (3 (3 (2 the) (3 (3 layered) (2 richness))) (3 (2 of) (3 (2 (2 the) (2 imagery)) (2 (2 in) (3 (2 (2 this) (2 chiaroscuro)) (2 (2 of) (2 (2 (2 madness) (2 and)) (2 light)))))))))))) (3 (3 (2 is) (4 astonishing)) (2 .))) | |||
(3 (3 (2 Part) (3 (2 of) (4 (2 (2 the) (3 charm)) (2 (2 of) (2 (2 Satin) (2 Rouge)))))) (3 (3 (2 is) (3 (2 that) (3 (2 it) (2 (1 (2 avoids) (2 (2 the) (1 obvious))) (3 (2 with) (3 (3 (3 humour) (2 and)) (2 lightness))))))) (2 .))) | |||
(4 (2 (2 a) (2 (2 screenplay) (2 more))) (3 (4 ingeniously) (2 (2 constructed) (2 (2 (2 (2 than) (2 ``)) (2 Memento)) (2 ''))))) | |||
(3 (2 ``) (3 (2 (2 Extreme) (2 Ops)) (3 (2 '') (4 (4 (3 exceeds) (2 expectations)) (2 .))))) | |||
""" | |||
train, test = 'train--', 'test--' | |||
with open(train, 'w', encoding='utf-8') as f: | |||
f.write(train_data) | |||
with open(test, 'w', encoding='utf-8') as f: | |||
f.write(test_data) | |||
loader = SSTLoader() | |||
info = loader.process( | |||
{train: train, test: test}, | |||
train_ds=[train], | |||
src_vocab_op=dict(min_freq=2) | |||
) | |||
assert len(list(info.vocabs.items())) == 2 | |||
assert len(list(info.datasets.items())) == 2 | |||
print(info.vocabs) | |||
print(info.datasets) | |||
os.remove(train), os.remove(test) |
@@ -130,11 +130,8 @@ class ModelRunner(): | |||
tester = Tester(data=data, model=model, metrics=metrics, | |||
batch_size=BATCH_SIZE, verbose=0) | |||
before_train = tester.test() | |||
trainer = Trainer(model=model, train_data=data, dev_data=None, | |||
n_epochs=N_EPOCHS, batch_size=BATCH_SIZE, | |||
loss=loss, | |||
save_path=None, | |||
use_tqdm=False) | |||
trainer = Trainer(train_data=data, model=model, loss=loss, batch_size=BATCH_SIZE, n_epochs=N_EPOCHS, | |||
dev_data=None, save_path=None, use_tqdm=False) | |||
trainer.train(load_best_model=False) | |||
after_train = tester.test() | |||
for metric_name, v1 in before_train.items(): | |||
@@ -2,20 +2,64 @@ import unittest | |||
import torch | |||
from fastNLP.models.bert import BertModel | |||
from fastNLP.models.bert import * | |||
class TestBert(unittest.TestCase): | |||
def test_bert_1(self): | |||
# model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese") | |||
model = BertModel(vocab_size=32000, hidden_size=768, | |||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
from fastNLP.core.const import Const | |||
model = BertForSequenceClassification(2) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||
pred = model(input_ids, token_type_ids, input_mask) | |||
self.assertTrue(isinstance(pred, dict)) | |||
self.assertTrue(Const.OUTPUT in pred) | |||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) | |||
def test_bert_2(self): | |||
from fastNLP.core.const import Const | |||
model = BertForMultipleChoice(2) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||
pred = model(input_ids, token_type_ids, input_mask) | |||
self.assertTrue(isinstance(pred, dict)) | |||
self.assertTrue(Const.OUTPUT in pred) | |||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2)) | |||
def test_bert_3(self): | |||
from fastNLP.core.const import Const | |||
model = BertForTokenClassification(7) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||
pred = model(input_ids, token_type_ids, input_mask) | |||
self.assertTrue(isinstance(pred, dict)) | |||
self.assertTrue(Const.OUTPUT in pred) | |||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7)) | |||
def test_bert_4(self): | |||
from fastNLP.core.const import Const | |||
model = BertForQuestionAnswering() | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||
for layer in all_encoder_layers: | |||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | |||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) | |||
pred = model(input_ids, token_type_ids, input_mask) | |||
self.assertTrue(isinstance(pred, dict)) | |||
self.assertTrue(Const.OUTPUTS(0) in pred) | |||
self.assertTrue(Const.OUTPUTS(1) in pred) | |||
self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3)) | |||
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3)) |
@@ -1,6 +1,5 @@ | |||
import unittest | |||
import fastNLP | |||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||
from .model_runner import * | |||
@@ -12,7 +12,6 @@ class TestCNNText(unittest.TestCase): | |||
model = CNNText(init_emb, | |||
NUM_CLS, | |||
kernel_nums=(1, 3, 5), | |||
kernel_sizes=(2, 2, 2), | |||
padding=0, | |||
kernel_sizes=(1, 3, 5), | |||
dropout=0.5) | |||
RUNNER.run_model_with_task(TEXT_CLS, model) |
@@ -10,14 +10,14 @@ class TestCRF(unittest.TestCase): | |||
id2label = {0: 'B', 1: 'I', 2:'O'} | |||
expected_res = {(0, 0), (0, 1), (0, 2), (0, 4), (1, 0), (1, 1), (1, 2), (1, 4), (2, 0), (2, 2), | |||
(2, 4), (3, 0), (3, 2)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||
id2label = {0: 'B', 1:'M', 2:'E', 3:'S'} | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 5), (3, 0), (3, 3), (3, 5), (4, 0), (4, 3)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
id2label = {0: 'B', 1: 'I', 2:'O', 3: '<pad>', 4:"<unk>"} | |||
allowed_transitions(id2label) | |||
allowed_transitions(id2label, include_start_end=True) | |||
labels = ['O'] | |||
for label in ['X', 'Y']: | |||
@@ -27,7 +27,7 @@ class TestCRF(unittest.TestCase): | |||
expected_res = {(0, 0), (0, 1), (0, 3), (0, 6), (1, 0), (1, 1), (1, 2), (1, 3), (1, 6), (2, 0), (2, 1), | |||
(2, 2), (2, 3), (2, 6), (3, 0), (3, 1), (3, 3), (3, 4), (3, 6), (4, 0), (4, 1), (4, 3), | |||
(4, 4), (4, 6), (5, 0), (5, 1), (5, 3)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label))) | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, include_start_end=True))) | |||
labels = [] | |||
for label in ['X', 'Y']: | |||
@@ -37,7 +37,7 @@ class TestCRF(unittest.TestCase): | |||
expected_res = {(0, 1), (0, 2), (1, 1), (1, 2), (2, 0), (2, 3), (2, 4), (2, 7), (2, 9), (3, 0), (3, 3), (3, 4), | |||
(3, 7), (3, 9), (4, 5), (4, 6), (5, 5), (5, 6), (6, 0), (6, 3), (6, 4), (6, 7), (6, 9), (7, 0), | |||
(7, 3), (7, 4), (7, 7), (7, 9), (8, 0), (8, 3), (8, 4), (8, 7)} | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES'))) | |||
self.assertSetEqual(expected_res, set(allowed_transitions(id2label, encoding_type='BMES', include_start_end=True))) | |||
def test_case2(self): | |||
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 | |||
@@ -0,0 +1,21 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.models.bert import BertModel | |||
class TestBert(unittest.TestCase): | |||
def test_bert_1(self): | |||
model = BertModel(vocab_size=32000, hidden_size=768, | |||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||
for layer in all_encoder_layers: | |||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | |||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) |
@@ -60,17 +60,17 @@ class TestTutorial(unittest.TestCase): | |||
print(test_data[0]) | |||
# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.batch import DataSetIter | |||
from fastNLP.core.sampler import RandomSampler | |||
batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||
batch_iterator = DataSetIter(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||
for batch_x, batch_y in batch_iterator: | |||
print("batch_x has: ", batch_x) | |||
print("batch_y has: ", batch_y) | |||
break | |||
from fastNLP.models import CNNText | |||
model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1) | |||
model = CNNText((len(vocab), 50), num_classes=5, dropout=0.1) | |||
from fastNLP import Trainer | |||
from copy import deepcopy | |||
@@ -80,23 +80,19 @@ class TestTutorial(unittest.TestCase): | |||
test_data.rename_field('label', 'label_seq') | |||
loss = CrossEntropyLoss(pred="output", target="label_seq") | |||
metric = AccuracyMetric(pred="predict", target="label_seq") | |||
metric = AccuracyMetric(target="label_seq") | |||
# 实例化Trainer,传入模型和数据,进行训练 | |||
# 先在test_data拟合(确保模型的实现是正确的) | |||
copy_model = deepcopy(model) | |||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||
loss=loss, | |||
metrics=metric, | |||
save_path=None, | |||
batch_size=32, | |||
n_epochs=5) | |||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, loss=loss, batch_size=32, n_epochs=5, | |||
dev_data=test_data, metrics=metric, save_path=None) | |||
overfit_trainer.train() | |||
# 用train_data训练,在test_data验证 | |||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
metrics=AccuracyMetric(target="label_seq"), | |||
save_path=None, | |||
batch_size=32, | |||
n_epochs=5) | |||
@@ -106,7 +102,7 @@ class TestTutorial(unittest.TestCase): | |||
# 调用Tester在test_data上评价效果 | |||
from fastNLP import Tester | |||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(target="label_seq"), | |||
batch_size=4) | |||
acc = tester.test() | |||
print(acc) | |||
@@ -143,17 +139,12 @@ class TestTutorial(unittest.TestCase): | |||
is_input=True) | |||
from fastNLP.models import CNNText | |||
model = CNNText((len(vocab), 50), num_classes=5, padding=2, dropout=0.1) | |||
model = CNNText((len(vocab), 50), num_classes=5, dropout=0.1) | |||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric, Adam | |||
trainer = Trainer(model=model, | |||
train_data=train_data, | |||
dev_data=dev_data, | |||
loss=CrossEntropyLoss(), | |||
optimizer= Adam(), | |||
metrics=AccuracyMetric(target='target') | |||
) | |||
trainer = Trainer(train_data=train_data, model=model, optimizer=Adam(), loss=CrossEntropyLoss(), | |||
dev_data=dev_data, metrics=AccuracyMetric(target='target')) | |||
trainer.train() | |||
print('Train finished!') | |||