@@ -1,4 +1,8 @@ | |||
""" | |||
.. warning:: | |||
本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||
用于读数据集的模块, 可以读取文本分类、序列标注、Matching任务的数据集 | |||
这些模块的具体介绍如下,您可以通过阅读 :doc:`教程</tutorials/tutorial_2_load_dataset>` 来进行了解。 | |||
@@ -1,4 +1,8 @@ | |||
""" | |||
.. warning:: | |||
本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , | |||
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。 | |||
以SNLI数据集为例:: | |||
@@ -11,6 +15,7 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||
# ... do stuff | |||
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 | |||
""" | |||
__all__ = [ | |||
'CSVLoader', | |||
@@ -1,4 +1,3 @@ | |||
import os | |||
from pathlib import Path | |||
from urllib.parse import urlparse | |||
@@ -9,35 +8,29 @@ from tqdm import tqdm | |||
import shutil | |||
from requests import HTTPError | |||
PRETRAINED_BERT_MODEL_DIR = { | |||
'en': 'bert-large-cased-wwm.zip', | |||
'en-base-uncased': 'bert-base-uncased-3413b23c.zip', | |||
'en-base-cased': 'bert-base-cased-f89bfe08.zip', | |||
'en-large-uncased': 'bert-large-uncased-20939f45.zip', | |||
'en-large-cased': 'bert-large-cased-e0cf90fc.zip', | |||
'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip', | |||
'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip', | |||
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip', | |||
'cn': 'bert-base-chinese-29d0a84a.zip', | |||
'cn-base': 'bert-base-chinese-29d0a84a.zip', | |||
'bert-base-chinese': 'bert-base-chinese.zip', | |||
'bert-base-cased': 'bert-base-cased.zip', | |||
'bert-base-cased-finetuned-mrpc': 'bert-base-cased-finetuned-mrpc.zip', | |||
'bert-large-cased-wwm': 'bert-large-cased-wwm.zip', | |||
'bert-large-uncased': 'bert-large-uncased.zip', | |||
'bert-large-cased': 'bert-large-cased.zip', | |||
'bert-base-uncased': 'bert-base-uncased.zip', | |||
'bert-large-uncased-wwm': 'bert-large-uncased-wwm.zip', | |||
'bert-chinese-wwm': 'bert-chinese-wwm.zip', | |||
'bert-base-multilingual-cased': 'bert-base-multilingual-cased.zip', | |||
'bert-base-multilingual-uncased': 'bert-base-multilingual-uncased.zip', | |||
'en-large-cased-wwm': 'bert-large-cased-wwm.zip', | |||
'en-large-uncased-wwm': 'bert-large-uncased-wwm.zip', | |||
'en-large-uncased': 'bert-large-uncased.zip', | |||
'en-large-cased': 'bert-large-cased.zip', | |||
'en-base-uncased': 'bert-base-uncased.zip', | |||
'en-base-cased': 'bert-base-cased.zip', | |||
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', | |||
'multi-base-cased': 'bert-base-multilingual-cased.zip', | |||
'multi-base-uncased': 'bert-base-multilingual-uncased.zip', | |||
'cn': 'bert-chinese-wwm.zip', | |||
'cn-base': 'bert-base-chinese.zip', | |||
'cn-wwm': 'bert-chinese-wwm.zip', | |||
} | |||
PRETRAINED_ELMO_MODEL_DIR = { | |||
'en': 'elmo_en-d39843fe.tar.gz', | |||
'en': 'elmo_en_Medium.tar.gz', | |||
'en-small': "elmo_en_Small.zip", | |||
'en-original-5.5b': 'elmo_en_Original_5.5B.zip', | |||
'en-original': 'elmo_en_Original.zip', | |||
@@ -45,30 +38,33 @@ PRETRAINED_ELMO_MODEL_DIR = { | |||
} | |||
PRETRAIN_STATIC_FILES = { | |||
'en': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz', | |||
'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz", | |||
'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz", | |||
'en': 'glove.840B.300d.zip', | |||
'en-glove-6b-50d': 'glove.6B.50d.zip', | |||
'en-glove-6b-100d': 'glove.6B.100d.zip', | |||
'en-glove-6b-200d': 'glove.6B.200d.zip', | |||
'en-glove-6b-300d': 'glove.6B.300d.zip', | |||
'en-glove-42b-300d': 'glove.42B.300d.zip', | |||
'en-glove-840b-300d': 'glove.840B.300d.zip', | |||
'en-glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', | |||
'en-glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', | |||
'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | |||
'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', | |||
'en-word2vec-300': "GoogleNews-vectors-negative300.txt.gz", | |||
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | |||
'cn': "tencent_cn-dab24577.tar.gz", | |||
'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz", | |||
'sgns-literature-word':'sgns.literature.word.txt.zip', | |||
'glove-42b-300d': 'glove.42B.300d.zip', | |||
'glove-6b-50d': 'glove.6B.50d.zip', | |||
'glove-6b-100d': 'glove.6B.100d.zip', | |||
'glove-6b-200d': 'glove.6B.200d.zip', | |||
'glove-6b-300d': 'glove.6B.300d.zip', | |||
'glove-840b-300d': 'glove.840B.300d.zip', | |||
'glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', | |||
'glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', | |||
'glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | |||
'glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip' | |||
} | |||
'en-fasttext-crawl': "crawl-300d-2M.vec.zip", | |||
'cn': "tencent_cn.txt.zip", | |||
'cn-tencent': "tencent_cn.txt.zip", | |||
'cn-fasttext': "cc.zh.300.vec.gz", | |||
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', | |||
} | |||
DATASET_DIR = { | |||
'aclImdb': "imdb.zip", | |||
"yelp-review-full":"yelp_review_full.tar.gz", | |||
"yelp-review-full": "yelp_review_full.tar.gz", | |||
"yelp-review-polarity": "yelp_review_polarity.tar.gz", | |||
"mnli": "MNLI.zip", | |||
"snli": "SNLI.zip", | |||
@@ -90,7 +86,7 @@ FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | |||
} | |||
def cached_path(url_or_filename:str, cache_dir:str=None, name=None) -> Path: | |||
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | |||
""" | |||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | |||
(1)如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | |||
@@ -147,7 +143,7 @@ def get_filepath(filepath): | |||
""" | |||
if os.path.isdir(filepath): | |||
files = os.listdir(filepath) | |||
if len(files)==1: | |||
if len(files) == 1: | |||
return os.path.join(filepath, files[0]) | |||
else: | |||
return filepath | |||
@@ -191,9 +187,9 @@ def _get_base_url(name): | |||
return url + '/' | |||
else: | |||
URLS = { | |||
'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", | |||
"dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" | |||
} | |||
'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/", | |||
"dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/" | |||
} | |||
if name.lower() not in URLS: | |||
raise KeyError(f"{name} is not recognized.") | |||
return URLS[name.lower()] | |||
@@ -213,14 +209,13 @@ def _get_embedding_url(embed_type, name): | |||
url = _read_extend_url_file(_filename, name) | |||
if url: | |||
return url | |||
map = PRETRAIN_MAP.get(embed_type, None) | |||
if map: | |||
filename = map.get(name, None) | |||
embed_map = PRETRAIN_MAP.get(embed_type, None) | |||
if embed_map: | |||
filename = embed_map.get(name, None) | |||
if filename: | |||
url = _get_base_url('embedding') + filename | |||
return url | |||
raise KeyError("There is no {}. Only supports {}.".format(name, list(map.keys()))) | |||
raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) | |||
else: | |||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static") | |||
@@ -313,16 +308,16 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
# Download to temporary file, then copy to cache dir once finished. | |||
# Otherwise you get corrupt cache entries if the download gets interrupted. | |||
fd, temp_filename = tempfile.mkstemp() | |||
print("%s not found in cache, downloading to %s"%(url, temp_filename)) | |||
print("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||
# GET file object | |||
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | |||
if req.status_code==200: | |||
if req.status_code == 200: | |||
content_length = req.headers.get("Content-Length") | |||
total = int(content_length) if content_length is not None else None | |||
progress = tqdm(unit="B", total=total, unit_scale=1) | |||
with open(temp_filename, "wb") as temp_file: | |||
for chunk in req.iter_content(chunk_size=1024*16): | |||
for chunk in req.iter_content(chunk_size=1024 * 16): | |||
if chunk: # filter out keep-alive new chunks | |||
progress.update(len(chunk)) | |||
temp_file.write(chunk) | |||
@@ -340,7 +335,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
else: | |||
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
filenames = os.listdir(uncompress_temp_dir) | |||
if len(filenames)==1: | |||
if len(filenames) == 1: | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): | |||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||
@@ -356,9 +351,9 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
if os.path.isdir(uncompress_temp_dir): | |||
for filename in os.listdir(uncompress_temp_dir): | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): | |||
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path/filename) | |||
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path / filename) | |||
else: | |||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename) | |||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path / filename) | |||
else: | |||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||
success = True | |||
@@ -390,7 +385,7 @@ def unzip_file(file: Path, to: Path): | |||
zipObj.extractall(to) | |||
def untar_gz_file(file:Path, to:Path): | |||
def untar_gz_file(file: Path, to: Path): | |||
import tarfile | |||
with tarfile.open(file, 'r:gz') as tar: | |||
@@ -409,12 +404,11 @@ def match_file(dir_name: str, cache_dir: Path) -> str: | |||
files = os.listdir(cache_dir) | |||
matched_filenames = [] | |||
for file_name in files: | |||
if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name): | |||
if re.match(dir_name + '$', file_name) or re.match(dir_name + '\\..*', file_name): | |||
matched_filenames.append(file_name) | |||
if len(matched_filenames)==0: | |||
if len(matched_filenames) == 0: | |||
return '' | |||
elif len(matched_filenames)==1: | |||
elif len(matched_filenames) == 1: | |||
return matched_filenames[-1] | |||
else: | |||
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | |||
@@ -1,25 +1,35 @@ | |||
""" | |||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 | |||
三个方法: __init__(),_load(), loads(). 其中__init__()用于申明读取参数,以及说明该Loader支持的数据格式,读取后Dataset中field | |||
; _load(path)方法传入一个文件路径读取单个文件,并返回DataSet; load(paths)用于读取文件夹下的文件,并返回DataBundle, load()方法 | |||
支持以下三种类型的参数:: | |||
(0) 如果传入None,将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 | |||
(1) 如果传入的是一个文件path,则返回的DataBundle包含一个名为train的DataSet可以通过data_bundle.datasets['train']获取 | |||
(2) 传入的是一个文件夹目录,将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。 | |||
假设某个目录下的文件为 | |||
-train.txt | |||
-dev.txt | |||
-test.txt | |||
-other.txt | |||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||
data_bundle.datasets['test']获取对应的DataSet,其中other.txt的内容会被忽略。 | |||
假设某个目录下的文件为 | |||
-train.txt | |||
-dev.txt | |||
Loader().load('/path/to/dir')读取,返回的data_bundle中可以用data_bundle.datasets['train'], data_bundle.datasets['dev']获取 | |||
对应的DataSet。 | |||
(3) 传入一个dict,key为dataset的名称,value是该dataset的文件路径。 | |||
三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, | |||
读取后 :class:`~fastNLP.Dataset` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.Dataset` ; | |||
``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: | |||
0.传入None | |||
将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 | |||
1.传入一个文件path | |||
返回的 data_bundle 包含一个名为 `train` 的 dataset ,可以通过 data_bundle.datasets['train']获取 | |||
2.传入一个文件夹目录 | |||
将读取的是这个文件夹下文件名中包含'train', 'test', 'dev'的文件,其它文件会被忽略。假设某个目录下的文件为:: | |||
-train.txt | |||
-dev.txt | |||
-test.txt | |||
-other.txt | |||
Loader().load('/path/to/dir')读取,返回的 data_bundle 中可以用 data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||
data_bundle.datasets['test'] 获取对应的DataSet,其中other.txt的内容会被忽略。假设某个目录下的文件为:: | |||
-train.txt | |||
-dev.txt | |||
Loader().load('/path/to/dir')读取,返回的 data_bundle 中可以用 data_bundle.datasets['train'], | |||
data_bundle.datasets['dev'] 获取对应的DataSet。 | |||
3.传入一个dict | |||
key为 dataset 的名称,value 是该 dataset 的文件路径:: | |||
paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} | |||
Loader().load(paths) # 返回的data_bundle可以通过以下的方式获取相应的DataSet, data_bundle.datasets['train'], data_bundle.datasets['dev'], | |||
data_bundle.datasets['test'] | |||
@@ -1,7 +1,8 @@ | |||
""" | |||
Pipe用于处理数据,所有的Pipe都包含一个process(DataBundle)方法,传入一个DataBundle对象, 在传入DataBundle上进行原位修改,并将其返回; | |||
process_from_file(paths)传入的文件路径,返回一个DataBundle。process(DataBundle)或者process_from_file(paths)的返回DataBundle | |||
中的DataSet一般都包含原文与转换为index的输入,以及转换为index的target;除了DataSet之外,还会包含将field转为index时所建立的词表。 | |||
Pipe用于处理数据,所有的Pipe都包含一个 process(data_bundle) 方法,传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, | |||
在传入 data_bundle 上进行原位修改,并将其返回; process_from_file(paths) 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 。 | |||
process(data_bundle) 或者 process_from_file(paths)的返回 :class:`~fastNLP.io.DataBundle` 中的 :class:`~fastNLP.DataSet` | |||
一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外,还会包含将field转为index时所建立的词表。 | |||
""" | |||
__all__ = [ | |||
@@ -1,4 +1,3 @@ | |||
import math | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer | |||
@@ -19,19 +18,17 @@ class MatchingBertPipe(Pipe): | |||
"...", "...", "[...]", ., . | |||
words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 | |||
words列被设置为input,target列被设置为target. | |||
words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, | |||
如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). | |||
:param bool lower: 是否将word小写化。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
:param int max_concat_sent_length: 如果concat后的句子长度超过了该值,则合并后的句子将被截断到这个长度,截断时同时对premise | |||
和hypothesis按比例截断。 | |||
""" | |||
def __init__(self, lower=False, tokenizer:str='raw', max_concat_sent_length:int=480): | |||
def __init__(self, lower=False, tokenizer: str='raw'): | |||
super().__init__() | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenizer=tokenizer) | |||
self.max_concat_sent_length = int(max_concat_sent_length) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
""" | |||
@@ -43,11 +40,15 @@ class MatchingBertPipe(Pipe): | |||
""" | |||
for name, dataset in data_bundle.datasets.items(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field(Const.TARGET): | |||
dataset.drop(lambda x: x[Const.TARGET] == '-') | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0)) | |||
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1)) | |||
@@ -64,40 +65,31 @@ class MatchingBertPipe(Pipe): | |||
def concat(ins): | |||
words0 = ins[Const.INPUTS(0)] | |||
words1 = ins[Const.INPUTS(1)] | |||
len0 = len(words0) | |||
len1 = len(words1) | |||
if len0 + len1 > self.max_concat_sent_length: | |||
ratio = self.max_concat_sent_length / (len0 + len1) | |||
len0 = math.floor(ratio * len0) | |||
len1 = math.floor(ratio * len1) | |||
words0 = words0[:len0] | |||
words1 = words1[:len1] | |||
words = words0 + ['[SEP]'] + words1 | |||
return words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply(concat, new_field_name=Const.INPUT) | |||
dataset.delete_field(Const.INPUTS(0)) | |||
dataset.delete_field(Const.INPUTS(1)) | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name=Const.INPUT, | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
name != 'train']) | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
has_target_datasets = [] | |||
for name, dataset in data_bundle.datasets.items(): | |||
if dataset.has_field(Const.TARGET): | |||
has_target_datasets.append(dataset) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(word_vocab, Const.INPUT) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUT, Const.INPUT_LEN] | |||
input_fields = [Const.INPUT, Const.INPUT_LEN, Const.TARGET] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
@@ -149,12 +141,14 @@ class MatchingPipe(Pipe): | |||
"This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7 | |||
"...", "...", "[...]", "[...]", ., ., . | |||
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target。 | |||
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target | |||
和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 | |||
的形参名进行传参)。 | |||
:param bool lower: 是否将所有raw_words转为小写。 | |||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | |||
""" | |||
def __init__(self, lower=False, tokenizer:str='raw'): | |||
def __init__(self, lower=False, tokenizer: str='raw'): | |||
super().__init__() | |||
self.lower = bool(lower) | |||
@@ -170,7 +164,7 @@ class MatchingPipe(Pipe): | |||
""" | |||
for name, dataset in data_bundle.datasets.items(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words:self.tokenizer(words), field_name=field_name, | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
@@ -191,34 +185,37 @@ class MatchingPipe(Pipe): | |||
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], | |||
[Const.INPUTS(0), Const.INPUTS(1)]) | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field(Const.TARGET): | |||
dataset.drop(lambda x: x[Const.TARGET] == '-') | |||
if self.lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset[Const.INPUTS(0)].lower() | |||
dataset[Const.INPUTS(1)].lower() | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(data_bundle.datasets['train'], field_name=[Const.INPUTS(0), Const.INPUTS(1)], | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name=[Const.INPUTS(0), Const.INPUTS(1)], | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
name != 'train']) | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) | |||
has_target_datasets = [] | |||
for name, dataset in data_bundle.datasets.items(): | |||
if dataset.has_field(Const.TARGET): | |||
has_target_datasets.append(dataset) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field(Const.TARGET)] | |||
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) | |||
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) | |||
data_bundle.set_vocab(target_vocab, Const.TARGET) | |||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LEN(0), Const.INPUT_LEN(1)] | |||
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1), Const.TARGET] | |||
target_fields = [Const.TARGET] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LEN(0)) | |||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LEN(1)) | |||
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) | |||
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) | |||
dataset.set_input(*input_fields, flag=True) | |||
dataset.set_target(*target_fields, flag=True) | |||
@@ -2,13 +2,14 @@ | |||
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||
""" | |||
import os | |||
import torch | |||
from torch import nn | |||
from .base_model import BaseModel | |||
from ..core.const import Const | |||
from ..modules.encoder import BertModel | |||
from ..modules.encoder.bert import BertConfig | |||
from ..modules.encoder.bert import BertConfig, CONFIG_FILE | |||
class BertForSequenceClassification(BaseModel): | |||
@@ -54,6 +55,7 @@ class BertForSequenceClassification(BaseModel): | |||
self.num_labels = num_labels | |||
if bert_dir is not None: | |||
self.bert = BertModel.from_pretrained(bert_dir) | |||
config = BertConfig(os.path.join(bert_dir, CONFIG_FILE)) | |||
else: | |||
if config is None: | |||
config = BertConfig(30522) | |||
@@ -67,20 +69,20 @@ class BertForSequenceClassification(BaseModel): | |||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
def forward(self, words, seq_len=None, target=None): | |||
_, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False) | |||
pooled_output = self.dropout(pooled_output) | |||
logits = self.classifier(pooled_output) | |||
if labels is not None: | |||
if target is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
loss = loss_fct(logits, target) | |||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask) | |||
def predict(self, words, seq_len=None): | |||
logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT] | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
@@ -140,7 +142,8 @@ class BertForMultipleChoice(BaseModel): | |||
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
def forward(self, words, seq_len1=None, seq_len2=None, target=None): | |||
input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2 | |||
flat_input_ids = input_ids.view(-1, input_ids.size(-1)) | |||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) | |||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) | |||
@@ -149,15 +152,15 @@ class BertForMultipleChoice(BaseModel): | |||
logits = self.classifier(pooled_output) | |||
reshaped_logits = logits.view(-1, self.num_choices) | |||
if labels is not None: | |||
if target is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
loss = loss_fct(reshaped_logits, labels) | |||
loss = loss_fct(reshaped_logits, target) | |||
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss} | |||
else: | |||
return {Const.OUTPUT: reshaped_logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] | |||
def predict(self, words, seq_len1=None, seq_len2=None,): | |||
logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT] | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
@@ -219,27 +222,27 @@ class BertForTokenClassification(BaseModel): | |||
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): | |||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
def forward(self, words, seq_len1=None, seq_len2=None, target=None): | |||
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) | |||
sequence_output = self.dropout(sequence_output) | |||
logits = self.classifier(sequence_output) | |||
if labels is not None: | |||
if target is not None: | |||
loss_fct = nn.CrossEntropyLoss() | |||
# Only keep active parts of the loss | |||
if attention_mask is not None: | |||
active_loss = attention_mask.view(-1) == 1 | |||
if seq_len2 is not None: | |||
active_loss = seq_len2.view(-1) == 1 | |||
active_logits = logits.view(-1, self.num_labels)[active_loss] | |||
active_labels = labels.view(-1)[active_loss] | |||
active_labels = target.view(-1)[active_loss] | |||
loss = loss_fct(active_logits, active_labels) | |||
else: | |||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |||
loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1)) | |||
return {Const.OUTPUT: logits, Const.LOSS: loss} | |||
else: | |||
return {Const.OUTPUT: logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT] | |||
def predict(self, words, seq_len1=None, seq_len2=None): | |||
logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT] | |||
return {Const.OUTPUT: torch.argmax(logits, dim=-1)} | |||
@@ -304,34 +307,34 @@ class BertForQuestionAnswering(BaseModel): | |||
model = cls(config=config, bert_dir=pretrained_model_dir) | |||
return model | |||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): | |||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) | |||
def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None): | |||
sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False) | |||
logits = self.qa_outputs(sequence_output) | |||
start_logits, end_logits = logits.split(1, dim=-1) | |||
start_logits = start_logits.squeeze(-1) | |||
end_logits = end_logits.squeeze(-1) | |||
if start_positions is not None and end_positions is not None: | |||
if target1 is not None and target2 is not None: | |||
# If we are on multi-GPU, split add a dimension | |||
if len(start_positions.size()) > 1: | |||
start_positions = start_positions.squeeze(-1) | |||
if len(end_positions.size()) > 1: | |||
end_positions = end_positions.squeeze(-1) | |||
if len(target1.size()) > 1: | |||
target1 = target1.squeeze(-1) | |||
if len(target2.size()) > 1: | |||
target2 = target2.squeeze(-1) | |||
# sometimes the start/end positions are outside our model inputs, we ignore these terms | |||
ignored_index = start_logits.size(1) | |||
start_positions.clamp_(0, ignored_index) | |||
end_positions.clamp_(0, ignored_index) | |||
target1.clamp_(0, ignored_index) | |||
target2.clamp_(0, ignored_index) | |||
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index) | |||
start_loss = loss_fct(start_logits, start_positions) | |||
end_loss = loss_fct(end_logits, end_positions) | |||
start_loss = loss_fct(start_logits, target1) | |||
end_loss = loss_fct(end_logits, target2) | |||
total_loss = (start_loss + end_loss) / 2 | |||
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss} | |||
else: | |||
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits} | |||
def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs): | |||
logits = self.forward(input_ids, token_type_ids, attention_mask) | |||
def predict(self, words, seq_len1=None, seq_len2=None): | |||
logits = self.forward(words, seq_len1, seq_len2) | |||
start_logits = logits[Const.OUTPUTS(0)] | |||
end_logits = logits[Const.OUTPUTS(1)] | |||
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1), | |||