@@ -438,26 +438,29 @@ class EarlyStopCallback(Callback): | |||||
class FitlogCallback(Callback): | class FitlogCallback(Callback): | ||||
""" | """ | ||||
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | |||||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||||
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` | |||||
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | |||||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||||
:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | :param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | ||||
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 | DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 | ||||
dict的方式传入。如果仅传入DataSet, 则被命名为test | dict的方式传入。如果仅传入DataSet, 则被命名为test | ||||
:param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` | :param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` | ||||
:param int verbose: 是否在终端打印内容,0不打印 | |||||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||||
:param bool log_exception: fitlog是否记录发生的exception信息 | :param bool log_exception: fitlog是否记录发生的exception信息 | ||||
""" | """ | ||||
# 还没有被导出到 fastNLP 层 | |||||
# 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` | |||||
def __init__(self, data=None, tester=None, verbose=0, log_exception=False): | |||||
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): | |||||
super().__init__() | super().__init__() | ||||
self.datasets = {} | self.datasets = {} | ||||
self.testers = {} | self.testers = {} | ||||
self._log_exception = log_exception | self._log_exception = log_exception | ||||
assert isinstance(log_loss_every, int) and log_loss_every>=0 | |||||
if tester is not None: | if tester is not None: | ||||
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." | assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." | ||||
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." | assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." | ||||
@@ -477,7 +480,9 @@ class FitlogCallback(Callback): | |||||
raise TypeError("data receives dict[DataSet] or DataSet object.") | raise TypeError("data receives dict[DataSet] or DataSet object.") | ||||
self.verbose = verbose | self.verbose = verbose | ||||
self._log_loss_every = log_loss_every | |||||
self._avg_loss = 0 | |||||
def on_train_begin(self): | def on_train_begin(self): | ||||
if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None: | if (len(self.datasets) > 0 or len(self.testers) > 0) and self.trainer.dev_data is None: | ||||
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") | raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") | ||||
@@ -490,8 +495,12 @@ class FitlogCallback(Callback): | |||||
fitlog.add_progress(total_steps=self.n_steps) | fitlog.add_progress(total_steps=self.n_steps) | ||||
def on_backward_begin(self, loss): | def on_backward_begin(self, loss): | ||||
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) | |||||
if self._log_loss_every>0: | |||||
self._avg_loss += loss.item() | |||||
if self.step%self._log_loss_every==0: | |||||
fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch) | |||||
self._avg_loss = 0 | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | ||||
if better_result: | if better_result: | ||||
eval_result = deepcopy(eval_result) | eval_result = deepcopy(eval_result) | ||||
@@ -518,7 +527,7 @@ class FitlogCallback(Callback): | |||||
def on_exception(self, exception): | def on_exception(self, exception): | ||||
fitlog.finish(status=1) | fitlog.finish(status=1) | ||||
if self._log_exception: | if self._log_exception: | ||||
fitlog.add_other(str(exception), name='except_info') | |||||
fitlog.add_other(repr(exception), name='except_info') | |||||
class LRScheduler(Callback): | class LRScheduler(Callback): | ||||
@@ -516,7 +516,7 @@ class EngChar2DPadder(Padder): | |||||
)) | )) | ||||
self._exactly_three_dims(contents, field_name) | self._exactly_three_dims(contents, field_name) | ||||
if self.pad_length < 1: | if self.pad_length < 1: | ||||
max_char_length = max(max([[len(char_lst) for char_lst in word_lst] for word_lst in contents])) | |||||
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) | |||||
else: | else: | ||||
max_char_length = self.pad_length | max_char_length = self.pad_length | ||||
max_sent_length = max(len(word_lst) for word_lst in contents) | max_sent_length = max(len(word_lst) for word_lst in contents) | ||||
@@ -476,8 +476,8 @@ class SpanFPreRecMetric(MetricBase): | |||||
label的f1, pre, rec | label的f1, pre, rec | ||||
:param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | :param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro': | ||||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | ||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`. | |||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
""" | """ | ||||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | ||||
@@ -708,8 +708,8 @@ class SQuADMetric(MetricBase): | |||||
:param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` | :param pred2: 参数映射表中`pred2`的映射关系,None表示映射关系为`pred2`->`pred2` | ||||
:param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` | :param target1: 参数映射表中`target1`的映射关系,None表示映射关系为`target1`->`target1` | ||||
:param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` | :param target2: 参数映射表中`target2`的映射关系,None表示映射关系为`target2`->`target2` | ||||
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 | |||||
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param float beta: f_beta分数,:math:`f_beta = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}`. | |||||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | :param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | ||||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | ||||
@@ -494,12 +494,14 @@ class Trainer(object): | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
def train(self, load_best_model=True): | |||||
def train(self, load_best_model=True, on_exception='ignore'): | |||||
""" | """ | ||||
使用该函数使Trainer开始训练。 | 使用该函数使Trainer开始训练。 | ||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效, | |||||
如果True, trainer将在返回之前重新加载dev表现最好的模型参数。 | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||||
最好的模型参数。 | |||||
:param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | |||||
支持'ignore'与'raise': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出。 | |||||
:return dict: 返回一个字典类型的数据, | :return dict: 返回一个字典类型的数据, | ||||
内含以下内容:: | 内含以下内容:: | ||||
@@ -528,8 +530,10 @@ class Trainer(object): | |||||
self.callback_manager.on_train_begin() | self.callback_manager.on_train_begin() | ||||
self._train() | self._train() | ||||
self.callback_manager.on_train_end() | self.callback_manager.on_train_end() | ||||
except (CallbackException, KeyboardInterrupt) as e: | |||||
except (CallbackException, KeyboardInterrupt, Exception) as e: | |||||
self.callback_manager.on_exception(e) | self.callback_manager.on_exception(e) | ||||
if on_exception=='raise': | |||||
raise e | |||||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | ||||
print( | print( | ||||
@@ -3,7 +3,8 @@ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户 | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"cache_results", | "cache_results", | ||||
"seq_len_to_mask" | |||||
"seq_len_to_mask", | |||||
"Example", | |||||
] | ] | ||||
import _pickle | import _pickle | ||||
@@ -21,6 +22,32 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||||
'varargs']) | 'varargs']) | ||||
class Example(dict): | |||||
"""a dict can treat keys as attributes""" | |||||
def __getattr__(self, item): | |||||
try: | |||||
return self.__getitem__(item) | |||||
except KeyError: | |||||
raise AttributeError(item) | |||||
def __setattr__(self, key, value): | |||||
if key.startswith('__') and key.endswith('__'): | |||||
raise AttributeError(key) | |||||
self.__setitem__(key, value) | |||||
def __delattr__(self, item): | |||||
try: | |||||
self.pop(item) | |||||
except KeyError: | |||||
raise AttributeError(item) | |||||
def __getstate__(self): | |||||
return self | |||||
def __setstate__(self, state): | |||||
self.update(state) | |||||
def _prepare_cache_filepath(filepath): | def _prepare_cache_filepath(filepath): | ||||
""" | """ | ||||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | 检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | ||||
@@ -1,11 +1,26 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"Vocabulary" | |||||
"Vocabulary", | |||||
"VocabularyOption", | |||||
] | ] | ||||
from functools import wraps | from functools import wraps | ||||
from collections import Counter | from collections import Counter | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
from .utils import Example | |||||
class VocabularyOption(Example): | |||||
def __init__(self, | |||||
max_size=None, | |||||
min_freq=None, | |||||
padding='<pad>', | |||||
unknown='<unk>'): | |||||
super().__init__( | |||||
max_size=max_size, | |||||
min_freq=min_freq, | |||||
padding=padding, | |||||
unknown=unknown | |||||
) | |||||
def _check_build_vocab(func): | def _check_build_vocab(func): | ||||
@@ -1,10 +1,14 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"BaseLoader" | |||||
"BaseLoader", | |||||
'DataInfo', | |||||
'DataSetLoader', | |||||
] | ] | ||||
import _pickle as pickle | import _pickle as pickle | ||||
import os | import os | ||||
from typing import Union, Dict | |||||
import os | |||||
from ..core.dataset import DataSet | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
""" | """ | ||||
@@ -51,24 +55,161 @@ class BaseLoader(object): | |||||
return obj | return obj | ||||
class DataLoaderRegister: | |||||
_readers = {} | |||||
@classmethod | |||||
def set_reader(cls, reader_cls, read_fn_name): | |||||
# def wrapper(reader_cls): | |||||
if read_fn_name in cls._readers: | |||||
raise KeyError( | |||||
'duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, | |||||
read_fn_name)) | |||||
if hasattr(reader_cls, 'load'): | |||||
cls._readers[read_fn_name] = reader_cls().load | |||||
return reader_cls | |||||
@classmethod | |||||
def get_reader(cls, read_fn_name): | |||||
if read_fn_name in cls._readers: | |||||
return cls._readers[read_fn_name] | |||||
raise AttributeError('no read function: {}'.format(read_fn_name)) | |||||
# TODO 这个类使用在何处? | |||||
def _download_from_url(url, path): | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from ..core.utils import _pseudo_tqdm as tqdm | |||||
import requests | |||||
"""Download file""" | |||||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||||
chunk_size = 16 * 1024 | |||||
total_size = int(r.headers.get('Content-length', 0)) | |||||
with open(path, "wb") as file, \ | |||||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||||
for chunk in r.iter_content(chunk_size): | |||||
if chunk: | |||||
file.write(chunk) | |||||
t.update(len(chunk)) | |||||
def _uncompress(src, dst): | |||||
import zipfile | |||||
import gzip | |||||
import tarfile | |||||
import os | |||||
def unzip(src, dst): | |||||
with zipfile.ZipFile(src, 'r') as f: | |||||
f.extractall(dst) | |||||
def ungz(src, dst): | |||||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||||
length = 16 * 1024 # 16KB | |||||
buf = f.read(length) | |||||
while buf: | |||||
uf.write(buf) | |||||
buf = f.read(length) | |||||
def untar(src, dst): | |||||
with tarfile.open(src, 'r:gz') as f: | |||||
f.extractall(dst) | |||||
fn, ext = os.path.splitext(src) | |||||
_, ext_2 = os.path.splitext(fn) | |||||
if ext == '.zip': | |||||
unzip(src, dst) | |||||
elif ext == '.gz' and ext_2 != '.tar': | |||||
ungz(src, dst) | |||||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||||
untar(src, dst) | |||||
else: | |||||
raise ValueError('unsupported file {}'.format(src)) | |||||
class DataInfo: | |||||
""" | |||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | |||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||||
""" | |||||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||||
self.vocabs = vocabs or {} | |||||
self.embeddings = embeddings or {} | |||||
self.datasets = datasets or {} | |||||
class DataSetLoader: | |||||
""" | |||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||||
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 | |||||
开发者至少应该编写如下内容: | |||||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` | |||||
**process 函数中可以 调用load 函数或 _load 函数** | |||||
""" | |||||
URL = '' | |||||
DATA_DIR = '' | |||||
ROOT_DIR = '.fastnlp/datasets/' | |||||
UNCOMPRESS = True | |||||
def _download(self, url: str, pdir: str, uncompress=True) -> str: | |||||
""" | |||||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 | |||||
:param url: 下载的网站 | |||||
:param pdir: 下载到的目录 | |||||
:param uncompress: 是否自动解压缩 | |||||
:return: 数据的存放路径 | |||||
""" | |||||
fn = os.path.basename(url) | |||||
path = os.path.join(pdir, fn) | |||||
"""check data exists""" | |||||
if not os.path.exists(path): | |||||
os.makedirs(pdir, exist_ok=True) | |||||
_download_from_url(url, path) | |||||
if uncompress: | |||||
dst = os.path.join(pdir, 'data') | |||||
if not os.path.exists(dst): | |||||
_uncompress(path, dst) | |||||
return dst | |||||
return path | |||||
def download(self): | |||||
return self._download( | |||||
self.URL, | |||||
os.path.join(self.ROOT_DIR, self.DATA_DIR), | |||||
uncompress=self.UNCOMPRESS) | |||||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 | |||||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 | |||||
:param Union[str, Dict[str, str]] paths: 文件路径 | |||||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 | |||||
""" | |||||
if isinstance(paths, str): | |||||
return self._load(paths) | |||||
return {name: self._load(path) for name, path in paths.items()} | |||||
def _load(self, path: str) -> DataSet: | |||||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 | |||||
:param str path: 文件路径 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||||
""" | |||||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||||
返回的 :class:`DataInfo` 对象有如下属性: | |||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||||
- embeddings: (可选) 数据集对应的词嵌入 | |||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||||
:param paths: 原始数据读取的路径 | |||||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||||
:return: 返回一个 DataInfo | |||||
""" | |||||
raise NotImplementedError |
@@ -0,0 +1,95 @@ | |||||
from typing import Iterable | |||||
from nltk import Tree | |||||
from ..base_loader import DataInfo, DataSetLoader | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
class SSTLoader(DataSetLoader): | |||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
DATA_DIR = 'sst/' | |||||
""" | |||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||||
读取SST数据集, DataSet包含fields:: | |||||
words: list(str) 需要分类的文本 | |||||
target: str 文本的标签 | |||||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
""" | |||||
def __init__(self, subtree=False, fine_grained=False): | |||||
self.subtree = subtree | |||||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||||
'3': 'positive', '4': 'very positive'} | |||||
if not fine_grained: | |||||
tag_v['0'] = tag_v['1'] | |||||
tag_v['4'] = tag_v['3'] | |||||
self.tag_v = tag_v | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 存储数据的路径 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
datas = [] | |||||
for l in f: | |||||
datas.extend([(s, self.tag_v[t]) | |||||
for s, t in self._get_one(l, self.subtree)]) | |||||
ds = DataSet() | |||||
for words, tag in datas: | |||||
ds.append(Instance(words=words, target=tag)) | |||||
return ds | |||||
@staticmethod | |||||
def _get_one(data, subtree): | |||||
tree = Tree.fromstring(data) | |||||
if subtree: | |||||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
return [(tree.leaves(), tree.label())] | |||||
def process(self, | |||||
paths, | |||||
train_ds: Iterable[str] = None, | |||||
src_vocab_op: VocabularyOption = None, | |||||
tgt_vocab_op: VocabularyOption = None, | |||||
src_embed_op: EmbeddingOption = None): | |||||
input_name, target_name = 'words', 'target' | |||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
info = DataInfo(datasets=self.load(paths)) | |||||
_train_ds = [info.datasets[name] | |||||
for name in train_ds] if train_ds else info.datasets.values() | |||||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
src_vocab.index_dataset( | |||||
*info.datasets.values(), | |||||
field_name=input_name, new_field_name=input_name) | |||||
tgt_vocab.index_dataset( | |||||
*info.datasets.values(), | |||||
field_name=target_name, new_field_name=target_name) | |||||
info.vocabs = { | |||||
input_name: src_vocab, | |||||
target_name: tgt_vocab | |||||
} | |||||
if src_embed_op is not None: | |||||
src_embed_op.vocab = src_vocab | |||||
init_emb = EmbedLoader.load_with_vocab(**src_embed_op) | |||||
info.embeddings[input_name] = init_emb | |||||
return info | |||||
@@ -13,8 +13,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||||
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 | 为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'DataInfo', | |||||
'DataSetLoader', | |||||
'CSVLoader', | 'CSVLoader', | ||||
'JsonLoader', | 'JsonLoader', | ||||
'ConllLoader', | 'ConllLoader', | ||||
@@ -24,158 +22,12 @@ __all__ = [ | |||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
] | ] | ||||
from nltk.tree import Tree | |||||
from nltk import Tree | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.instance import Instance | from ..core.instance import Instance | ||||
from .file_reader import _read_csv, _read_json, _read_conll | from .file_reader import _read_csv, _read_json, _read_conll | ||||
from typing import Union, Dict | |||||
import os | |||||
def _download_from_url(url, path): | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from ..core.utils import _pseudo_tqdm as tqdm | |||||
import requests | |||||
"""Download file""" | |||||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||||
chunk_size = 16 * 1024 | |||||
total_size = int(r.headers.get('Content-length', 0)) | |||||
with open(path, "wb") as file, \ | |||||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||||
for chunk in r.iter_content(chunk_size): | |||||
if chunk: | |||||
file.write(chunk) | |||||
t.update(len(chunk)) | |||||
return | |||||
def _uncompress(src, dst): | |||||
import zipfile | |||||
import gzip | |||||
import tarfile | |||||
import os | |||||
def unzip(src, dst): | |||||
with zipfile.ZipFile(src, 'r') as f: | |||||
f.extractall(dst) | |||||
def ungz(src, dst): | |||||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||||
length = 16 * 1024 # 16KB | |||||
buf = f.read(length) | |||||
while buf: | |||||
uf.write(buf) | |||||
buf = f.read(length) | |||||
def untar(src, dst): | |||||
with tarfile.open(src, 'r:gz') as f: | |||||
f.extractall(dst) | |||||
fn, ext = os.path.splitext(src) | |||||
_, ext_2 = os.path.splitext(fn) | |||||
if ext == '.zip': | |||||
unzip(src, dst) | |||||
elif ext == '.gz' and ext_2 != '.tar': | |||||
ungz(src, dst) | |||||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||||
untar(src, dst) | |||||
else: | |||||
raise ValueError('unsupported file {}'.format(src)) | |||||
class DataInfo: | |||||
""" | |||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 | |||||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||||
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` | |||||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||||
""" | |||||
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): | |||||
self.vocabs = vocabs or {} | |||||
self.embeddings = embeddings or {} | |||||
self.datasets = datasets or {} | |||||
class DataSetLoader: | |||||
""" | |||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||||
定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 | |||||
开发者至少应该编写如下内容: | |||||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` | |||||
**process 函数中可以 调用load 函数或 _load 函数** | |||||
""" | |||||
def _download(self, url: str, path: str, uncompress=True) -> str: | |||||
""" | |||||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 | |||||
:param url: 下载的网站 | |||||
:param path: 下载到的目录 | |||||
:param uncompress: 是否自动解压缩 | |||||
:return: 数据的存放路径 | |||||
""" | |||||
pdir = os.path.dirname(path) | |||||
os.makedirs(pdir, exist_ok=True) | |||||
_download_from_url(url, path) | |||||
if uncompress: | |||||
dst = os.path.join(pdir, 'data') | |||||
_uncompress(path, dst) | |||||
return dst | |||||
return path | |||||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 | |||||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 | |||||
:param Union[str, Dict[str, str]] paths: 文件路径 | |||||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 | |||||
""" | |||||
if isinstance(paths, str): | |||||
return self._load(paths) | |||||
return {name: self._load(path) for name, path in paths.items()} | |||||
def _load(self, path: str) -> DataSet: | |||||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 | |||||
:param str path: 文件路径 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: | |||||
""" | |||||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||||
返回的 :class:`DataInfo` 对象有如下属性: | |||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||||
- embeddings: (可选) 数据集对应的词嵌入 | |||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||||
:param paths: 原始数据读取的路径 | |||||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||||
:return: 返回一个 DataInfo | |||||
""" | |||||
raise NotImplementedError | |||||
from .base_loader import DataSetLoader | |||||
from .data_loader.sst import SSTLoader | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
""" | """ | ||||
@@ -183,12 +35,12 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
读取人民日报数据集 | 读取人民日报数据集 | ||||
""" | """ | ||||
def __init__(self, pos=True, ner=True): | def __init__(self, pos=True, ner=True): | ||||
super(PeopleDailyCorpusLoader, self).__init__() | super(PeopleDailyCorpusLoader, self).__init__() | ||||
self.pos = pos | self.pos = pos | ||||
self.ner = ner | self.ner = ner | ||||
def _load(self, data_path): | def _load(self, data_path): | ||||
with open(data_path, "r", encoding="utf-8") as f: | with open(data_path, "r", encoding="utf-8") as f: | ||||
sents = f.readlines() | sents = f.readlines() | ||||
@@ -233,7 +85,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
example.append(sent_ner) | example.append(sent_ner) | ||||
examples.append(example) | examples.append(example) | ||||
return self.convert(examples) | return self.convert(examples) | ||||
def convert(self, data): | def convert(self, data): | ||||
""" | """ | ||||
@@ -284,7 +136,7 @@ class ConllLoader(DataSetLoader): | |||||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | ||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, headers, indexes=None, dropna=False): | def __init__(self, headers, indexes=None, dropna=False): | ||||
super(ConllLoader, self).__init__() | super(ConllLoader, self).__init__() | ||||
if not isinstance(headers, (list, tuple)): | if not isinstance(headers, (list, tuple)): | ||||
@@ -298,7 +150,7 @@ class ConllLoader(DataSetLoader): | |||||
if len(indexes) != len(headers): | if len(indexes) != len(headers): | ||||
raise ValueError | raise ValueError | ||||
self.indexes = indexes | self.indexes = indexes | ||||
def _load(self, path): | def _load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | ||||
@@ -316,7 +168,7 @@ class Conll2003Loader(ConllLoader): | |||||
关于数据集的更多信息,参考: | 关于数据集的更多信息,参考: | ||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
headers = [ | headers = [ | ||||
'tokens', 'pos', 'chunks', 'ner', | 'tokens', 'pos', 'chunks', 'ner', | ||||
@@ -354,56 +206,6 @@ def _cut_long_sentence(sent, max_sample_length=200): | |||||
return cutted_sentence | return cutted_sentence | ||||
class SSTLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` | |||||
读取SST数据集, DataSet包含fields:: | |||||
words: list(str) 需要分类的文本 | |||||
target: str 文本的标签 | |||||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
""" | |||||
def __init__(self, subtree=False, fine_grained=False): | |||||
self.subtree = subtree | |||||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||||
'3': 'positive', '4': 'very positive'} | |||||
if not fine_grained: | |||||
tag_v['0'] = tag_v['1'] | |||||
tag_v['4'] = tag_v['3'] | |||||
self.tag_v = tag_v | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 存储数据的路径 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
datas = [] | |||||
for l in f: | |||||
datas.extend([(s, self.tag_v[t]) | |||||
for s, t in self._get_one(l, self.subtree)]) | |||||
ds = DataSet() | |||||
for words, tag in datas: | |||||
ds.append(Instance(words=words, target=tag)) | |||||
return ds | |||||
@staticmethod | |||||
def _get_one(data, subtree): | |||||
tree = Tree.fromstring(data) | |||||
if subtree: | |||||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||||
return [(tree.leaves(), tree.label())] | |||||
class JsonLoader(DataSetLoader): | class JsonLoader(DataSetLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` | 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` | ||||
@@ -417,7 +219,7 @@ class JsonLoader(DataSetLoader): | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | ||||
Default: ``False`` | Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, fields=None, dropna=False): | def __init__(self, fields=None, dropna=False): | ||||
super(JsonLoader, self).__init__() | super(JsonLoader, self).__init__() | ||||
self.dropna = dropna | self.dropna = dropna | ||||
@@ -428,7 +230,7 @@ class JsonLoader(DataSetLoader): | |||||
for k, v in fields.items(): | for k, v in fields.items(): | ||||
self.fields[k] = k if v is None else v | self.fields[k] = k if v is None else v | ||||
self.fields_list = list(self.fields.keys()) | self.fields_list = list(self.fields.keys()) | ||||
def _load(self, path): | def _load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | ||||
@@ -452,7 +254,7 @@ class SNLILoader(JsonLoader): | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
fields = { | fields = { | ||||
'sentence1_parse': 'words1', | 'sentence1_parse': 'words1', | ||||
@@ -460,14 +262,14 @@ class SNLILoader(JsonLoader): | |||||
'gold_label': 'target', | 'gold_label': 'target', | ||||
} | } | ||||
super(SNLILoader, self).__init__(fields=fields) | super(SNLILoader, self).__init__(fields=fields) | ||||
def _load(self, path): | def _load(self, path): | ||||
ds = super(SNLILoader, self)._load(path) | ds = super(SNLILoader, self)._load(path) | ||||
def parse_tree(x): | def parse_tree(x): | ||||
t = Tree.fromstring(x) | t = Tree.fromstring(x) | ||||
return t.leaves() | return t.leaves() | ||||
ds.apply(lambda ins: parse_tree( | ds.apply(lambda ins: parse_tree( | ||||
ins['words1']), new_field_name='words1') | ins['words1']), new_field_name='words1') | ||||
ds.apply(lambda ins: parse_tree( | ds.apply(lambda ins: parse_tree( | ||||
@@ -488,12 +290,12 @@ class CSVLoader(DataSetLoader): | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | ||||
Default: ``False`` | Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, headers=None, sep=",", dropna=False): | def __init__(self, headers=None, sep=",", dropna=False): | ||||
self.headers = headers | self.headers = headers | ||||
self.sep = sep | self.sep = sep | ||||
self.dropna = dropna | self.dropna = dropna | ||||
def _load(self, path): | def _load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, data in _read_csv(path, headers=self.headers, | for idx, data in _read_csv(path, headers=self.headers, | ||||
@@ -508,7 +310,7 @@ def _add_seg_tag(data): | |||||
:param data: list of ([word], [pos], [heads], [head_tags]) | :param data: list of ([word], [pos], [heads], [head_tags]) | ||||
:return: list of ([word], [pos]) | :return: list of ([word], [pos]) | ||||
""" | """ | ||||
_processed = [] | _processed = [] | ||||
for word_list, pos_list, _, _ in data: | for word_list, pos_list, _, _ in data: | ||||
new_sample = [] | new_sample = [] | ||||
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"EmbedLoader" | |||||
"EmbedLoader", | |||||
"EmbeddingOption", | |||||
] | ] | ||||
import os | import os | ||||
@@ -9,8 +10,22 @@ import numpy as np | |||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from .base_loader import BaseLoader | from .base_loader import BaseLoader | ||||
from ..core.utils import Example | |||||
class EmbeddingOption(Example): | |||||
def __init__(self, | |||||
embed_filepath=None, | |||||
dtype=np.float32, | |||||
normalize=True, | |||||
error='ignore'): | |||||
super().__init__( | |||||
embed_filepath=embed_filepath, | |||||
dtype=dtype, | |||||
normalize=normalize, | |||||
error=error | |||||
) | |||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
""" | """ | ||||
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` | 别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` | ||||
@@ -10,6 +10,35 @@ from ..core.const import Const | |||||
from ..modules.encoder import BertModel | from ..modules.encoder import BertModel | ||||
class BertConfig: | |||||
def __init__( | |||||
self, | |||||
vocab_size=30522, | |||||
hidden_size=768, | |||||
num_hidden_layers=12, | |||||
num_attention_heads=12, | |||||
intermediate_size=3072, | |||||
hidden_act="gelu", | |||||
hidden_dropout_prob=0.1, | |||||
attention_probs_dropout_prob=0.1, | |||||
max_position_embeddings=512, | |||||
type_vocab_size=2, | |||||
initializer_range=0.02 | |||||
): | |||||
self.vocab_size = vocab_size | |||||
self.hidden_size = hidden_size | |||||
self.num_hidden_layers = num_hidden_layers | |||||
self.num_attention_heads = num_attention_heads | |||||
self.intermediate = intermediate_size | |||||
self.hidden_act = hidden_act | |||||
self.hidden_dropout_prob = hidden_dropout_prob | |||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob | |||||
self.max_position_embeddings = max_position_embeddings | |||||
self.type_vocab_size = type_vocab_size | |||||
self.initializer_range = initializer_range | |||||
class BertForSequenceClassification(BaseModel): | class BertForSequenceClassification(BaseModel): | ||||
"""BERT model for classification. | """BERT model for classification. | ||||
This module is composed of the BERT model with a linear layer on top of | This module is composed of the BERT model with a linear layer on top of | ||||
@@ -44,14 +73,19 @@ class BertForSequenceClassification(BaseModel): | |||||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | ||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | ||||
num_labels = 2 | num_labels = 2 | ||||
model = BertForSequenceClassification(config, num_labels) | |||||
model = BertForSequenceClassification(num_labels, config) | |||||
logits = model(input_ids, token_type_ids, input_mask) | logits = model(input_ids, token_type_ids, input_mask) | ||||
``` | ``` | ||||
""" | """ | ||||
def __init__(self, config, num_labels, bert_dir): | |||||
def __init__(self, num_labels, config=None, bert_dir=None): | |||||
super(BertForSequenceClassification, self).__init__() | super(BertForSequenceClassification, self).__init__() | ||||
self.num_labels = num_labels | self.num_labels = num_labels | ||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
if bert_dir is not None: | |||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
else: | |||||
if config is None: | |||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@@ -106,14 +140,19 @@ class BertForMultipleChoice(BaseModel): | |||||
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, | ||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | ||||
num_choices = 2 | num_choices = 2 | ||||
model = BertForMultipleChoice(config, num_choices, bert_dir) | |||||
model = BertForMultipleChoice(num_choices, config, bert_dir) | |||||
logits = model(input_ids, token_type_ids, input_mask) | logits = model(input_ids, token_type_ids, input_mask) | ||||
``` | ``` | ||||
""" | """ | ||||
def __init__(self, config, num_choices, bert_dir): | |||||
def __init__(self, num_choices, config=None, bert_dir=None): | |||||
super(BertForMultipleChoice, self).__init__() | super(BertForMultipleChoice, self).__init__() | ||||
self.num_choices = num_choices | self.num_choices = num_choices | ||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
if bert_dir is not None: | |||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
else: | |||||
if config is None: | |||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, 1) | self.classifier = nn.Linear(config.hidden_size, 1) | ||||
@@ -174,14 +213,19 @@ class BertForTokenClassification(BaseModel): | |||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | ||||
num_labels = 2 | num_labels = 2 | ||||
bert_dir = 'your-bert-file-dir' | bert_dir = 'your-bert-file-dir' | ||||
model = BertForTokenClassification(config, num_labels, bert_dir) | |||||
model = BertForTokenClassification(num_labels, config, bert_dir) | |||||
logits = model(input_ids, token_type_ids, input_mask) | logits = model(input_ids, token_type_ids, input_mask) | ||||
``` | ``` | ||||
""" | """ | ||||
def __init__(self, config, num_labels, bert_dir): | |||||
def __init__(self, num_labels, config=None, bert_dir=None): | |||||
super(BertForTokenClassification, self).__init__() | super(BertForTokenClassification, self).__init__() | ||||
self.num_labels = num_labels | self.num_labels = num_labels | ||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
if bert_dir is not None: | |||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
else: | |||||
if config is None: | |||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
self.dropout = nn.Dropout(config.hidden_dropout_prob) | self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.classifier = nn.Linear(config.hidden_size, num_labels) | self.classifier = nn.Linear(config.hidden_size, num_labels) | ||||
@@ -252,9 +296,14 @@ class BertForQuestionAnswering(BaseModel): | |||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask) | start_logits, end_logits = model(input_ids, token_type_ids, input_mask) | ||||
``` | ``` | ||||
""" | """ | ||||
def __init__(self, config, bert_dir): | |||||
def __init__(self, config=None, bert_dir=None): | |||||
super(BertForQuestionAnswering, self).__init__() | super(BertForQuestionAnswering, self).__init__() | ||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
if bert_dir is not None: | |||||
self.bert = BertModel.from_pretrained(bert_dir) | |||||
else: | |||||
if config is None: | |||||
config = BertConfig() | |||||
self.bert = BertModel(**config.__dict__) | |||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version | ||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob) | # self.dropout = nn.Dropout(config.hidden_dropout_prob) | ||||
self.qa_outputs = nn.Linear(config.hidden_size, 2) | self.qa_outputs = nn.Linear(config.hidden_size, 2) | ||||
@@ -2,20 +2,64 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP.models.bert import BertModel | |||||
from fastNLP.models.bert import * | |||||
class TestBert(unittest.TestCase): | class TestBert(unittest.TestCase): | ||||
def test_bert_1(self): | def test_bert_1(self): | ||||
# model = BertModel.from_pretrained("/home/zyfeng/data/bert-base-chinese") | |||||
model = BertModel(vocab_size=32000, hidden_size=768, | |||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||||
from fastNLP.core.const import Const | |||||
model = BertForSequenceClassification(2) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||||
pred = model(input_ids, token_type_ids, input_mask) | |||||
self.assertTrue(isinstance(pred, dict)) | |||||
self.assertTrue(Const.OUTPUT in pred) | |||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2)) | |||||
def test_bert_2(self): | |||||
from fastNLP.core.const import Const | |||||
model = BertForMultipleChoice(2) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||||
pred = model(input_ids, token_type_ids, input_mask) | |||||
self.assertTrue(isinstance(pred, dict)) | |||||
self.assertTrue(Const.OUTPUT in pred) | |||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (1, 2)) | |||||
def test_bert_3(self): | |||||
from fastNLP.core.const import Const | |||||
model = BertForTokenClassification(7) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||||
pred = model(input_ids, token_type_ids, input_mask) | |||||
self.assertTrue(isinstance(pred, dict)) | |||||
self.assertTrue(Const.OUTPUT in pred) | |||||
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 3, 7)) | |||||
def test_bert_4(self): | |||||
from fastNLP.core.const import Const | |||||
model = BertForQuestionAnswering() | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | ||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | ||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | ||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||||
for layer in all_encoder_layers: | |||||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | |||||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) | |||||
pred = model(input_ids, token_type_ids, input_mask) | |||||
self.assertTrue(isinstance(pred, dict)) | |||||
self.assertTrue(Const.OUTPUTS(0) in pred) | |||||
self.assertTrue(Const.OUTPUTS(1) in pred) | |||||
self.assertEqual(tuple(pred[Const.OUTPUTS(0)].shape), (2, 3)) | |||||
self.assertEqual(tuple(pred[Const.OUTPUTS(1)].shape), (2, 3)) |
@@ -0,0 +1,21 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.models.bert import BertModel | |||||
class TestBert(unittest.TestCase): | |||||
def test_bert_1(self): | |||||
model = BertModel(vocab_size=32000, hidden_size=768, | |||||
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) | |||||
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) | |||||
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) | |||||
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) | |||||
all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) | |||||
for layer in all_encoder_layers: | |||||
self.assertEqual(tuple(layer.shape), (2, 3, 768)) | |||||
self.assertEqual(tuple(pooled_output.shape), (2, 768)) |