@@ -82,7 +82,6 @@ __all__ = [ | |||||
from .embed_loader import EmbedLoader | from .embed_loader import EmbedLoader | ||||
from .data_bundle import DataBundle | from .data_bundle import DataBundle | ||||
from .dataset_loader import CSVLoader, JsonLoader | |||||
from .model_io import ModelLoader, ModelSaver | from .model_io import ModelLoader, ModelSaver | ||||
from .loader import * | from .loader import * | ||||
@@ -6,112 +6,10 @@ __all__ = [ | |||||
'DataBundle', | 'DataBundle', | ||||
] | ] | ||||
import _pickle as pickle | |||||
import os | |||||
from typing import Union, Dict | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
class BaseLoader(object): | |||||
""" | |||||
各个 Loader 的基类,提供了 API 的参考。 | |||||
""" | |||||
def __init__(self): | |||||
super(BaseLoader, self).__init__() | |||||
@staticmethod | |||||
def load_lines(data_path): | |||||
""" | |||||
按行读取,舍弃每行两侧空白字符,返回list of str | |||||
:param data_path: 读取数据的路径 | |||||
""" | |||||
with open(data_path, "r", encoding="utf=8") as f: | |||||
text = f.readlines() | |||||
return [line.strip() for line in text] | |||||
@classmethod | |||||
def load(cls, data_path): | |||||
""" | |||||
先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str | |||||
:param data_path: | |||||
""" | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
text = f.readlines() | |||||
return [[word for word in sent.strip()] for sent in text] | |||||
@classmethod | |||||
def load_with_cache(cls, data_path, cache_path): | |||||
"""缓存版的load | |||||
""" | |||||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | |||||
with open(cache_path, 'rb') as f: | |||||
return pickle.load(f) | |||||
else: | |||||
obj = cls.load(data_path) | |||||
with open(cache_path, 'wb') as f: | |||||
pickle.dump(obj, f) | |||||
return obj | |||||
def _download_from_url(url, path): | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from ..core.utils import _pseudo_tqdm as tqdm | |||||
import requests | |||||
"""Download file""" | |||||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||||
chunk_size = 16 * 1024 | |||||
total_size = int(r.headers.get('Content-length', 0)) | |||||
with open(path, "wb") as file, \ | |||||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||||
for chunk in r.iter_content(chunk_size): | |||||
if chunk: | |||||
file.write(chunk) | |||||
t.update(len(chunk)) | |||||
def _uncompress(src, dst): | |||||
import zipfile | |||||
import gzip | |||||
import tarfile | |||||
import os | |||||
def unzip(src, dst): | |||||
with zipfile.ZipFile(src, 'r') as f: | |||||
f.extractall(dst) | |||||
def ungz(src, dst): | |||||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||||
length = 16 * 1024 # 16KB | |||||
buf = f.read(length) | |||||
while buf: | |||||
uf.write(buf) | |||||
buf = f.read(length) | |||||
def untar(src, dst): | |||||
with tarfile.open(src, 'r:gz') as f: | |||||
f.extractall(dst) | |||||
fn, ext = os.path.splitext(src) | |||||
_, ext_2 = os.path.splitext(fn) | |||||
if ext == '.zip': | |||||
unzip(src, dst) | |||||
elif ext == '.gz' and ext_2 != '.tar': | |||||
ungz(src, dst) | |||||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||||
untar(src, dst) | |||||
else: | |||||
raise ValueError('unsupported file {}'.format(src)) | |||||
class DataBundle: | class DataBundle: | ||||
""" | """ | ||||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | ||||
@@ -154,7 +52,7 @@ class DataBundle: | |||||
self.datasets[name] = dataset | self.datasets[name] = dataset | ||||
return self | return self | ||||
def get_dataset(self, name:str)->DataSet: | |||||
def get_dataset(self, name: str) -> DataSet: | |||||
""" | """ | ||||
获取名为name的dataset | 获取名为name的dataset | ||||
@@ -163,7 +61,7 @@ class DataBundle: | |||||
""" | """ | ||||
return self.datasets[name] | return self.datasets[name] | ||||
def delete_dataset(self, name:str): | |||||
def delete_dataset(self, name: str): | |||||
""" | """ | ||||
删除名为name的DataSet | 删除名为name的DataSet | ||||
@@ -173,7 +71,7 @@ class DataBundle: | |||||
self.datasets.pop(name, None) | self.datasets.pop(name, None) | ||||
return self | return self | ||||
def get_vocab(self, field_name:str)->Vocabulary: | |||||
def get_vocab(self, field_name: str) -> Vocabulary: | |||||
""" | """ | ||||
获取field名为field_name对应的vocab | 获取field名为field_name对应的vocab | ||||
@@ -182,7 +80,7 @@ class DataBundle: | |||||
""" | """ | ||||
return self.vocabs[field_name] | return self.vocabs[field_name] | ||||
def delete_vocab(self, field_name:str): | |||||
def delete_vocab(self, field_name: str): | |||||
""" | """ | ||||
删除vocab | 删除vocab | ||||
:param str field_name: | :param str field_name: | ||||
@@ -312,90 +210,3 @@ class DataBundle: | |||||
return _str | return _str | ||||
class DataSetLoader: | |||||
""" | |||||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||||
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 | |||||
开发者至少应该编写如下内容: | |||||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` | |||||
**process 函数中可以 调用load 函数或 _load 函数** | |||||
""" | |||||
URL = '' | |||||
DATA_DIR = '' | |||||
ROOT_DIR = '.fastnlp/datasets/' | |||||
UNCOMPRESS = True | |||||
def _download(self, url: str, pdir: str, uncompress=True) -> str: | |||||
""" | |||||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 | |||||
:param url: 下载的网站 | |||||
:param pdir: 下载到的目录 | |||||
:param uncompress: 是否自动解压缩 | |||||
:return: 数据的存放路径 | |||||
""" | |||||
fn = os.path.basename(url) | |||||
path = os.path.join(pdir, fn) | |||||
"""check data exists""" | |||||
if not os.path.exists(path): | |||||
os.makedirs(pdir, exist_ok=True) | |||||
_download_from_url(url, path) | |||||
if uncompress: | |||||
dst = os.path.join(pdir, 'data') | |||||
if not os.path.exists(dst): | |||||
_uncompress(path, dst) | |||||
return dst | |||||
return path | |||||
def download(self): | |||||
return self._download( | |||||
self.URL, | |||||
os.path.join(self.ROOT_DIR, self.DATA_DIR), | |||||
uncompress=self.UNCOMPRESS) | |||||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 | |||||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 | |||||
:param Union[str, Dict[str, str]] paths: 文件路径 | |||||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 | |||||
""" | |||||
if isinstance(paths, str): | |||||
return self._load(paths) | |||||
return {name: self._load(path) for name, path in paths.items()} | |||||
def _load(self, path: str) -> DataSet: | |||||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 | |||||
:param str path: 文件路径 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle: | |||||
""" | |||||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||||
返回的 :class:`DataBundle` 对象有如下属性: | |||||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||||
:param paths: 原始数据读取的路径 | |||||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||||
:return: 返回一个 DataBundle | |||||
""" | |||||
raise NotImplementedError |
@@ -1,39 +0,0 @@ | |||||
"""undocumented | |||||
.. warning:: | |||||
本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||||
用于读数据集的模块, 可以读取文本分类、序列标注、Matching任务的数据集 | |||||
这些模块的具体介绍如下,您可以通过阅读 :doc:`教程</tutorials/tutorial_2_load_dataset>` 来进行了解。 | |||||
""" | |||||
__all__ = [ | |||||
'ConllLoader', | |||||
'Conll2003Loader', | |||||
'IMDBLoader', | |||||
'MatchingLoader', | |||||
'SNLILoader', | |||||
'MNLILoader', | |||||
'MTL16Loader', | |||||
'PeopleDailyCorpusLoader', | |||||
'QNLILoader', | |||||
'QuoraLoader', | |||||
'RTELoader', | |||||
'SSTLoader', | |||||
'SST2Loader', | |||||
'YelpLoader', | |||||
] | |||||
from .conll import ConllLoader, Conll2003Loader | |||||
from .imdb import IMDBLoader | |||||
from .matching import MatchingLoader | |||||
from .mnli import MNLILoader | |||||
from .mtl import MTL16Loader | |||||
from .people_daily import PeopleDailyCorpusLoader | |||||
from .qnli import QNLILoader | |||||
from .quora import QuoraLoader | |||||
from .rte import RTELoader | |||||
from .snli import SNLILoader | |||||
from .sst import SSTLoader, SST2Loader | |||||
from .yelp import YelpLoader |
@@ -1,109 +0,0 @@ | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ..data_bundle import DataSetLoader | |||||
from ..file_reader import _read_conll | |||||
from typing import Union, Dict | |||||
from ..utils import check_loader_paths | |||||
from ..data_bundle import DataBundle | |||||
class ConllLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | |||||
该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||||
Example:: | |||||
# 文件中的内容 | |||||
Nadim NNP B-NP B-PER | |||||
Ladki NNP I-NP I-PER | |||||
AL-AIN NNP B-NP B-LOC | |||||
United NNP B-NP B-LOC | |||||
Arab NNP I-NP I-LOC | |||||
Emirates NNPS I-NP I-LOC | |||||
1996-12-06 CD I-NP O | |||||
... | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 | |||||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 | |||||
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') | |||||
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field | |||||
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') | |||||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words | |||||
列与pos列的内容都是List[str] | |||||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||||
""" | |||||
def __init__(self, headers, indexes=None, dropna=True): | |||||
super(ConllLoader, self).__init__() | |||||
if not isinstance(headers, (list, tuple)): | |||||
raise TypeError( | |||||
'invalid headers: {}, should be list of strings'.format(headers)) | |||||
self.headers = headers | |||||
self.dropna = dropna | |||||
if indexes is None: | |||||
self.indexes = list(range(len(self.headers))) | |||||
else: | |||||
if len(indexes) != len(headers): | |||||
raise ValueError | |||||
self.indexes = indexes | |||||
def _load(self, path): | |||||
""" | |||||
传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。 | |||||
:param str path: 文件的路径 | |||||
:return: DataSet | |||||
""" | |||||
ds = DataSet() | |||||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||||
ds.append(Instance(**ins)) | |||||
return ds | |||||
def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle: | |||||
""" | |||||
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 | |||||
读取的field根据ConllLoader初始化时传入的headers决定。 | |||||
:param Union[str, Dict[str, str]] paths: | |||||
:return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典 | |||||
""" | |||||
paths = check_loader_paths(paths) | |||||
datasets = {name: self._load(path) for name, path in paths.items()} | |||||
data_bundle = DataBundle(datasets=datasets) | |||||
return data_bundle | |||||
class Conll2003Loader(ConllLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` | |||||
该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003 | |||||
找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||||
返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。 | |||||
.. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len" | |||||
"[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2 | |||||
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5 | |||||
"[...]", "[...]", "[...]", . | |||||
""" | |||||
def __init__(self): | |||||
headers = [ | |||||
'raw_words', 'pos', 'chunks', 'ner', | |||||
] | |||||
super(Conll2003Loader, self).__init__(headers=headers) |
@@ -1,99 +0,0 @@ | |||||
from typing import Union, Dict | |||||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
from ..data_bundle import DataSetLoader, DataBundle | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.const import Const | |||||
from ..utils import get_tokenizer | |||||
class IMDBLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.data_loader.IMDBLoader` | |||||
读取IMDB数据集,DataSet包含以下fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
""" | |||||
def __init__(self): | |||||
super(IMDBLoader, self).__init__() | |||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | |||||
dataset = DataSet() | |||||
with open(path, 'r', encoding="utf-8") as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if not line: | |||||
continue | |||||
parts = line.split('\t') | |||||
target = parts[0] | |||||
words = self.tokenizer(parts[1].lower()) | |||||
dataset.append(Instance(words=words, target=target)) | |||||
if len(dataset) == 0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
return dataset | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None, | |||||
char_level_op=False): | |||||
datasets = {} | |||||
info = DataBundle() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
if char_level_op: | |||||
for dataset in datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
@@ -1,248 +0,0 @@ | |||||
import os | |||||
from typing import Union, Dict, List | |||||
from ...core.const import Const | |||||
from ...core.vocabulary import Vocabulary | |||||
from ..data_bundle import DataBundle, DataSetLoader | |||||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||||
from ...modules.encoder.bert import BertTokenizer | |||||
class MatchingLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||||
读取Matching任务的数据集 | |||||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
self.paths = paths | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 待读取数据集的路径名 | |||||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||||
的原始字符串文本,第三个为标签 | |||||
""" | |||||
raise NotImplementedError | |||||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||||
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, | |||||
extra_split: List[str]=None, ) -> DataBundle: | |||||
""" | |||||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||||
对应的全路径文件名。 | |||||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||||
这个数据集的名字,如果不定义则默认为train。 | |||||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||||
:param bool get_index: 是否需要根据词表将文本转为index | |||||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||||
:param str auto_pad_token: 自动pad的内容 | |||||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||||
于此同时其他field不会被设置为input。默认值为True。 | |||||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||||
:param extra_split: 额外的分隔符,即除了空格之外的用于分词的字符。 | |||||
:return: | |||||
""" | |||||
if isinstance(set_input, str): | |||||
set_input = [set_input] | |||||
if isinstance(set_target, str): | |||||
set_target = [set_target] | |||||
if isinstance(set_input, bool): | |||||
auto_set_input = set_input | |||||
else: | |||||
auto_set_input = False | |||||
if isinstance(set_target, bool): | |||||
auto_set_target = set_target | |||||
else: | |||||
auto_set_target = False | |||||
if isinstance(paths, str): | |||||
if os.path.isdir(paths): | |||||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||||
else: | |||||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||||
else: | |||||
path = paths | |||||
data_info = DataBundle() | |||||
for data_name in path.keys(): | |||||
data_info.datasets[data_name] = self._load(path[data_name]) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if auto_set_input: | |||||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||||
if auto_set_target: | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.set_target(Const.TARGET) | |||||
if extra_split is not None: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||||
for s in extra_split: | |||||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||||
new_field_name=Const.INPUTS(0)) | |||||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||||
new_field_name=Const.INPUTS(0)) | |||||
_filt = lambda x: x | |||||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(0)].split(' '))), | |||||
new_field_name=Const.INPUTS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(1)].split(' '))), | |||||
new_field_name=Const.INPUTS(1), is_input=auto_set_input) | |||||
_filt = None | |||||
if to_lower: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||||
is_input=auto_set_input) | |||||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||||
is_input=auto_set_input) | |||||
if bert_tokenizer is not None: | |||||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||||
PRETRAIN_URL = _get_base_url('bert') | |||||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||||
model_url = PRETRAIN_URL + model_name | |||||
model_dir = cached_path(model_url, name='embedding') | |||||
# 检查是否存在 | |||||
elif os.path.isdir(bert_tokenizer): | |||||
model_dir = bert_tokenizer | |||||
else: | |||||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||||
lines = f.readlines() | |||||
lines = [line.strip() for line in lines] | |||||
words_vocab.add_word_lst(lines) | |||||
words_vocab.build_vocab() | |||||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if isinstance(concat, bool): | |||||
concat = 'default' if concat else None | |||||
if concat is not None: | |||||
if isinstance(concat, str): | |||||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||||
'default': ['', '<sep>', '', '']} | |||||
if concat.lower() in CONCAT_MAP: | |||||
concat = CONCAT_MAP[concat] | |||||
else: | |||||
concat = 4 * [concat] | |||||
assert len(concat) == 4, \ | |||||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||||
f'sentence. Your input is {concat}' | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||||
is_input=auto_set_input) | |||||
if seq_len_type is not None: | |||||
if seq_len_type == 'seq_len': # | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'mask': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [1] * len(x[fields]), | |||||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||||
is_input=auto_set_input) | |||||
elif seq_len_type == 'bert': | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if Const.INPUT not in data_set.get_field_names(): | |||||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||||
f'got {data_set.get_field_names()}') | |||||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||||
if auto_pad_length is not None: | |||||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||||
if cut_text is not None: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
data_set_list = [d for n, d in data_info.datasets.items()] | |||||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||||
if bert_tokenizer is None: | |||||
words_vocab = Vocabulary(padding=auto_pad_token) | |||||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=[n for n in data_set_list[0].get_field_names() | |||||
if (Const.INPUT in n)], | |||||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||||
if 'train' not in n]) | |||||
target_vocab = Vocabulary(padding=None, unknown=None) | |||||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||||
field_name=Const.TARGET) | |||||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||||
if get_index: | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
if Const.TARGET in data_set.get_field_names(): | |||||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||||
is_input=auto_set_input, is_target=auto_set_target) | |||||
if auto_pad_length is not None: | |||||
if seq_len_type == 'seq_len': | |||||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
for fields in data_set.get_field_names(): | |||||
if Const.INPUT in fields: | |||||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||||
is_input=auto_set_input) | |||||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||||
new_field_name=fields, is_input=auto_set_input) | |||||
for data_name, data_set in data_info.datasets.items(): | |||||
if isinstance(set_input, list): | |||||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||||
if isinstance(set_target, list): | |||||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||||
return data_info |
@@ -1,62 +0,0 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class MNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev_matched': 'dev_matched.tsv', | |||||
'dev_mismatched': 'dev_mismatched.tsv', | |||||
'test_matched': 'test_matched.tsv', | |||||
'test_mismatched': 'test_mismatched.tsv', | |||||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t') | |||||
self.fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
if ds[0][Const.TARGET] == 'hidden': | |||||
ds.delete_field(Const.TARGET) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
if Const.TARGET in ds.get_field_names(): | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds |
@@ -1,68 +0,0 @@ | |||||
from typing import Union, Dict | |||||
from ..data_bundle import DataBundle | |||||
from ..dataset_loader import CSVLoader | |||||
from ...core.vocabulary import Vocabulary, VocabularyOption | |||||
from ...core.const import Const | |||||
from ..utils import check_loader_paths | |||||
class MTL16Loader(CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.MTL16Loader` :class:`fastNLP.io.data_loader.MTL16Loader` | |||||
读取MTL16数据集,DataSet包含以下fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
数据来源:https://pan.baidu.com/s/1c2L6vdA | |||||
""" | |||||
def __init__(self): | |||||
super(MTL16Loader, self).__init__(headers=(Const.TARGET, Const.INPUT), sep='\t') | |||||
def _load(self, path): | |||||
dataset = super(MTL16Loader, self)._load(path) | |||||
dataset.apply(lambda x: x[Const.INPUT].lower().split(), new_field_name=Const.INPUT) | |||||
if len(dataset) == 0: | |||||
raise RuntimeError(f"{path} has no valid data.") | |||||
return dataset | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None,): | |||||
paths = check_loader_paths(paths) | |||||
datasets = {} | |||||
info = DataBundle() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT) | |||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info |
@@ -1,85 +0,0 @@ | |||||
from ..data_bundle import DataSetLoader | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.const import Const | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.data_loader.PeopleDailyCorpusLoader` | |||||
读取人民日报数据集 | |||||
""" | |||||
def __init__(self, pos=True, ner=True): | |||||
super(PeopleDailyCorpusLoader, self).__init__() | |||||
self.pos = pos | |||||
self.ner = ner | |||||
def _load(self, data_path): | |||||
with open(data_path, "r", encoding="utf-8") as f: | |||||
sents = f.readlines() | |||||
examples = [] | |||||
for sent in sents: | |||||
if len(sent) <= 2: | |||||
continue | |||||
inside_ne = False | |||||
sent_pos_tag = [] | |||||
sent_words = [] | |||||
sent_ner = [] | |||||
words = sent.strip().split()[1:] | |||||
for word in words: | |||||
if "[" in word and "]" in word: | |||||
ner_tag = "U" | |||||
print(word) | |||||
elif "[" in word: | |||||
inside_ne = True | |||||
ner_tag = "B" | |||||
word = word[1:] | |||||
elif "]" in word: | |||||
ner_tag = "L" | |||||
word = word[:word.index("]")] | |||||
if inside_ne is True: | |||||
inside_ne = False | |||||
else: | |||||
raise RuntimeError("only ] appears!") | |||||
else: | |||||
if inside_ne is True: | |||||
ner_tag = "I" | |||||
else: | |||||
ner_tag = "O" | |||||
tmp = word.split("/") | |||||
token, pos = tmp[0], tmp[1] | |||||
sent_ner.append(ner_tag) | |||||
sent_pos_tag.append(pos) | |||||
sent_words.append(token) | |||||
example = [sent_words] | |||||
if self.pos is True: | |||||
example.append(sent_pos_tag) | |||||
if self.ner is True: | |||||
example.append(sent_ner) | |||||
examples.append(example) | |||||
return self.convert(examples) | |||||
def convert(self, data): | |||||
""" | |||||
:param data: python 内置对象 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
data_set = DataSet() | |||||
for item in data: | |||||
sent_words = item[0] | |||||
if self.pos is True and self.ner is True: | |||||
instance = Instance( | |||||
words=sent_words, pos_tags=item[1], ner=item[2]) | |||||
elif self.pos is True: | |||||
instance = Instance(words=sent_words, pos_tags=item[1]) | |||||
elif self.ner is True: | |||||
instance = Instance(words=sent_words, ner=item[1]) | |||||
else: | |||||
instance = Instance(words=sent_words) | |||||
data_set.append(instance) | |||||
data_set.apply(lambda ins: len(ins[Const.INPUT]), new_field_name=Const.INPUT_LEN) | |||||
return data_set |
@@ -1,47 +0,0 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class QNLILoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.data_loader.QNLILoader` | |||||
读取QNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'question': Const.INPUTS(0), | |||||
'sentence': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds |
@@ -1,34 +0,0 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class QuoraLoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.data_loader.QuoraLoader` | |||||
读取MNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv', | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID')) | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
return ds |
@@ -1,47 +0,0 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import CSVLoader | |||||
class RTELoader(MatchingLoader, CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.data_loader.RTELoader` | |||||
读取RTE数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
paths = paths if paths is not None else { | |||||
'train': 'train.tsv', | |||||
'dev': 'dev.tsv', | |||||
'test': 'test.tsv' # test set has not label | |||||
} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
self.fields = { | |||||
'sentence1': Const.INPUTS(0), | |||||
'sentence2': Const.INPUTS(1), | |||||
'label': Const.TARGET, | |||||
} | |||||
CSVLoader.__init__(self, sep='\t') | |||||
def _load(self, path): | |||||
ds = CSVLoader._load(self, path) | |||||
for k, v in self.fields.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
for fields in ds.get_all_fields(): | |||||
if Const.INPUT in fields: | |||||
ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields) | |||||
return ds |
@@ -1,46 +0,0 @@ | |||||
from ...core.const import Const | |||||
from .matching import MatchingLoader | |||||
from ..dataset_loader import JsonLoader | |||||
class SNLILoader(MatchingLoader, JsonLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.data_loader.SNLILoader` | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | |||||
def __init__(self, paths: dict=None): | |||||
fields = { | |||||
'sentence1_binary_parse': Const.INPUTS(0), | |||||
'sentence2_binary_parse': Const.INPUTS(1), | |||||
'gold_label': Const.TARGET, | |||||
} | |||||
paths = paths if paths is not None else { | |||||
'train': 'snli_1.0_train.jsonl', | |||||
'dev': 'snli_1.0_dev.jsonl', | |||||
'test': 'snli_1.0_test.jsonl'} | |||||
MatchingLoader.__init__(self, paths=paths) | |||||
JsonLoader.__init__(self, fields=fields) | |||||
def _load(self, path): | |||||
ds = JsonLoader._load(self, path) | |||||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(0)) | |||||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||||
new_field_name=Const.INPUTS(1)) | |||||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||||
return ds |
@@ -1,180 +0,0 @@ | |||||
from typing import Union, Dict | |||||
from nltk import Tree | |||||
from ..data_bundle import DataBundle, DataSetLoader | |||||
from ..dataset_loader import CSVLoader | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ...core.dataset import DataSet | |||||
from ...core.const import Const | |||||
from ...core.instance import Instance | |||||
from ..utils import check_loader_paths, get_tokenizer | |||||
class SSTLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.data_loader.SSTLoader` | |||||
读取SST数据集, DataSet包含fields:: | |||||
words: list(str) 需要分类的文本 | |||||
target: str 文本的标签 | |||||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
""" | |||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | |||||
DATA_DIR = 'sst/' | |||||
def __init__(self, subtree=False, fine_grained=False): | |||||
self.subtree = subtree | |||||
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', | |||||
'3': 'positive', '4': 'very positive'} | |||||
if not fine_grained: | |||||
tag_v['0'] = tag_v['1'] | |||||
tag_v['4'] = tag_v['3'] | |||||
self.tag_v = tag_v | |||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | |||||
""" | |||||
:param str path: 存储数据的路径 | |||||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
datas = [] | |||||
for l in f: | |||||
datas.extend([(s, self.tag_v[t]) | |||||
for s, t in self._get_one(l, self.subtree)]) | |||||
ds = DataSet() | |||||
for words, tag in datas: | |||||
ds.append(Instance(words=words, target=tag)) | |||||
return ds | |||||
def _get_one(self, data, subtree): | |||||
tree = Tree.fromstring(data) | |||||
if subtree: | |||||
return [(self.tokenizer(' '.join(t.leaves())), t.label()) for t in tree.subtrees() ] | |||||
return [(self.tokenizer(' '.join(tree.leaves())), tree.label())] | |||||
def process(self, | |||||
paths, train_subtree=True, | |||||
src_vocab_op: VocabularyOption = None, | |||||
tgt_vocab_op: VocabularyOption = None,): | |||||
paths = check_loader_paths(paths) | |||||
input_name, target_name = 'words', 'target' | |||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
info = DataBundle() | |||||
origin_subtree = self.subtree | |||||
self.subtree = train_subtree | |||||
info.datasets['train'] = self._load(paths['train']) | |||||
self.subtree = origin_subtree | |||||
for n, p in paths.items(): | |||||
if n != 'train': | |||||
info.datasets[n] = self._load(p) | |||||
src_vocab.from_dataset( | |||||
info.datasets['train'], | |||||
field_name=input_name, | |||||
no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) | |||||
tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) | |||||
src_vocab.index_dataset( | |||||
*info.datasets.values(), | |||||
field_name=input_name, new_field_name=input_name) | |||||
tgt_vocab.index_dataset( | |||||
*info.datasets.values(), | |||||
field_name=target_name, new_field_name=target_name) | |||||
info.vocabs = { | |||||
input_name: src_vocab, | |||||
target_name: tgt_vocab | |||||
} | |||||
return info | |||||
class SST2Loader(CSVLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.SST2Loader` :class:`fastNLP.io.data_loader.SST2Loader` | |||||
数据来源 SST: https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8 | |||||
""" | |||||
def __init__(self): | |||||
super(SST2Loader, self).__init__(sep='\t') | |||||
self.tokenizer = get_tokenizer() | |||||
self.field = {'sentence': Const.INPUT, 'label': Const.TARGET} | |||||
def _load(self, path: str) -> DataSet: | |||||
ds = super(SST2Loader, self)._load(path) | |||||
for k, v in self.field.items(): | |||||
if k in ds.get_field_names(): | |||||
ds.rename_field(k, v) | |||||
ds.apply(lambda x: self.tokenizer(x[Const.INPUT]), new_field_name=Const.INPUT) | |||||
print("all count:", len(ds)) | |||||
return ds | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None, | |||||
char_level_op=False): | |||||
paths = check_loader_paths(paths) | |||||
datasets = {} | |||||
info = DataBundle() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
dataset.apply_field(lambda words:words.copy(), field_name='words', new_field_name='raw_words') | |||||
datasets[name] = dataset | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
input_name, target_name = Const.INPUT, Const.TARGET | |||||
info.vocabs={} | |||||
# 就分隔为char形式 | |||||
if char_level_op: | |||||
for dataset in datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[ | |||||
dataset for name, dataset in datasets.items() if name!='train' | |||||
]) | |||||
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name=Const.TARGET) | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name=Const.TARGET) | |||||
info.vocabs = { | |||||
Const.INPUT: src_vocab, | |||||
Const.TARGET: tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
@@ -1,132 +0,0 @@ | |||||
import csv | |||||
from typing import Iterable | |||||
from ...core.const import Const | |||||
from ...core.dataset import DataSet | |||||
from ...core.instance import Instance | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||||
from ..data_bundle import DataBundle, DataSetLoader | |||||
from typing import Union, Dict | |||||
from ..utils import check_loader_paths, get_tokenizer | |||||
class YelpLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.data_loader.YelpLoader` | |||||
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields: | |||||
words: list(str), 需要分类的文本 | |||||
target: str, 文本的标签 | |||||
chars:list(str),未index的字符列表 | |||||
数据集:yelp_full/yelp_polarity | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
:param lower: 是否需要自动转小写,默认为False。 | |||||
""" | |||||
def __init__(self, fine_grained=False, lower=False): | |||||
super(YelpLoader, self).__init__() | |||||
tag_v = {'1.0': 'very negative', '2.0': 'negative', '3.0': 'neutral', | |||||
'4.0': 'positive', '5.0': 'very positive'} | |||||
if not fine_grained: | |||||
tag_v['1.0'] = tag_v['2.0'] | |||||
tag_v['5.0'] = tag_v['4.0'] | |||||
self.fine_grained = fine_grained | |||||
self.tag_v = tag_v | |||||
self.lower = lower | |||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
csv_reader = csv.reader(open(path, encoding='utf-8')) | |||||
all_count = 0 | |||||
real_count = 0 | |||||
for row in csv_reader: | |||||
all_count += 1 | |||||
if len(row) == 2: | |||||
target = self.tag_v[row[0] + ".0"] | |||||
words = clean_str(row[1], self.tokenizer, self.lower) | |||||
if len(words) != 0: | |||||
ds.append(Instance(words=words, target=target)) | |||||
real_count += 1 | |||||
print("all count:", all_count) | |||||
print("real count:", real_count) | |||||
return ds | |||||
def process(self, paths: Union[str, Dict[str, str]], | |||||
train_ds: Iterable[str] = None, | |||||
src_vocab_op: VocabularyOption = None, | |||||
tgt_vocab_op: VocabularyOption = None, | |||||
char_level_op=False): | |||||
paths = check_loader_paths(paths) | |||||
info = DataBundle(datasets=self.load(paths)) | |||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | |||||
_train_ds = [info.datasets[name] | |||||
for name in train_ds] if train_ds else info.datasets.values() | |||||
def wordtochar(words): | |||||
chars = [] | |||||
for word in words: | |||||
word = word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
chars.append('') | |||||
chars.pop() | |||||
return chars | |||||
input_name, target_name = Const.INPUT, Const.TARGET | |||||
info.vocabs = {} | |||||
# 就分隔为char形式 | |||||
if char_level_op: | |||||
for dataset in info.datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT) | |||||
else: | |||||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
src_vocab.index_dataset(*info.datasets.values(), field_name=input_name, new_field_name=input_name) | |||||
info.vocabs[input_name] = src_vocab | |||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
tgt_vocab.index_dataset( | |||||
*info.datasets.values(), | |||||
field_name=target_name, new_field_name=target_name) | |||||
info.vocabs[target_name] = tgt_vocab | |||||
info.datasets['train'], info.datasets['dev'] = info.datasets['train'].split(0.1, shuffle=False) | |||||
for name, dataset in info.datasets.items(): | |||||
dataset.set_input(Const.INPUT) | |||||
dataset.set_target(Const.TARGET) | |||||
return info | |||||
def clean_str(sentence, tokenizer, char_lower=False): | |||||
""" | |||||
heavily borrowed from github | |||||
https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb | |||||
:param sentence: is a str | |||||
:return: | |||||
""" | |||||
if char_lower: | |||||
sentence = sentence.lower() | |||||
import re | |||||
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') | |||||
words = tokenizer(sentence) | |||||
words_collection = [] | |||||
for word in words: | |||||
if word in ['-lrb-', '-rrb-', '<sssss>', '-r', '-l', 'b-']: | |||||
continue | |||||
tt = nonalpnum.split(word) | |||||
t = ''.join(tt) | |||||
if t != '': | |||||
words_collection.append(t) | |||||
return words_collection | |||||
@@ -1,121 +0,0 @@ | |||||
"""undocumented | |||||
.. warning:: | |||||
本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 | |||||
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , | |||||
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。 | |||||
以SNLI数据集为例:: | |||||
loader = SNLILoader() | |||||
train_ds = loader.load('path/to/train') | |||||
dev_ds = loader.load('path/to/dev') | |||||
test_ds = loader.load('path/to/test') | |||||
# ... do stuff | |||||
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。 | |||||
""" | |||||
__all__ = [ | |||||
'CSVLoader', | |||||
'JsonLoader', | |||||
] | |||||
from .data_bundle import DataSetLoader | |||||
from .file_reader import _read_csv, _read_json | |||||
from ..core.dataset import DataSet | |||||
from ..core.instance import Instance | |||||
class JsonLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` | |||||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , | |||||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 | |||||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||||
Default: ``False`` | |||||
""" | |||||
def __init__(self, fields=None, dropna=False): | |||||
super(JsonLoader, self).__init__() | |||||
self.dropna = dropna | |||||
self.fields = None | |||||
self.fields_list = None | |||||
if fields: | |||||
self.fields = {} | |||||
for k, v in fields.items(): | |||||
self.fields[k] = k if v is None else v | |||||
self.fields_list = list(self.fields.keys()) | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
if self.fields: | |||||
ins = {self.fields[k]: v for k, v in d.items()} | |||||
else: | |||||
ins = d | |||||
ds.append(Instance(**ins)) | |||||
return ds | |||||
class CSVLoader(DataSetLoader): | |||||
""" | |||||
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` | |||||
读取CSV格式的数据集。返回 ``DataSet`` | |||||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||||
Default: ``False`` | |||||
""" | |||||
def __init__(self, headers=None, sep=",", dropna=False): | |||||
self.headers = headers | |||||
self.sep = sep | |||||
self.dropna = dropna | |||||
def _load(self, path): | |||||
ds = DataSet() | |||||
for idx, data in _read_csv(path, headers=self.headers, | |||||
sep=self.sep, dropna=self.dropna): | |||||
ds.append(Instance(**data)) | |||||
return ds | |||||
def _cut_long_sentence(sent, max_sample_length=200): | |||||
""" | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。 | |||||
所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | |||||
:param max_sample_length: int. | |||||
:return: list of str. | |||||
""" | |||||
sent_no_space = sent.replace(' ', '') | |||||
cutted_sentence = [] | |||||
if len(sent_no_space) > max_sample_length: | |||||
parts = sent.strip().split() | |||||
new_line = '' | |||||
length = 0 | |||||
for part in parts: | |||||
length += len(part) | |||||
new_line += part + ' ' | |||||
if length > max_sample_length: | |||||
new_line = new_line[:-1] | |||||
cutted_sentence.append(new_line) | |||||
length = 0 | |||||
new_line = '' | |||||
if new_line != '': | |||||
cutted_sentence.append(new_line[:-1]) | |||||
else: | |||||
cutted_sentence.append(sent) | |||||
return cutted_sentence |
@@ -13,7 +13,6 @@ import warnings | |||||
import numpy as np | import numpy as np | ||||
from .data_bundle import BaseLoader | |||||
from ..core.utils import Option | from ..core.utils import Option | ||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
@@ -32,7 +31,7 @@ class EmbeddingOption(Option): | |||||
) | ) | ||||
class EmbedLoader(BaseLoader): | |||||
class EmbedLoader: | |||||
""" | """ | ||||
别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` | 别名::class:`fastNLP.io.EmbedLoader` :class:`fastNLP.io.embed_loader.EmbedLoader` | ||||
@@ -84,9 +83,9 @@ class EmbedLoader(BaseLoader): | |||||
word = ''.join(parts[:-dim]) | word = ''.join(parts[:-dim]) | ||||
nums = parts[-dim:] | nums = parts[-dim:] | ||||
# 对齐unk与pad | # 对齐unk与pad | ||||
if word==padding and vocab.padding is not None: | |||||
if word == padding and vocab.padding is not None: | |||||
word = vocab.padding | word = vocab.padding | ||||
elif word==unknown and vocab.unknown is not None: | |||||
elif word == unknown and vocab.unknown is not None: | |||||
word = vocab.unknown | word = vocab.unknown | ||||
if word in vocab: | if word in vocab: | ||||
index = vocab.to_index(word) | index = vocab.to_index(word) | ||||
@@ -171,7 +170,7 @@ class EmbedLoader(BaseLoader): | |||||
index = vocab.to_index(key) | index = vocab.to_index(key) | ||||
matrix[index] = vec | matrix[index] = vec | ||||
if (unknown is not None and not found_unknown) or (padding is not None and not found_pad): | |||||
if ((unknown is not None) and (not found_unknown)) or ((padding is not None) and (not found_pad)): | |||||
start_idx = 0 | start_idx = 0 | ||||
if padding is not None: | if padding is not None: | ||||
start_idx += 1 | start_idx += 1 | ||||
@@ -180,9 +179,9 @@ class EmbedLoader(BaseLoader): | |||||
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | ||||
std = np.std(matrix[start_idx:], axis=0, keepdims=True) | std = np.std(matrix[start_idx:], axis=0, keepdims=True) | ||||
if (unknown is not None and not found_unknown): | |||||
if (unknown is not None) and (not found_unknown): | |||||
matrix[start_idx - 1] = np.random.randn(1, dim).astype(dtype) * std + mean | matrix[start_idx - 1] = np.random.randn(1, dim).astype(dtype) * std + mean | ||||
if (padding is not None and not found_pad): | |||||
if (padding is not None) and (not found_pad): | |||||
matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean | matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean | ||||
if normalize: | if normalize: | ||||
@@ -8,10 +8,8 @@ __all__ = [ | |||||
import torch | import torch | ||||
from .data_bundle import BaseLoader | |||||
class ModelLoader(BaseLoader): | |||||
class ModelLoader: | |||||
""" | """ | ||||
别名::class:`fastNLP.io.ModelLoader` :class:`fastNLP.io.model_io.ModelLoader` | 别名::class:`fastNLP.io.ModelLoader` :class:`fastNLP.io.model_io.ModelLoader` | ||||
@@ -1,228 +0,0 @@ | |||||
"""undocumented | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
A module with NAS controller-related code. | |||||
""" | |||||
__all__ = [] | |||||
import collections | |||||
import os | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from . import enas_utils as utils | |||||
from .enas_utils import Node | |||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks): | |||||
"""Constructs a set of DAGs based on the actions, i.e., previous nodes and | |||||
activation functions, sampled from the controller/policy pi. | |||||
Args: | |||||
prev_nodes: Previous node actions from the policy. | |||||
activations: Activations sampled from the policy. | |||||
func_names: Mapping from activation function names to functions. | |||||
num_blocks: Number of blocks in the target RNN cell. | |||||
Returns: | |||||
A list of DAGs defined by the inputs. | |||||
RNN cell DAGs are represented in the following way: | |||||
1. Each element (node) in a DAG is a list of `Node`s. | |||||
2. The `Node`s in the list dag[i] correspond to the subsequent nodes | |||||
that take the output from node i as their own input. | |||||
3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. | |||||
dag[-1] always feeds dag[0]. | |||||
dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its | |||||
weights. | |||||
4. dag[N - 1] is the node that produces the hidden state passed to | |||||
the next timestep. dag[N - 1] is also always a leaf node, and therefore | |||||
is always averaged with the other leaf nodes and fed to the output | |||||
decoder. | |||||
""" | |||||
dags = [] | |||||
for nodes, func_ids in zip(prev_nodes, activations): | |||||
dag = collections.defaultdict(list) | |||||
# add first node | |||||
dag[-1] = [Node(0, func_names[func_ids[0]])] | |||||
dag[-2] = [Node(0, func_names[func_ids[0]])] | |||||
# add following nodes | |||||
for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): | |||||
dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) | |||||
leaf_nodes = set(range(num_blocks)) - dag.keys() | |||||
# merge with avg | |||||
for idx in leaf_nodes: | |||||
dag[idx] = [Node(num_blocks, 'avg')] | |||||
# This is actually y^{(t)}. h^{(t)} is node N - 1 in | |||||
# the graph, where N Is the number of nodes. I.e., h^{(t)} takes | |||||
# only one other node as its input. | |||||
# last h[t] node | |||||
last_node = Node(num_blocks + 1, 'h[t]') | |||||
dag[num_blocks] = [last_node] | |||||
dags.append(dag) | |||||
return dags | |||||
class Controller(torch.nn.Module): | |||||
"""Based on | |||||
https://github.com/pytorch/examples/blob/master/word_language_model/model.py | |||||
RL controllers do not necessarily have much to do with | |||||
language models. | |||||
Base the controller RNN on the GRU from: | |||||
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py | |||||
""" | |||||
def __init__(self, num_blocks=4, controller_hid=100, cuda=False): | |||||
torch.nn.Module.__init__(self) | |||||
# `num_tokens` here is just the activation function | |||||
# for every even step, | |||||
self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid'] | |||||
self.num_tokens = [len(self.shared_rnn_activations)] | |||||
self.controller_hid = controller_hid | |||||
self.use_cuda = cuda | |||||
self.num_blocks = num_blocks | |||||
for idx in range(num_blocks): | |||||
self.num_tokens += [idx + 1, len(self.shared_rnn_activations)] | |||||
self.func_names = self.shared_rnn_activations | |||||
num_total_tokens = sum(self.num_tokens) | |||||
self.encoder = torch.nn.Embedding(num_total_tokens, | |||||
controller_hid) | |||||
self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid) | |||||
# Perhaps these weights in the decoder should be | |||||
# shared? At least for the activation functions, which all have the | |||||
# same size. | |||||
self.decoders = [] | |||||
for idx, size in enumerate(self.num_tokens): | |||||
decoder = torch.nn.Linear(controller_hid, size) | |||||
self.decoders.append(decoder) | |||||
self._decoders = torch.nn.ModuleList(self.decoders) | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def _get_default_hidden(key): | |||||
return utils.get_variable( | |||||
torch.zeros(key, self.controller_hid), | |||||
self.use_cuda, | |||||
requires_grad=False) | |||||
self.static_inputs = utils.keydefaultdict(_get_default_hidden) | |||||
def reset_parameters(self): | |||||
init_range = 0.1 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
for decoder in self.decoders: | |||||
decoder.bias.data.fill_(0) | |||||
def forward(self, # pylint:disable=arguments-differ | |||||
inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed): | |||||
if not is_embed: | |||||
embed = self.encoder(inputs) | |||||
else: | |||||
embed = inputs | |||||
hx, cx = self.lstm(embed, hidden) | |||||
logits = self.decoders[block_idx](hx) | |||||
logits /= 5.0 | |||||
# # exploration | |||||
# if self.args.mode == 'train': | |||||
# logits = (2.5 * F.tanh(logits)) | |||||
return logits, (hx, cx) | |||||
def sample(self, batch_size=1, with_details=False, save_dir=None): | |||||
"""Samples a set of `args.num_blocks` many computational nodes from the | |||||
controller, where each node is made up of an activation function, and | |||||
each node except the last also includes a previous node. | |||||
""" | |||||
if batch_size < 1: | |||||
raise Exception(f'Wrong batch_size: {batch_size} < 1') | |||||
# [B, L, H] | |||||
inputs = self.static_inputs[batch_size] | |||||
hidden = self.static_init_hidden[batch_size] | |||||
activations = [] | |||||
entropies = [] | |||||
log_probs = [] | |||||
prev_nodes = [] | |||||
# The RNN controller alternately outputs an activation, | |||||
# followed by a previous node, for each block except the last one, | |||||
# which only gets an activation function. The last node is the output | |||||
# node, and its previous node is the average of all leaf nodes. | |||||
for block_idx in range(2*(self.num_blocks - 1) + 1): | |||||
logits, hidden = self.forward(inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed=(block_idx == 0)) | |||||
probs = F.softmax(logits, dim=-1) | |||||
log_prob = F.log_softmax(logits, dim=-1) | |||||
# .mean() for entropy? | |||||
entropy = -(log_prob * probs).sum(1, keepdim=False) | |||||
action = probs.multinomial(num_samples=1).data | |||||
selected_log_prob = log_prob.gather( | |||||
1, utils.get_variable(action, requires_grad=False)) | |||||
# why the [:, 0] here? Should it be .squeeze(), or | |||||
# .view()? Same below with `action`. | |||||
entropies.append(entropy) | |||||
log_probs.append(selected_log_prob[:, 0]) | |||||
# 0: function, 1: previous node | |||||
mode = block_idx % 2 | |||||
inputs = utils.get_variable( | |||||
action[:, 0] + sum(self.num_tokens[:mode]), | |||||
requires_grad=False) | |||||
if mode == 0: | |||||
activations.append(action[:, 0]) | |||||
elif mode == 1: | |||||
prev_nodes.append(action[:, 0]) | |||||
prev_nodes = torch.stack(prev_nodes).transpose(0, 1) | |||||
activations = torch.stack(activations).transpose(0, 1) | |||||
dags = _construct_dags(prev_nodes, | |||||
activations, | |||||
self.func_names, | |||||
self.num_blocks) | |||||
if save_dir is not None: | |||||
for idx, dag in enumerate(dags): | |||||
utils.draw_network(dag, | |||||
os.path.join(save_dir, f'graph{idx}.png')) | |||||
if with_details: | |||||
return dags, torch.cat(log_probs), torch.cat(entropies) | |||||
return dags | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.controller_hid) | |||||
return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), | |||||
utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False)) |
@@ -1,393 +0,0 @@ | |||||
"""undocumented | |||||
Module containing the shared RNN model. | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
""" | |||||
__all__ = [] | |||||
import collections | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
from torch.autograd import Variable | |||||
from . import enas_utils as utils | |||||
from .base_model import BaseModel | |||||
def _get_dropped_weights(w_raw, dropout_p, is_training): | |||||
"""Drops out weights to implement DropConnect. | |||||
Args: | |||||
w_raw: Full, pre-dropout, weights to be dropped out. | |||||
dropout_p: Proportion of weights to drop out. | |||||
is_training: True iff _shared_ model is training. | |||||
Returns: | |||||
The dropped weights. | |||||
Why does torch.nn.functional.dropout() return: | |||||
1. `torch.autograd.Variable()` on the training loop | |||||
2. `torch.nn.Parameter()` on the controller or eval loop, when | |||||
training = False... | |||||
Even though the call to `_setweights` in the Smerity repo's | |||||
`weight_drop.py` does not have this behaviour, and `F.dropout` always | |||||
returns `torch.autograd.Variable` there, even when `training=False`? | |||||
The above TODO is the reason for the hacky check for `torch.nn.Parameter`. | |||||
""" | |||||
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) | |||||
if isinstance(dropped_w, torch.nn.Parameter): | |||||
dropped_w = dropped_w.clone() | |||||
return dropped_w | |||||
class EmbeddingDropout(torch.nn.Embedding): | |||||
"""Class for dropping out embeddings by zero'ing out parameters in the | |||||
embedding matrix. | |||||
This is equivalent to dropping out particular words, e.g., in the sentence | |||||
'the quick brown fox jumps over the lazy dog', dropping out 'the' would | |||||
lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the | |||||
embedding vector space). | |||||
See 'A Theoretically Grounded Application of Dropout in Recurrent Neural | |||||
Networks', (Gal and Ghahramani, 2016). | |||||
""" | |||||
def __init__(self, | |||||
num_embeddings, | |||||
embedding_dim, | |||||
max_norm=None, | |||||
norm_type=2, | |||||
scale_grad_by_freq=False, | |||||
sparse=False, | |||||
dropout=0.1, | |||||
scale=None): | |||||
"""Embedding constructor. | |||||
Args: | |||||
dropout: Dropout probability. | |||||
scale: Used to scale parameters of embedding weight matrix that are | |||||
not dropped out. Note that this is _in addition_ to the | |||||
`1/(1 - dropout)` scaling. | |||||
See `torch.nn.Embedding` for remaining arguments. | |||||
""" | |||||
torch.nn.Embedding.__init__(self, | |||||
num_embeddings=num_embeddings, | |||||
embedding_dim=embedding_dim, | |||||
max_norm=max_norm, | |||||
norm_type=norm_type, | |||||
scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse) | |||||
self.dropout = dropout | |||||
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' | |||||
'and < 1.0') | |||||
self.scale = scale | |||||
def forward(self, inputs): # pylint:disable=arguments-differ | |||||
"""Embeds `inputs` with the dropped out embedding weight matrix.""" | |||||
if self.training: | |||||
dropout = self.dropout | |||||
else: | |||||
dropout = 0 | |||||
if dropout: | |||||
mask = self.weight.data.new(self.weight.size(0), 1) | |||||
mask.bernoulli_(1 - dropout) | |||||
mask = mask.expand_as(self.weight) | |||||
mask = mask / (1 - dropout) | |||||
masked_weight = self.weight * Variable(mask) | |||||
else: | |||||
masked_weight = self.weight | |||||
if self.scale and self.scale != 1: | |||||
masked_weight = masked_weight * self.scale | |||||
return F.embedding(inputs, | |||||
masked_weight, | |||||
max_norm=self.max_norm, | |||||
norm_type=self.norm_type, | |||||
scale_grad_by_freq=self.scale_grad_by_freq, | |||||
sparse=self.sparse) | |||||
class LockedDropout(nn.Module): | |||||
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py | |||||
def __init__(self): | |||||
super().__init__() | |||||
def forward(self, x, dropout=0.5): | |||||
if not self.training or not dropout: | |||||
return x | |||||
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) | |||||
mask = Variable(m, requires_grad=False) / (1 - dropout) | |||||
mask = mask.expand_as(x) | |||||
return mask * x | |||||
class ENASModel(BaseModel): | |||||
"""Shared RNN model.""" | |||||
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): | |||||
super(ENASModel, self).__init__() | |||||
self.use_cuda = cuda | |||||
self.shared_hid = shared_hid | |||||
self.num_blocks = num_blocks | |||||
self.decoder = nn.Linear(self.shared_hid, num_classes) | |||||
self.encoder = EmbeddingDropout(embed_num, | |||||
shared_embed, | |||||
dropout=0.1) | |||||
self.lockdrop = LockedDropout() | |||||
self.dag = None | |||||
# Tie weights | |||||
# self.decoder.weight = self.encoder.weight | |||||
# Since W^{x, c} and W^{h, c} are always summed, there | |||||
# is no point duplicating their bias offset parameter. Likewise for | |||||
# W^{x, h} and W^{h, h}. | |||||
self.w_xc = nn.Linear(shared_embed, self.shared_hid) | |||||
self.w_xh = nn.Linear(shared_embed, self.shared_hid) | |||||
# The raw weights are stored here because the hidden-to-hidden weights | |||||
# are weight dropped on the forward pass. | |||||
self.w_hc_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hh_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hc = None | |||||
self.w_hh = None | |||||
self.w_h = collections.defaultdict(dict) | |||||
self.w_c = collections.defaultdict(dict) | |||||
for idx in range(self.num_blocks): | |||||
for jdx in range(idx + 1, self.num_blocks): | |||||
self.w_h[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self.w_c[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self._w_h = nn.ModuleList([self.w_h[idx][jdx] | |||||
for idx in self.w_h | |||||
for jdx in self.w_h[idx]]) | |||||
self._w_c = nn.ModuleList([self.w_c[idx][jdx] | |||||
for idx in self.w_c | |||||
for jdx in self.w_c[idx]]) | |||||
self.batch_norm = None | |||||
# if args.mode == 'train': | |||||
# self.batch_norm = nn.BatchNorm1d(self.shared_hid) | |||||
# else: | |||||
# self.batch_norm = None | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def setDAG(self, dag): | |||||
if self.dag is None: | |||||
self.dag = dag | |||||
def forward(self, word_seq, hidden=None): | |||||
inputs = torch.transpose(word_seq, 0, 1) | |||||
time_steps = inputs.size(0) | |||||
batch_size = inputs.size(1) | |||||
self.w_hh = _get_dropped_weights(self.w_hh_raw, | |||||
0.5, | |||||
self.training) | |||||
self.w_hc = _get_dropped_weights(self.w_hc_raw, | |||||
0.5, | |||||
self.training) | |||||
# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden | |||||
hidden = self.static_init_hidden[batch_size] | |||||
embed = self.encoder(inputs) | |||||
embed = self.lockdrop(embed, 0.65 if self.training else 0) | |||||
# The norm of hidden states are clipped here because | |||||
# otherwise ENAS is especially prone to exploding activations on the | |||||
# forward pass. This could probably be fixed in a more elegant way, but | |||||
# it might be exposing a weakness in the ENAS algorithm as currently | |||||
# proposed. | |||||
# | |||||
# For more details, see | |||||
# https://github.com/carpedm20/ENAS-pytorch/issues/6 | |||||
clipped_num = 0 | |||||
max_clipped_norm = 0 | |||||
h1tohT = [] | |||||
logits = [] | |||||
for step in range(time_steps): | |||||
x_t = embed[step] | |||||
logit, hidden = self.cell(x_t, hidden, self.dag) | |||||
hidden_norms = hidden.norm(dim=-1) | |||||
max_norm = 25.0 | |||||
if hidden_norms.data.max() > max_norm: | |||||
# Just directly use the torch slice operations | |||||
# in PyTorch v0.4. | |||||
# | |||||
# This workaround for PyTorch v0.3.1 does everything in numpy, | |||||
# because the PyTorch slicing and slice assignment is too | |||||
# flaky. | |||||
hidden_norms = hidden_norms.data.cpu().numpy() | |||||
clipped_num += 1 | |||||
if hidden_norms.max() > max_clipped_norm: | |||||
max_clipped_norm = hidden_norms.max() | |||||
clip_select = hidden_norms > max_norm | |||||
clip_norms = hidden_norms[clip_select] | |||||
mask = np.ones(hidden.size()) | |||||
normalizer = max_norm / clip_norms | |||||
normalizer = normalizer[:, np.newaxis] | |||||
mask[clip_select] = normalizer | |||||
if self.use_cuda: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask).cuda(), requires_grad=False) | |||||
else: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask), requires_grad=False) | |||||
logits.append(logit) | |||||
h1tohT.append(hidden) | |||||
h1tohT = torch.stack(h1tohT) | |||||
output = torch.stack(logits) | |||||
raw_output = output | |||||
output = self.lockdrop(output, 0.4 if self.training else 0) | |||||
# Pooling | |||||
output = torch.mean(output, 0) | |||||
decoded = self.decoder(output) | |||||
extra_out = {'dropped': decoded, | |||||
'hiddens': h1tohT, | |||||
'raw': raw_output} | |||||
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} | |||||
def cell(self, x, h_prev, dag): | |||||
"""Computes a single pass through the discovered RNN cell.""" | |||||
c = {} | |||||
h = {} | |||||
f = {} | |||||
f[0] = self.get_f(dag[-1][0].name) | |||||
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) | |||||
h[0] = (c[0] * f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + | |||||
(1 - c[0]) * h_prev) | |||||
leaf_node_ids = [] | |||||
q = collections.deque() | |||||
q.append(0) | |||||
# Computes connections from the parent nodes `node_id` | |||||
# to their child nodes `next_id` recursively, skipping leaf nodes. A | |||||
# leaf node is a node whose id == `self.num_blocks`. | |||||
# | |||||
# Connections between parent i and child j should be computed as | |||||
# h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i, | |||||
# where c_j = \sigmoid{(W^c_{ij}*h_i)} | |||||
# | |||||
# See Training details from Section 3.1 of the paper. | |||||
# | |||||
# The following algorithm does a breadth-first (since `q.popleft()` is | |||||
# used) search over the nodes and computes all the hidden states. | |||||
while True: | |||||
if len(q) == 0: | |||||
break | |||||
node_id = q.popleft() | |||||
nodes = dag[node_id] | |||||
for next_node in nodes: | |||||
next_id = next_node.id | |||||
if next_id == self.num_blocks: | |||||
leaf_node_ids.append(node_id) | |||||
assert len(nodes) == 1, ('parent of leaf node should have ' | |||||
'only one child') | |||||
continue | |||||
w_h = self.w_h[node_id][next_id] | |||||
w_c = self.w_c[node_id][next_id] | |||||
f[next_id] = self.get_f(next_node.name) | |||||
c[next_id] = torch.sigmoid(w_c(h[node_id])) | |||||
h[next_id] = (c[next_id] * f[next_id](w_h(h[node_id])) + | |||||
(1 - c[next_id]) * h[node_id]) | |||||
q.append(next_id) | |||||
# Instead of averaging loose ends, perhaps there should | |||||
# be a set of separate unshared weights for each "loose" connection | |||||
# between each node in a cell and the output. | |||||
# | |||||
# As it stands, all weights W^h_{ij} are doing double duty by | |||||
# connecting both from i to j, as well as from i to the output. | |||||
# average all the loose ends | |||||
leaf_nodes = [h[node_id] for node_id in leaf_node_ids] | |||||
output = torch.mean(torch.stack(leaf_nodes, 2), -1) | |||||
# stabilizing the Updates of omega | |||||
if self.batch_norm is not None: | |||||
output = self.batch_norm(output) | |||||
return output, h[self.num_blocks - 1] | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.shared_hid) | |||||
return utils.get_variable(zeros, self.use_cuda, requires_grad=False) | |||||
def get_f(self, name): | |||||
name = name.lower() | |||||
if name == 'relu': | |||||
f = torch.relu | |||||
elif name == 'tanh': | |||||
f = torch.tanh | |||||
elif name == 'identity': | |||||
f = lambda x: x | |||||
elif name == 'sigmoid': | |||||
f = torch.sigmoid | |||||
return f | |||||
@property | |||||
def num_parameters(self): | |||||
def size(p): | |||||
return np.prod(p.size()) | |||||
return sum([size(param) for param in self.parameters()]) | |||||
def reset_parameters(self): | |||||
init_range = 0.025 | |||||
# init_range = 0.025 if self.args.mode == 'train' else 0.04 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
self.decoder.bias.data.fill_(0) | |||||
def predict(self, word_seq): | |||||
""" | |||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||||
""" | |||||
output = self(word_seq) | |||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} |
@@ -1,384 +0,0 @@ | |||||
"""undocumented | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
""" | |||||
__all__ = [] | |||||
import math | |||||
import time | |||||
from datetime import datetime, timedelta | |||||
import numpy as np | |||||
import torch | |||||
from torch.optim import Adam | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from ..core.utils import _pseudo_tqdm as tqdm | |||||
from ..core.trainer import Trainer | |||||
from ..core.batch import DataSetIter | |||||
from ..core.callback import CallbackException | |||||
from ..core.dataset import DataSet | |||||
from ..core.utils import _move_dict_value_to_device | |||||
from . import enas_utils as utils | |||||
from ..core.utils import _build_args | |||||
def _get_no_grad_ctx_mgr(): | |||||
"""Returns a the `torch.no_grad` context manager for PyTorch version >= | |||||
0.4, or a no-op context manager otherwise. | |||||
""" | |||||
return torch.no_grad() | |||||
class ENASTrainer(Trainer): | |||||
"""A class to wrap training code.""" | |||||
def __init__(self, train_data, model, controller, **kwargs): | |||||
"""Constructor for training algorithm. | |||||
:param DataSet train_data: the training data | |||||
:param torch.nn.modules.module model: a PyTorch model | |||||
:param torch.nn.modules.module controller: a PyTorch model | |||||
""" | |||||
self.final_epochs = kwargs['final_epochs'] | |||||
kwargs.pop('final_epochs') | |||||
super(ENASTrainer, self).__init__(train_data, model, **kwargs) | |||||
self.controller_step = 0 | |||||
self.shared_step = 0 | |||||
self.max_length = 35 | |||||
self.shared = model | |||||
self.controller = controller | |||||
self.shared_optim = Adam( | |||||
self.shared.parameters(), | |||||
lr=20.0, | |||||
weight_decay=1e-7) | |||||
self.controller_optim = Adam( | |||||
self.controller.parameters(), | |||||
lr=3.5e-4) | |||||
def train(self, load_best_model=True): | |||||
""" | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | |||||
results = {} | |||||
if self.n_epochs <= 0: | |||||
print(f"training epoch is {self.n_epochs}, nothing was done.") | |||||
results['seconds'] = 0. | |||||
return results | |||||
try: | |||||
if torch.cuda.is_available() and "cuda" in self.device: | |||||
self.model = self.model.cuda() | |||||
self._model_device = self.model.parameters().__next__().device | |||||
self._mode(self.model, is_test=False) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
start_time = time.time() | |||||
print("training epochs started " + self.start_time, flush=True) | |||||
try: | |||||
self.callback_manager.on_train_begin() | |||||
self._train() | |||||
self.callback_manager.on_train_end() | |||||
except (CallbackException, KeyboardInterrupt) as e: | |||||
self.callback_manager.on_exception(e) | |||||
if self.dev_data is not None: | |||||
print( | |||||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf), ) | |||||
results['best_eval'] = self.best_dev_perf | |||||
results['best_epoch'] = self.best_dev_epoch | |||||
results['best_step'] = self.best_dev_step | |||||
if load_best_model: | |||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | |||||
load_succeed = self._load_model(self.model, model_name) | |||||
if load_succeed: | |||||
print("Reloaded the best model.") | |||||
else: | |||||
print("Fail to reload best model.") | |||||
finally: | |||||
pass | |||||
results['seconds'] = round(time.time() - start_time, 2) | |||||
return results | |||||
def _train(self): | |||||
if not self.use_tqdm: | |||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||||
else: | |||||
inner_tqdm = tqdm | |||||
self.step = 0 | |||||
start = time.time() | |||||
total_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
avg_loss = 0 | |||||
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for epoch in range(1, self.n_epochs + 1): | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||||
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) | |||||
if epoch == self.n_epochs + 1 - self.final_epochs: | |||||
print('Entering the final stage. (Only train the selected structure)') | |||||
# early stopping | |||||
self.callback_manager.on_epoch_begin() | |||||
# 1. Training the shared parameters omega of the child models | |||||
self.train_shared(pbar) | |||||
# 2. Training the controller parameters theta | |||||
if not last_stage: | |||||
self.train_controller() | |||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||||
and self.dev_data is not None: | |||||
if not last_stage: | |||||
self.derive() | |||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
# lr decay; early stopping | |||||
self.callback_manager.on_epoch_end() | |||||
# =============== epochs end =================== # | |||||
pbar.close() | |||||
# ============ tqdm end ============== # | |||||
def get_loss(self, inputs, targets, hidden, dags): | |||||
"""Computes the loss for the same batch for M models. | |||||
This amounts to an estimate of the loss, which is turned into an | |||||
estimate for the gradients of the shared model. | |||||
""" | |||||
if not isinstance(dags, list): | |||||
dags = [dags] | |||||
loss = 0 | |||||
for dag in dags: | |||||
self.shared.setDAG(dag) | |||||
inputs = _build_args(self.shared.forward, **inputs) | |||||
inputs['hidden'] = hidden | |||||
result = self.shared(**inputs) | |||||
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] | |||||
self.callback_manager.on_loss_begin(targets, result) | |||||
sample_loss = self._compute_loss(result, targets) | |||||
loss += sample_loss | |||||
assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' | |||||
return loss, hidden, extra_out | |||||
def train_shared(self, pbar=None, max_step=None, dag=None): | |||||
"""Train the language model for 400 steps of minibatches of 64 | |||||
examples. | |||||
Args: | |||||
max_step: Used to run extra training steps as a warm-up. | |||||
dag: If not None, is used instead of calling sample(). | |||||
BPTT is truncated at 35 timesteps. | |||||
For each weight update, gradients are estimated by sampling M models | |||||
from the fixed controller policy, and averaging their gradients | |||||
computed on a batch of training data. | |||||
""" | |||||
model = self.shared | |||||
model.train() | |||||
self.controller.eval() | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
abs_max_grad = 0 | |||||
abs_max_hidden_norm = 0 | |||||
step = 0 | |||||
raw_total_loss = 0 | |||||
total_loss = 0 | |||||
train_idx = 0 | |||||
avg_loss = 0 | |||||
data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
indices = data_iterator.get_batch_indices() | |||||
# negative sampling; replace unknown; re-weight batch_y | |||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||||
# prediction = self._data_forward(self.model, batch_x) | |||||
dags = self.controller.sample(1) | |||||
inputs, targets = batch_x, batch_y | |||||
# self.callback_manager.on_loss_begin(batch_y, prediction) | |||||
loss, hidden, extra_out = self.get_loss(inputs, | |||||
targets, | |||||
hidden, | |||||
dags) | |||||
hidden.detach_() | |||||
avg_loss += loss.item() | |||||
# Is loss NaN or inf? requires_grad = False | |||||
self.callback_manager.on_backward_begin(loss) | |||||
self._grad_backward(loss) | |||||
self.callback_manager.on_backward_end() | |||||
self._update() | |||||
self.callback_manager.on_step_end() | |||||
if (self.step + 1) % self.print_every == 0: | |||||
if self.use_tqdm: | |||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||||
pbar.update(self.print_every) | |||||
else: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, avg_loss, diff) | |||||
pbar.set_postfix_str(print_output) | |||||
avg_loss = 0 | |||||
self.step += 1 | |||||
step += 1 | |||||
self.shared_step += 1 | |||||
self.callback_manager.on_batch_end() | |||||
# ================= mini-batch end ==================== # | |||||
def get_reward(self, dag, entropies, hidden, valid_idx=0): | |||||
"""Computes the perplexity of a single sampled model on a minibatch of | |||||
validation data. | |||||
""" | |||||
if not isinstance(entropies, np.ndarray): | |||||
entropies = entropies.data.cpu().numpy() | |||||
data_iterator = DataSetIter(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for inputs, targets in data_iterator: | |||||
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) | |||||
valid_loss = utils.to_item(valid_loss.data) | |||||
valid_ppl = math.exp(valid_loss) | |||||
R = 80 / valid_ppl | |||||
rewards = R + 1e-4 * entropies | |||||
return rewards, hidden | |||||
def train_controller(self): | |||||
"""Fixes the shared parameters and updates the controller parameters. | |||||
The controller is updated with a score function gradient estimator | |||||
(i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl | |||||
is computed on a minibatch of validation data. | |||||
A moving average baseline is used. | |||||
The controller is trained for 2000 steps per epoch (i.e., | |||||
first (Train Shared) phase -> second (Train Controller) phase). | |||||
""" | |||||
model = self.controller | |||||
model.train() | |||||
# Why can't we call shared.eval() here? Leads to loss | |||||
# being uniformly zero for the controller. | |||||
# self.shared.eval() | |||||
avg_reward_base = None | |||||
baseline = None | |||||
adv_history = [] | |||||
entropy_history = [] | |||||
reward_history = [] | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
total_loss = 0 | |||||
valid_idx = 0 | |||||
for step in range(20): | |||||
# sample models | |||||
dags, log_probs, entropies = self.controller.sample( | |||||
with_details=True) | |||||
# calculate reward | |||||
np_entropies = entropies.data.cpu().numpy() | |||||
# No gradients should be backpropagated to the | |||||
# shared model during controller training, obviously. | |||||
with _get_no_grad_ctx_mgr(): | |||||
rewards, hidden = self.get_reward(dags, | |||||
np_entropies, | |||||
hidden, | |||||
valid_idx) | |||||
reward_history.extend(rewards) | |||||
entropy_history.extend(np_entropies) | |||||
# moving average baseline | |||||
if baseline is None: | |||||
baseline = rewards | |||||
else: | |||||
decay = 0.95 | |||||
baseline = decay * baseline + (1 - decay) * rewards | |||||
adv = rewards - baseline | |||||
adv_history.extend(adv) | |||||
# policy loss | |||||
loss = -log_probs * utils.get_variable(adv, | |||||
'cuda' in self.device, | |||||
requires_grad=False) | |||||
loss = loss.sum() # or loss.mean() | |||||
# update | |||||
self.controller_optim.zero_grad() | |||||
loss.backward() | |||||
self.controller_optim.step() | |||||
total_loss += utils.to_item(loss.data) | |||||
if ((step % 50) == 0) and (step > 0): | |||||
reward_history, adv_history, entropy_history = [], [], [] | |||||
total_loss = 0 | |||||
self.controller_step += 1 | |||||
# prev_valid_idx = valid_idx | |||||
# valid_idx = ((valid_idx + self.max_length) % | |||||
# (self.valid_data.size(0) - 1)) | |||||
# # Whenever we wrap around to the beginning of the | |||||
# # validation data, we reset the hidden states. | |||||
# if prev_valid_idx > valid_idx: | |||||
# hidden = self.shared.init_hidden(self.batch_size) | |||||
def derive(self, sample_num=10, valid_idx=0): | |||||
"""We are always deriving based on the very first batch | |||||
of validation data? This seems wrong... | |||||
""" | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
dags, _, entropies = self.controller.sample(sample_num, | |||||
with_details=True) | |||||
max_R = 0 | |||||
best_dag = None | |||||
for dag in dags: | |||||
R, _ = self.get_reward(dag, entropies, hidden, valid_idx) | |||||
if R.max() > max_R: | |||||
max_R = R.max() | |||||
best_dag = dag | |||||
self.model.setDAG(best_dag) |
@@ -1,58 +0,0 @@ | |||||
"""undocumented | |||||
Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
""" | |||||
__all__ = [] | |||||
import collections | |||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | |||||
from torch.autograd import Variable | |||||
def detach(h): | |||||
if type(h) == Variable: | |||||
return Variable(h.data) | |||||
else: | |||||
return tuple(detach(v) for v in h) | |||||
def get_variable(inputs, cuda=False, **kwargs): | |||||
if type(inputs) in [list, np.ndarray]: | |||||
inputs = torch.Tensor(inputs) | |||||
if cuda: | |||||
out = Variable(inputs.cuda(), **kwargs) | |||||
else: | |||||
out = Variable(inputs, **kwargs) | |||||
return out | |||||
def update_lr(optimizer, lr): | |||||
for param_group in optimizer.param_groups: | |||||
param_group['lr'] = lr | |||||
Node = collections.namedtuple('Node', ['id', 'name']) | |||||
class keydefaultdict(defaultdict): | |||||
def __missing__(self, key): | |||||
if self.default_factory is None: | |||||
raise KeyError(key) | |||||
else: | |||||
ret = self[key] = self.default_factory(key) | |||||
return ret | |||||
def to_item(x): | |||||
"""Converts x, possibly scalar and possibly tensor, to a Python scalar.""" | |||||
if isinstance(x, (float, int)): | |||||
return x | |||||
if float(torch.__version__[0:3]) < 0.4: | |||||
assert (x.dim() == 1) and (len(x) == 1) | |||||
return x[0] | |||||
return x.item() |
@@ -1,44 +0,0 @@ | |||||
# fastNLP 高级接口 | |||||
### 环境与配置 | |||||
1. 系统环境:linux/ubuntu(推荐) | |||||
2. 编程语言:Python>=3.6 | |||||
3. Python包依赖 | |||||
- **torch==1.0** | |||||
- numpy>=1.14.2 | |||||
### 中文分词 | |||||
```python | |||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
from fastNLP.api import CWS | |||||
cws = CWS(device='cpu') | |||||
print(cws.predict(text)) | |||||
# ['编者 按 : 7月 12日 , 英国 航空 航天 系统 公司 公布 了 该 公司 研制 的 第一 款 高 科技 隐形 无人 机雷电 之 神 。', '这 款 飞行 从 外型 上 来 看 酷似 电影 中 的 太空 飞行器 , 据 英国 方面 介绍 , 可以 实现 洲际 远程 打击 。', '那么 这 款 无人 机 到底 有 多 厉害 ?'] | |||||
``` | |||||
### 词性标注 | |||||
```python | |||||
# 输入已分词序列 | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
from fastNLP.api import POS | |||||
pos = POS(device='cpu') | |||||
print(pos.predict(text)) | |||||
# [['编者/NN', '按:/NN', '7月/NT', '12日/NT', ',/PU', '英国/NR', '航空/NN', '航天/NN', '系统/NN', '公司/NN', '公布/VV', '了/AS', '该/DT', '公司/NN', '研制/VV', '的/DEC', '第一款/NN', '高科技/NN', '隐形/AD', '无人机/VV', '雷电之神/NN', '。/PU'], ['那么/AD', '这/DT', '款/NN', '无人机/VV', '到底/AD', '有/VE', '多/AD', '厉害/VA', '?/PU']] | |||||
``` | |||||
### 句法分析 | |||||
```python | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
from fastNLP.api import Parser | |||||
parser = Parser(device='cpu') | |||||
print(parser.predict(text)) | |||||
# [['2/nn', '4/nn', '4/nn', '20/tmod', '11/punct', '10/nn', '10/nn', '10/nn', '10/nn', '11/nsubj', '20/dep', '11/asp', '14/det', '15/nsubj', '18/rcmod', '15/cpm', '18/nn', '11/dobj', '20/advmod', '0/root', '20/dobj', '20/punct'], ['4/advmod', '3/det', '8/xsubj', '8/dep', '8/advmod', '8/dep', '8/advmod', '0/root', '8/punct']] | |||||
``` | |||||
完整样例见`examples.py` |
@@ -1,2 +0,0 @@ | |||||
__all__ = ["CWS", "POS", "Parser"] | |||||
from .api import CWS, POS, Parser |
@@ -1,463 +0,0 @@ | |||||
import warnings | |||||
import torch | |||||
warnings.filterwarnings('ignore') | |||||
import os | |||||
from fastNLP.core.dataset import DataSet | |||||
from .utils import load_url | |||||
from .processor import ModelProcessor | |||||
from fastNLP.io.dataset_loader import _cut_long_sentence | |||||
from fastNLP.io.data_loader import ConllLoader | |||||
from fastNLP.core.instance import Instance | |||||
from ..api.pipeline import Pipeline | |||||
from fastNLP.core.metrics import SpanFPreRecMetric | |||||
from .processor import IndexerProcessor | |||||
# TODO add pretrain urls | |||||
model_urls = { | |||||
"cws": "http://123.206.98.91:8888/download/cws_lstm_ctb9_1_20-09908656.pkl", | |||||
"pos": "http://123.206.98.91:8888/download/pos_tag_model_20190119-43f8b435.pkl", | |||||
"parser": "http://123.206.98.91:8888/download/parser_20190204-c72ca5c0.pkl" | |||||
} | |||||
class ConllCWSReader(object): | |||||
"""Deprecated. Use ConllLoader for all types of conll-format files.""" | |||||
def __init__(self): | |||||
pass | |||||
def load(self, path, cut_long_sent=False): | |||||
""" | |||||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||||
:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
""" | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.strip().split()) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
ds = DataSet() | |||||
for sample in datalist: | |||||
# print(sample) | |||||
res = self.get_char_lst(sample) | |||||
if res is None: | |||||
continue | |||||
line = ' '.join(res) | |||||
if cut_long_sent: | |||||
sents = _cut_long_sentence(line) | |||||
else: | |||||
sents = [line] | |||||
for raw_sentence in sents: | |||||
ds.append(Instance(raw_sentence=raw_sentence)) | |||||
return ds | |||||
def get_char_lst(self, sample): | |||||
if len(sample) == 0: | |||||
return None | |||||
text = [] | |||||
for w in sample: | |||||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||||
if t3 == '_': | |||||
return None | |||||
text.append(t1) | |||||
return text | |||||
class ConllxDataLoader(ConllLoader): | |||||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | |||||
Deprecated. Use ConllLoader for all types of conll-format files. | |||||
""" | |||||
def __init__(self): | |||||
headers = [ | |||||
'words', 'pos_tags', 'heads', 'labels', | |||||
] | |||||
indexs = [ | |||||
1, 3, 6, 7, | |||||
] | |||||
super(ConllxDataLoader, self).__init__(headers=headers, indexes=indexs) | |||||
class API: | |||||
def __init__(self): | |||||
self.pipeline = None | |||||
self._dict = None | |||||
def predict(self, *args, **kwargs): | |||||
"""Do prediction for the given input. | |||||
""" | |||||
raise NotImplementedError | |||||
def test(self, file_path): | |||||
"""Test performance over the given data set. | |||||
:param str file_path: | |||||
:return: a dictionary of metric values | |||||
""" | |||||
raise NotImplementedError | |||||
def load(self, path, device): | |||||
if os.path.exists(os.path.expanduser(path)): | |||||
_dict = torch.load(path, map_location='cpu') | |||||
else: | |||||
_dict = load_url(path, map_location='cpu') | |||||
self._dict = _dict | |||||
self.pipeline = _dict['pipeline'] | |||||
for processor in self.pipeline.pipeline: | |||||
if isinstance(processor, ModelProcessor): | |||||
processor.set_model_device(device) | |||||
class POS(API): | |||||
"""FastNLP API for Part-Of-Speech tagging. | |||||
:param str model_path: the path to the model. | |||||
:param str device: device name such as "cpu" or "cuda:0". Use the same notation as PyTorch. | |||||
""" | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(POS, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['pos'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
"""predict函数的介绍, | |||||
函数介绍的第二句,这句话不会换行 | |||||
:param content: list of list of str. Each string is a token(word). | |||||
:return answer: list of list of str. Each string is a tag. | |||||
""" | |||||
if not hasattr(self, "pipeline"): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = content | |||||
# 1. 检查sentence的类型 | |||||
for sentence in sentence_list: | |||||
if not all((type(obj) == str for obj in sentence)): | |||||
raise ValueError("Input must be list of list of string.") | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field("words", sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
def merge_tag(words_list, tags_list): | |||||
rtn = [] | |||||
for words, tags in zip(words_list, tags_list): | |||||
rtn.append([w + "/" + t for w, t in zip(words, tags)]) | |||||
return rtn | |||||
output = dataset.field_arrays["tag"].content | |||||
if isinstance(content, str): | |||||
return output[0] | |||||
elif isinstance(content, list): | |||||
return merge_tag(content, output) | |||||
def test(self, file_path): | |||||
test_data = ConllxDataLoader().load(file_path) | |||||
save_dict = self._dict | |||||
tag_vocab = save_dict["tag_vocab"] | |||||
pipeline = save_dict["pipeline"] | |||||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||||
pipeline.pipeline = [index_tag] + pipeline.pipeline | |||||
test_data.rename_field("pos_tags", "tag") | |||||
pipeline(test_data) | |||||
test_data.set_target("truth") | |||||
prediction = test_data.field_arrays["predict"].content | |||||
truth = test_data.field_arrays["truth"].content | |||||
seq_len = test_data.field_arrays["word_seq_origin_len"].content | |||||
# padding by hand | |||||
max_length = max([len(seq) for seq in prediction]) | |||||
for idx in range(len(prediction)): | |||||
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | |||||
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | |||||
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | |||||
seq_len="word_seq_origin_len") | |||||
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | |||||
{"truth": torch.Tensor(truth)}) | |||||
test_result = evaluator.get_metric() | |||||
f1 = round(test_result['f'] * 100, 2) | |||||
pre = round(test_result['pre'] * 100, 2) | |||||
rec = round(test_result['rec'] * 100, 2) | |||||
return {"F1": f1, "precision": pre, "recall": rec} | |||||
class CWS(API): | |||||
""" | |||||
中文分词高级接口。 | |||||
:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 | |||||
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 | |||||
""" | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(CWS, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['cws'] | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
""" | |||||
分词接口。 | |||||
:param content: str或List[str], 例如: "中文分词很重要!", 返回的结果是"中文 分词 很 重要 !"。 如果传入的为List[str],比如 | |||||
[ "中文分词很重要!", ...], 返回的结果["中文 分词 很 重要 !", ...]。 | |||||
:return: str或List[str], 根据输入的的类型决定。 | |||||
""" | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('raw_sentence', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
output = dataset.get_field('output').content | |||||
if isinstance(content, str): | |||||
return output[0] | |||||
elif isinstance(content, list): | |||||
return output | |||||
def test(self, filepath): | |||||
""" | |||||
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 | |||||
分词文件应该为:: | |||||
1 编者按 编者按 NN O 11 nmod:topic | |||||
2 : : PU O 11 punct | |||||
3 7月 7月 NT DATE 4 compound:nn | |||||
4 12日 12日 NT DATE 11 nmod:tmod | |||||
5 , , PU O 11 punct | |||||
1 这 这 DT O 3 det | |||||
2 款 款 M O 1 mark:clf | |||||
3 飞行 飞行 NN O 8 nsubj | |||||
4 从 从 P O 5 case | |||||
5 外型 外型 NN O 8 nmod:prep | |||||
以空行分割两个句子,有内容的每行有7列。 | |||||
:param filepath: str, 文件路径路径。 | |||||
:return: float, float, float. 分别f1, precision, recall. | |||||
""" | |||||
tag_proc = self._dict['tag_proc'] | |||||
cws_model = self.pipeline.pipeline[-2].model | |||||
pipeline = self.pipeline.pipeline[:-2] | |||||
pipeline.insert(1, tag_proc) | |||||
pp = Pipeline(pipeline) | |||||
reader = ConllCWSReader() | |||||
# te_filename = '/home/hyan/ctb3/test.conllx' | |||||
te_dataset = reader.load(filepath) | |||||
pp(te_dataset) | |||||
from ..core.tester import Tester | |||||
from ..core.metrics import SpanFPreRecMetric | |||||
tester = Tester(data=te_dataset, model=cws_model, metrics=SpanFPreRecMetric(tag_proc.get_vocab()), batch_size=64, | |||||
verbose=0) | |||||
eval_res = tester.test() | |||||
f1 = eval_res['SpanFPreRecMetric']['f'] | |||||
pre = eval_res['SpanFPreRecMetric']['pre'] | |||||
rec = eval_res['SpanFPreRecMetric']['rec'] | |||||
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) | |||||
return {"F1": f1, "precision": pre, "recall": rec} | |||||
class Parser(API): | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(Parser, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['parser'] | |||||
self.pos_tagger = POS(device=device) | |||||
self.load(model_path, device) | |||||
def predict(self, content): | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
# 1. 利用POS得到分词和pos tagging结果 | |||||
pos_out = self.pos_tagger.predict(content) | |||||
# pos_out = ['这里/NN 是/VB 分词/NN 结果/NN'.split()] | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('wp', pos_out) | |||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[0] for w in x['wp']], new_field_name='words') | |||||
dataset.apply(lambda x: ['<BOS>'] + [w.split('/')[1] for w in x['wp']], new_field_name='pos') | |||||
dataset.rename_field("words", "raw_words") | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
dataset.apply(lambda x: [str(arc) for arc in x['arc_pred']], new_field_name='arc_pred') | |||||
dataset.apply(lambda x: [arc + '/' + label for arc, label in | |||||
zip(x['arc_pred'], x['label_pred_seq'])][1:], new_field_name='output') | |||||
# output like: [['2/top', '0/root', '4/nn', '2/dep']] | |||||
return dataset.field_arrays['output'].content | |||||
def load_test_file(self, path): | |||||
def get_one(sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
if len(sample) == 0: | |||||
return None | |||||
for w in sample[7]: | |||||
if w == '_': | |||||
print('Error Sample {}'.format(sample)) | |||||
return None | |||||
# return word_seq, pos_seq, head_seq, head_tag_seq | |||||
return sample[1], sample[3], list(map(int, sample[6])), sample[7] | |||||
datalist = [] | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
sample = [] | |||||
for line in f: | |||||
if line.startswith('\n'): | |||||
datalist.append(sample) | |||||
sample = [] | |||||
elif line.startswith('#'): | |||||
continue | |||||
else: | |||||
sample.append(line.split('\t')) | |||||
if len(sample) > 0: | |||||
datalist.append(sample) | |||||
data = [get_one(sample) for sample in datalist] | |||||
data_list = list(filter(lambda x: x is not None, data)) | |||||
return data_list | |||||
def test(self, filepath): | |||||
data = self.load_test_file(filepath) | |||||
def convert(data): | |||||
BOS = '<BOS>' | |||||
dataset = DataSet() | |||||
for sample in data: | |||||
word_seq = [BOS] + sample[0] | |||||
pos_seq = [BOS] + sample[1] | |||||
heads = [0] + sample[2] | |||||
head_tags = [BOS] + sample[3] | |||||
dataset.append(Instance(raw_words=word_seq, | |||||
pos=pos_seq, | |||||
gold_heads=heads, | |||||
arc_true=heads, | |||||
tags=head_tags)) | |||||
return dataset | |||||
ds = convert(data) | |||||
pp = self.pipeline | |||||
for p in pp: | |||||
if p.field_name == 'word_list': | |||||
p.field_name = 'gold_words' | |||||
elif p.field_name == 'pos_list': | |||||
p.field_name = 'gold_pos' | |||||
# ds.rename_field("words", "raw_words") | |||||
# ds.rename_field("tag", "pos") | |||||
pp(ds) | |||||
head_cor, label_cor, total = 0, 0, 0 | |||||
for ins in ds: | |||||
head_gold = ins['gold_heads'] | |||||
head_pred = ins['arc_pred'] | |||||
length = len(head_gold) | |||||
total += length | |||||
for i in range(length): | |||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | |||||
uas = head_cor / total | |||||
# print('uas:{:.2f}'.format(uas)) | |||||
for p in pp: | |||||
if p.field_name == 'gold_words': | |||||
p.field_name = 'word_list' | |||||
elif p.field_name == 'gold_pos': | |||||
p.field_name = 'pos_list' | |||||
return {"USA": round(uas, 5)} | |||||
class Analyzer: | |||||
def __init__(self, device='cpu'): | |||||
self.cws = CWS(device=device) | |||||
self.pos = POS(device=device) | |||||
self.parser = Parser(device=device) | |||||
def predict(self, content, seg=False, pos=False, parser=False): | |||||
if seg is False and pos is False and parser is False: | |||||
seg = True | |||||
output_dict = {} | |||||
if seg: | |||||
seg_output = self.cws.predict(content) | |||||
output_dict['seg'] = seg_output | |||||
if pos: | |||||
pos_output = self.pos.predict(content) | |||||
output_dict['pos'] = pos_output | |||||
if parser: | |||||
parser_output = self.parser.predict(content) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict | |||||
def test(self, filepath): | |||||
output_dict = {} | |||||
if self.cws: | |||||
seg_output = self.cws.test(filepath) | |||||
output_dict['seg'] = seg_output | |||||
if self.pos: | |||||
pos_output = self.pos.test(filepath) | |||||
output_dict['pos'] = pos_output | |||||
if self.parser: | |||||
parser_output = self.parser.test(filepath) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict |
@@ -1,181 +0,0 @@ | |||||
import re | |||||
class SpanConverter: | |||||
def __init__(self, replace_tag, pattern): | |||||
super(SpanConverter, self).__init__() | |||||
self.replace_tag = replace_tag | |||||
self.pattern = pattern | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
prev_end = 0 | |||||
for match in re.finditer(self.pattern, sentence): | |||||
start, end = match.span() | |||||
span = sentence[start:end] | |||||
replaced_sentence += sentence[prev_end:start] + self.span_to_special_tag(span) | |||||
prev_end = end | |||||
replaced_sentence += sentence[prev_end:] | |||||
return replaced_sentence | |||||
def span_to_special_tag(self, span): | |||||
return self.replace_tag | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
for match in re.finditer(self.pattern, sentence): | |||||
spans.append(match.span()) | |||||
return spans | |||||
class AlphaSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<ALPHA>' | |||||
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag). | |||||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])' | |||||
super(AlphaSpanConverter, self).__init__(replace_tag, pattern) | |||||
class DigitSpanConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<NUM>' | |||||
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super(DigitSpanConverter, self).__init__(replace_tag, pattern) | |||||
def span_to_special_tag(self, span): | |||||
# return self.special_tag | |||||
if span[0] == '0' and len(span) > 2: | |||||
return '<NUM>' | |||||
decimal_point_count = 0 # one might have more than one decimal pointers | |||||
for idx, char in enumerate(span): | |||||
if char == '.' or char == '﹒' or char == '·': | |||||
decimal_point_count += 1 | |||||
if span[-1] == '.' or span[-1] == '﹒' or span[-1] == '·': | |||||
# last digit being decimal point means this is not a number | |||||
if decimal_point_count == 1: | |||||
return span | |||||
else: | |||||
return '<UNKDGT>' | |||||
if decimal_point_count == 1: | |||||
return '<DEC>' | |||||
elif decimal_point_count > 1: | |||||
return '<UNKDGT>' | |||||
else: | |||||
return '<NUM>' | |||||
class TimeConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<TOC>' | |||||
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])' | |||||
super().__init__(replace_tag, pattern) | |||||
class MixNumAlphaConverter(SpanConverter): | |||||
def __init__(self): | |||||
replace_tag = '<MIX>' | |||||
pattern = None | |||||
super().__init__(replace_tag, pattern) | |||||
def find_certain_span_and_replace(self, sentence): | |||||
replaced_sentence = '' | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
replaced_sentence += sentence[start:idx] | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
span = sentence[start:idx] | |||||
start = idx | |||||
replaced_sentence += self.span_to_special_tag(span) | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
replaced_sentence += sentence[start:] | |||||
return replaced_sentence | |||||
def find_certain_span(self, sentence): | |||||
spans = [] | |||||
start = 0 | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
for idx in range(len(sentence)): | |||||
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]): | |||||
if not matching_flag: | |||||
start = idx | |||||
if re.match('[0-9]', sentence[idx]): | |||||
number_flag = True | |||||
elif re.match('[\'′&\\-]', sentence[idx]): | |||||
link_flag = True | |||||
elif re.match('/', sentence[idx]): | |||||
slash_flag = True | |||||
elif re.match('[\\(\\)]', sentence[idx]): | |||||
bracket_flag = True | |||||
else: | |||||
alpha_flag = True | |||||
matching_flag = True | |||||
elif re.match('[\\.]', sentence[idx]): | |||||
pass | |||||
else: | |||||
if matching_flag: | |||||
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \ | |||||
or (slash_flag and alpha_flag) or (link_flag and number_flag) \ | |||||
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag): | |||||
spans.append((start, idx)) | |||||
start = idx | |||||
matching_flag = False | |||||
number_flag = False | |||||
alpha_flag = False | |||||
link_flag = False | |||||
slash_flag = False | |||||
bracket_flag = False | |||||
return spans | |||||
class EmailConverter(SpanConverter): | |||||
def __init__(self): | |||||
replaced_tag = "<EML>" | |||||
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])' | |||||
super(EmailConverter, self).__init__(replaced_tag, pattern) |
@@ -1,56 +0,0 @@ | |||||
""" | |||||
api/example.py contains all API examples provided by fastNLP. | |||||
It is used as a tutorial for API or a test script since it is difficult to test APIs in travis. | |||||
""" | |||||
from . import CWS, POS, Parser | |||||
text = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
def chinese_word_segmentation(): | |||||
cws = CWS(device='cpu') | |||||
print(cws.predict(text)) | |||||
def chinese_word_segmentation_test(): | |||||
cws = CWS(device='cpu') | |||||
print(cws.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
def pos_tagging(): | |||||
# 输入已分词序列 | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
pos = POS(device='cpu') | |||||
print(pos.predict(text)) | |||||
def pos_tagging_test(): | |||||
pos = POS(device='cpu') | |||||
print(pos.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
def syntactic_parsing(): | |||||
text = [['编者', '按:', '7月', '12日', ',', '英国', '航空', '航天', '系统', '公司', '公布', '了', '该', '公司', | |||||
'研制', '的', '第一款', '高科技', '隐形', '无人机', '雷电之神', '。'], | |||||
['那么', '这', '款', '无人机', '到底', '有', '多', '厉害', '?']] | |||||
parser = Parser(device='cpu') | |||||
print(parser.predict(text)) | |||||
def syntactic_parsing_test(): | |||||
parser = Parser(device='cpu') | |||||
print(parser.test("../../test/data_for_tests/zh_sample.conllx")) | |||||
if __name__ == "__main__": | |||||
# chinese_word_segmentation() | |||||
# chinese_word_segmentation_test() | |||||
# pos_tagging() | |||||
# pos_tagging_test() | |||||
syntactic_parsing() | |||||
# syntactic_parsing_test() |
@@ -1,33 +0,0 @@ | |||||
from ..api.processor import Processor | |||||
class Pipeline: | |||||
""" | |||||
Pipeline takes a DataSet object as input, runs multiple processors sequentially, and | |||||
outputs a DataSet object. | |||||
""" | |||||
def __init__(self, processors=None): | |||||
self.pipeline = [] | |||||
if isinstance(processors, list): | |||||
for proc in processors: | |||||
assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(proc)) | |||||
self.pipeline = processors | |||||
def add_processor(self, processor): | |||||
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) | |||||
self.pipeline.append(processor) | |||||
def process(self, dataset): | |||||
assert len(self.pipeline) != 0, "You need to add some processor first." | |||||
for proc in self.pipeline: | |||||
dataset = proc(dataset) | |||||
return dataset | |||||
def __call__(self, *args, **kwargs): | |||||
return self.process(*args, **kwargs) | |||||
def __getitem__(self, item): | |||||
return self.pipeline[item] |
@@ -1,428 +0,0 @@ | |||||
import re | |||||
from collections import defaultdict | |||||
import torch | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.vocabulary import Vocabulary | |||||
class Processor(object): | |||||
def __init__(self, field_name, new_added_field_name): | |||||
""" | |||||
:param field_name: 处理哪个field | |||||
:param new_added_field_name: 如果为None,则认为是field_name,即覆盖原有的field | |||||
""" | |||||
self.field_name = field_name | |||||
if new_added_field_name is None: | |||||
self.new_added_field_name = field_name | |||||
else: | |||||
self.new_added_field_name = new_added_field_name | |||||
def process(self, *args, **kwargs): | |||||
raise NotImplementedError | |||||
def __call__(self, *args, **kwargs): | |||||
return self.process(*args, **kwargs) | |||||
class FullSpaceToHalfSpaceProcessor(Processor): | |||||
"""全角转半角,以字符为处理单元 | |||||
""" | |||||
def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, | |||||
change_space=True): | |||||
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) | |||||
self.change_alpha = change_alpha | |||||
self.change_digit = change_digit | |||||
self.change_punctuation = change_punctuation | |||||
self.change_space = change_space | |||||
FH_SPACE = [(u" ", u" ")] | |||||
FH_NUM = [ | |||||
(u"0", u"0"), (u"1", u"1"), (u"2", u"2"), (u"3", u"3"), (u"4", u"4"), | |||||
(u"5", u"5"), (u"6", u"6"), (u"7", u"7"), (u"8", u"8"), (u"9", u"9")] | |||||
FH_ALPHA = [ | |||||
(u"a", u"a"), (u"b", u"b"), (u"c", u"c"), (u"d", u"d"), (u"e", u"e"), | |||||
(u"f", u"f"), (u"g", u"g"), (u"h", u"h"), (u"i", u"i"), (u"j", u"j"), | |||||
(u"k", u"k"), (u"l", u"l"), (u"m", u"m"), (u"n", u"n"), (u"o", u"o"), | |||||
(u"p", u"p"), (u"q", u"q"), (u"r", u"r"), (u"s", u"s"), (u"t", u"t"), | |||||
(u"u", u"u"), (u"v", u"v"), (u"w", u"w"), (u"x", u"x"), (u"y", u"y"), | |||||
(u"z", u"z"), | |||||
(u"A", u"A"), (u"B", u"B"), (u"C", u"C"), (u"D", u"D"), (u"E", u"E"), | |||||
(u"F", u"F"), (u"G", u"G"), (u"H", u"H"), (u"I", u"I"), (u"J", u"J"), | |||||
(u"K", u"K"), (u"L", u"L"), (u"M", u"M"), (u"N", u"N"), (u"O", u"O"), | |||||
(u"P", u"P"), (u"Q", u"Q"), (u"R", u"R"), (u"S", u"S"), (u"T", u"T"), | |||||
(u"U", u"U"), (u"V", u"V"), (u"W", u"W"), (u"X", u"X"), (u"Y", u"Y"), | |||||
(u"Z", u"Z")] | |||||
# 谨慎使用标点符号转换, 因为"5.12特大地震"转换后可能就成了"5.12特大地震" | |||||
FH_PUNCTUATION = [ | |||||
(u'%', u'%'), (u'!', u'!'), (u'"', u'\"'), (u''', u'\''), (u'#', u'#'), | |||||
(u'¥', u'$'), (u'&', u'&'), (u'(', u'('), (u')', u')'), (u'*', u'*'), | |||||
(u'+', u'+'), (u',', u','), (u'-', u'-'), (u'.', u'.'), (u'/', u'/'), | |||||
(u':', u':'), (u';', u';'), (u'<', u'<'), (u'=', u'='), (u'>', u'>'), | |||||
(u'?', u'?'), (u'@', u'@'), (u'[', u'['), (u']', u']'), (u'\', u'\\'), | |||||
(u'^', u'^'), (u'_', u'_'), (u'`', u'`'), (u'~', u'~'), (u'{', u'{'), | |||||
(u'}', u'}'), (u'|', u'|')] | |||||
FHs = [] | |||||
if self.change_alpha: | |||||
FHs = FH_ALPHA | |||||
if self.change_digit: | |||||
FHs += FH_NUM | |||||
if self.change_punctuation: | |||||
FHs += FH_PUNCTUATION | |||||
if self.change_space: | |||||
FHs += FH_SPACE | |||||
self.convert_map = {k: v for k, v in FHs} | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
def inner_proc(ins): | |||||
sentence = ins[self.field_name] | |||||
new_sentence = [""] * len(sentence) | |||||
for idx, char in enumerate(sentence): | |||||
if char in self.convert_map: | |||||
char = self.convert_map[char] | |||||
new_sentence[idx] = char | |||||
return "".join(new_sentence) | |||||
dataset.apply(inner_proc, new_field_name=self.field_name) | |||||
return dataset | |||||
class PreAppendProcessor(Processor): | |||||
""" | |||||
向某个field的起始增加data(应该为str类型)。该field需要为list类型。即新增的field为 | |||||
[data] + instance[field_name] | |||||
""" | |||||
def __init__(self, data, field_name, new_added_field_name=None): | |||||
super(PreAppendProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.data = data | |||||
def process(self, dataset): | |||||
dataset.apply(lambda ins: [self.data] + ins[self.field_name], new_field_name=self.new_added_field_name) | |||||
return dataset | |||||
class SliceProcessor(Processor): | |||||
""" | |||||
从某个field中只取部分内容。等价于instance[field_name][start:end:step] | |||||
""" | |||||
def __init__(self, start, end, step, field_name, new_added_field_name=None): | |||||
super(SliceProcessor, self).__init__(field_name, new_added_field_name) | |||||
for o in (start, end, step): | |||||
assert isinstance(o, int) or o is None | |||||
self.slice = slice(start, end, step) | |||||
def process(self, dataset): | |||||
dataset.apply(lambda ins: ins[self.field_name][self.slice], new_field_name=self.new_added_field_name) | |||||
return dataset | |||||
class Num2TagProcessor(Processor): | |||||
""" | |||||
将一句话中的数字转换为某个tag。 | |||||
""" | |||||
def __init__(self, tag, field_name, new_added_field_name=None): | |||||
""" | |||||
:param tag: str, 将数字转换为该tag | |||||
:param field_name: | |||||
:param new_added_field_name: | |||||
""" | |||||
super(Num2TagProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.tag = tag | |||||
self.pattern = r'[-+]?([0-9]+[.]?[0-9]*)+[/eE]?[-+]?([0-9]+[.]?[0-9]*)' | |||||
def process(self, dataset): | |||||
def inner_proc(ins): | |||||
s = ins[self.field_name] | |||||
new_s = [None] * len(s) | |||||
for i, w in enumerate(s): | |||||
if re.search(self.pattern, w) is not None: | |||||
w = self.tag | |||||
new_s[i] = w | |||||
return new_s | |||||
dataset.apply(inner_proc, new_field_name=self.new_added_field_name) | |||||
return dataset | |||||
class IndexerProcessor(Processor): | |||||
""" | |||||
给定一个vocabulary , 将指定field转换为index形式。指定field应该是一维的list,比如 | |||||
['我', '是', xxx] | |||||
""" | |||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||||
super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.vocab = vocab | |||||
self.delete_old_field = delete_old_field | |||||
self.is_input = is_input | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name) | |||||
if self.is_input: | |||||
dataset.set_input(self.new_added_field_name) | |||||
if self.delete_old_field: | |||||
dataset.delete_field(self.field_name) | |||||
return dataset | |||||
class VocabProcessor(Processor): | |||||
""" | |||||
传入若干个DataSet以建立vocabulary。 | |||||
""" | |||||
def __init__(self, field_name, min_freq=1, max_size=None): | |||||
super(VocabProcessor, self).__init__(field_name, None) | |||||
self.vocab = Vocabulary(min_freq=min_freq, max_size=max_size) | |||||
def process(self, *datasets): | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
def get_vocab(self): | |||||
self.vocab.build_vocab() | |||||
return self.vocab | |||||
class SeqLenProcessor(Processor): | |||||
""" | |||||
根据某个field新增一个sequence length的field。取该field的第一维 | |||||
""" | |||||
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.is_input = is_input | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: len(ins[self.field_name]), new_field_name=self.new_added_field_name) | |||||
if self.is_input: | |||||
dataset.set_input(self.new_added_field_name) | |||||
return dataset | |||||
from fastNLP.core.utils import _build_args | |||||
class ModelProcessor(Processor): | |||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | |||||
""" | |||||
传入一个model,在process()时传入一个dataset,该processor会通过Batch将DataSet的内容输出给model.predict或者model.forward. | |||||
model输出的内容会被增加到dataset中,field_name由model输出决定。如果生成的内容维度不是(Batch_size, )与 | |||||
(Batch_size, 1),则使用seqence length这个field进行unpad | |||||
TODO 这个类需要删除对seq_lens的依赖。 | |||||
:param seq_len_field_name: | |||||
:param batch_size: | |||||
""" | |||||
super(ModelProcessor, self).__init__(None, None) | |||||
self.batch_size = batch_size | |||||
self.seq_len_field_name = seq_len_field_name | |||||
self.model = model | |||||
def process(self, dataset): | |||||
self.model.eval() | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
data_iterator = Batch(dataset, batch_size=self.batch_size, sampler=SequentialSampler()) | |||||
batch_output = defaultdict(list) | |||||
predict_func = self.model.forward | |||||
with torch.no_grad(): | |||||
for batch_x, _ in data_iterator: | |||||
refined_batch_x = _build_args(predict_func, **batch_x) | |||||
prediction = predict_func(**refined_batch_x) | |||||
seq_lens = batch_x[self.seq_len_field_name].tolist() | |||||
for key, value in prediction.items(): | |||||
tmp_batch = [] | |||||
value = value.cpu().numpy() | |||||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | |||||
batch_output[key].extend(value.tolist()) | |||||
else: | |||||
for idx, seq_len in enumerate(seq_lens): | |||||
tmp_batch.append(value[idx, :seq_len]) | |||||
batch_output[key].extend(tmp_batch) | |||||
if not self.seq_len_field_name in prediction: | |||||
batch_output[self.seq_len_field_name].extend(seq_lens) | |||||
# TODO 当前的实现会导致之后的processor需要知道model输出的output的key是什么 | |||||
for field_name, fields in batch_output.items(): | |||||
dataset.add_field(field_name, fields, is_input=True, is_target=False) | |||||
return dataset | |||||
def set_model(self, model): | |||||
self.model = model | |||||
def set_model_device(self, device): | |||||
device = torch.device(device) | |||||
self.model.to(device) | |||||
class Index2WordProcessor(Processor): | |||||
""" | |||||
将DataSet中某个为index的field根据vocab转换为str | |||||
""" | |||||
def __init__(self, vocab, field_name, new_added_field_name): | |||||
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) | |||||
self.vocab = vocab | |||||
def process(self, dataset): | |||||
dataset.apply(lambda ins: [self.vocab.to_word(w) for w in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name) | |||||
return dataset | |||||
class SetTargetProcessor(Processor): | |||||
def __init__(self, *fields, flag=True): | |||||
super(SetTargetProcessor, self).__init__(None, None) | |||||
self.fields = fields | |||||
self.flag = flag | |||||
def process(self, dataset): | |||||
dataset.set_target(*self.fields, flag=self.flag) | |||||
return dataset | |||||
class SetInputProcessor(Processor): | |||||
def __init__(self, *fields, flag=True): | |||||
super(SetInputProcessor, self).__init__(None, None) | |||||
self.fields = fields | |||||
self.flag = flag | |||||
def process(self, dataset): | |||||
dataset.set_input(*self.fields, flag=self.flag) | |||||
return dataset | |||||
class VocabIndexerProcessor(Processor): | |||||
""" | |||||
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供 | |||||
new_added_field_name, 则覆盖原有的field_name. | |||||
""" | |||||
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None, | |||||
verbose=0, is_input=True): | |||||
""" | |||||
:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作 | |||||
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name. | |||||
:param min_freq: 创建的Vocabulary允许的单词最少出现次数. | |||||
:param max_size: 创建的Vocabulary允许的最大的单词数量 | |||||
:param verbose: 0, 不输出任何信息;1,输出信息 | |||||
:param bool is_input: | |||||
""" | |||||
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name) | |||||
self.min_freq = min_freq | |||||
self.max_size = max_size | |||||
self.verbose = verbose | |||||
self.is_input = is_input | |||||
def construct_vocab(self, *datasets): | |||||
""" | |||||
使用传入的DataSet创建vocabulary | |||||
:param datasets: DataSet类型的数据,用于构建vocabulary | |||||
:return: | |||||
""" | |||||
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size) | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name])) | |||||
self.vocab.build_vocab() | |||||
if self.verbose: | |||||
print("Vocabulary Constructed, has {} items.".format(len(self.vocab))) | |||||
def process(self, *datasets, only_index_dataset=None): | |||||
""" | |||||
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary | |||||
后,则会index datasets与only_index_dataset。 | |||||
:param datasets: DataSet类型的数据 | |||||
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。 | |||||
:return: | |||||
""" | |||||
if len(datasets) == 0 and not hasattr(self, 'vocab'): | |||||
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.") | |||||
if not hasattr(self, 'vocab'): | |||||
self.construct_vocab(*datasets) | |||||
else: | |||||
if self.verbose: | |||||
print("Using constructed vocabulary with {} items.".format(len(self.vocab))) | |||||
to_index_datasets = [] | |||||
if len(datasets) != 0: | |||||
for dataset in datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
if not (only_index_dataset is None): | |||||
if isinstance(only_index_dataset, list): | |||||
for dataset in only_index_dataset: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
to_index_datasets.append(dataset) | |||||
elif isinstance(only_index_dataset, DataSet): | |||||
to_index_datasets.append(only_index_dataset) | |||||
else: | |||||
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset))) | |||||
for dataset in to_index_datasets: | |||||
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset)) | |||||
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]], | |||||
new_field_name=self.new_added_field_name, is_input=self.is_input) | |||||
# 只返回一个,infer时为了跟其他processor保持一致 | |||||
if len(to_index_datasets) == 1: | |||||
return to_index_datasets[0] | |||||
def set_vocab(self, vocab): | |||||
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab)) | |||||
self.vocab = vocab | |||||
def delete_vocab(self): | |||||
del self.vocab | |||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
def set_verbose(self, verbose): | |||||
""" | |||||
设置processor verbose状态。 | |||||
:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。 | |||||
:return: | |||||
""" | |||||
self.verbose = verbose |
@@ -1,134 +0,0 @@ | |||||
import hashlib | |||||
import os | |||||
import re | |||||
import shutil | |||||
import sys | |||||
import tempfile | |||||
import torch | |||||
try: | |||||
from requests.utils import urlparse | |||||
from requests import get as urlopen | |||||
requests_available = True | |||||
except ImportError: | |||||
requests_available = False | |||||
if sys.version_info[0] == 2: | |||||
from urlparse import urlparse # noqa f811 | |||||
from urllib2 import urlopen # noqa f811 | |||||
else: | |||||
from urllib.request import urlopen | |||||
from urllib.parse import urlparse | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
# matches bfd8deac from resnet18-bfd8deac.pth | |||||
HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') | |||||
def load_url(url, model_dir=None, map_location=None, progress=True): | |||||
r"""Loads the Torch serialized object at the given URL. | |||||
If the object is already present in `model_dir`, it's deserialized and | |||||
returned. The filename part of the URL should follow the naming convention | |||||
``filename-<sha256>.ext`` where ``<sha256>`` is the first eight or more | |||||
digits of the SHA256 hash of the contents of the file. The hash is used to | |||||
ensure unique names and to verify the contents of the file. | |||||
The default value of `model_dir` is ``$TORCH_HOME/models`` where | |||||
``$TORCH_HOME`` defaults to ``~/.torch``. The default directory can be | |||||
overridden with the ``$TORCH_MODEL_ZOO`` environment variable. | |||||
Args: | |||||
url (string): URL of the object to download | |||||
model_dir (string, optional): directory in which to save the object | |||||
map_location (optional): a function or a dict specifying how to remap storage locations (see torch.load) | |||||
progress (bool, optional): whether or not to display a progress bar to stderr | |||||
Example: | |||||
# >>> state_dict = model_zoo.load_url('https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth') | |||||
""" | |||||
if model_dir is None: | |||||
torch_home = os.path.expanduser(os.getenv('fastNLP_HOME', '~/.fastNLP')) | |||||
model_dir = os.getenv('fastNLP_MODEL_ZOO', os.path.join(torch_home, 'models')) | |||||
if not os.path.exists(model_dir): | |||||
os.makedirs(model_dir) | |||||
parts = urlparse(url) | |||||
filename = os.path.basename(parts.path) | |||||
cached_file = os.path.join(model_dir, filename) | |||||
if not os.path.exists(cached_file): | |||||
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) | |||||
# hash_prefix = HASH_REGEX.search(filename).group(1) | |||||
_download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) | |||||
return torch.load(cached_file, map_location=map_location) | |||||
def _download_url_to_file(url, dst, hash_prefix, progress): | |||||
if requests_available: | |||||
u = urlopen(url, stream=True) | |||||
file_size = int(u.headers["Content-Length"]) | |||||
u = u.raw | |||||
else: | |||||
u = urlopen(url) | |||||
meta = u.info() | |||||
if hasattr(meta, 'getheaders'): | |||||
file_size = int(meta.getheaders("Content-Length")[0]) | |||||
else: | |||||
file_size = int(meta.get_all("Content-Length")[0]) | |||||
f = tempfile.NamedTemporaryFile(delete=False) | |||||
try: | |||||
if hash_prefix is not None: | |||||
sha256 = hashlib.sha256() | |||||
with tqdm(total=file_size, disable=not progress) as pbar: | |||||
while True: | |||||
buffer = u.read(8192) | |||||
if len(buffer) == 0: | |||||
break | |||||
f.write(buffer) | |||||
if hash_prefix is not None: | |||||
sha256.update(buffer) | |||||
pbar.update(len(buffer)) | |||||
f.close() | |||||
if hash_prefix is not None: | |||||
digest = sha256.hexdigest() | |||||
if digest[:len(hash_prefix)] != hash_prefix: | |||||
raise RuntimeError('invalid hash value (expected "{}", got "{}")' | |||||
.format(hash_prefix, digest)) | |||||
shutil.move(f.name, dst) | |||||
finally: | |||||
f.close() | |||||
if os.path.exists(f.name): | |||||
os.remove(f.name) | |||||
if tqdm is None: | |||||
# fake tqdm if it's not installed | |||||
class tqdm(object): | |||||
def __init__(self, total, disable=False): | |||||
self.total = total | |||||
self.disable = disable | |||||
self.n = 0 | |||||
def update(self, n): | |||||
if self.disable: | |||||
return | |||||
self.n += n | |||||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(self.total))) | |||||
sys.stderr.flush() | |||||
def __enter__(self): | |||||
return self | |||||
def __exit__(self, exc_type, exc_val, exc_tb): | |||||
if self.disable: | |||||
return | |||||
sys.stderr.write('\n') | |||||
@@ -1,223 +0,0 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""A module with NAS controller-related code.""" | |||||
import collections | |||||
import os | |||||
import torch | |||||
import torch.nn.functional as F | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.automl.enas_utils import Node | |||||
def _construct_dags(prev_nodes, activations, func_names, num_blocks): | |||||
"""Constructs a set of DAGs based on the actions, i.e., previous nodes and | |||||
activation functions, sampled from the controller/policy pi. | |||||
Args: | |||||
prev_nodes: Previous node actions from the policy. | |||||
activations: Activations sampled from the policy. | |||||
func_names: Mapping from activation function names to functions. | |||||
num_blocks: Number of blocks in the target RNN cell. | |||||
Returns: | |||||
A list of DAGs defined by the inputs. | |||||
RNN cell DAGs are represented in the following way: | |||||
1. Each element (node) in a DAG is a list of `Node`s. | |||||
2. The `Node`s in the list dag[i] correspond to the subsequent nodes | |||||
that take the output from node i as their own input. | |||||
3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. | |||||
dag[-1] always feeds dag[0]. | |||||
dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its | |||||
weights. | |||||
4. dag[N - 1] is the node that produces the hidden state passed to | |||||
the next timestep. dag[N - 1] is also always a leaf node, and therefore | |||||
is always averaged with the other leaf nodes and fed to the output | |||||
decoder. | |||||
""" | |||||
dags = [] | |||||
for nodes, func_ids in zip(prev_nodes, activations): | |||||
dag = collections.defaultdict(list) | |||||
# add first node | |||||
dag[-1] = [Node(0, func_names[func_ids[0]])] | |||||
dag[-2] = [Node(0, func_names[func_ids[0]])] | |||||
# add following nodes | |||||
for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): | |||||
dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) | |||||
leaf_nodes = set(range(num_blocks)) - dag.keys() | |||||
# merge with avg | |||||
for idx in leaf_nodes: | |||||
dag[idx] = [Node(num_blocks, 'avg')] | |||||
# This is actually y^{(t)}. h^{(t)} is node N - 1 in | |||||
# the graph, where N Is the number of nodes. I.e., h^{(t)} takes | |||||
# only one other node as its input. | |||||
# last h[t] node | |||||
last_node = Node(num_blocks + 1, 'h[t]') | |||||
dag[num_blocks] = [last_node] | |||||
dags.append(dag) | |||||
return dags | |||||
class Controller(torch.nn.Module): | |||||
"""Based on | |||||
https://github.com/pytorch/examples/blob/master/word_language_model/model.py | |||||
RL controllers do not necessarily have much to do with | |||||
language models. | |||||
Base the controller RNN on the GRU from: | |||||
https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py | |||||
""" | |||||
def __init__(self, num_blocks=4, controller_hid=100, cuda=False): | |||||
torch.nn.Module.__init__(self) | |||||
# `num_tokens` here is just the activation function | |||||
# for every even step, | |||||
self.shared_rnn_activations = ['tanh', 'ReLU', 'identity', 'sigmoid'] | |||||
self.num_tokens = [len(self.shared_rnn_activations)] | |||||
self.controller_hid = controller_hid | |||||
self.use_cuda = cuda | |||||
self.num_blocks = num_blocks | |||||
for idx in range(num_blocks): | |||||
self.num_tokens += [idx + 1, len(self.shared_rnn_activations)] | |||||
self.func_names = self.shared_rnn_activations | |||||
num_total_tokens = sum(self.num_tokens) | |||||
self.encoder = torch.nn.Embedding(num_total_tokens, | |||||
controller_hid) | |||||
self.lstm = torch.nn.LSTMCell(controller_hid, controller_hid) | |||||
# Perhaps these weights in the decoder should be | |||||
# shared? At least for the activation functions, which all have the | |||||
# same size. | |||||
self.decoders = [] | |||||
for idx, size in enumerate(self.num_tokens): | |||||
decoder = torch.nn.Linear(controller_hid, size) | |||||
self.decoders.append(decoder) | |||||
self._decoders = torch.nn.ModuleList(self.decoders) | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def _get_default_hidden(key): | |||||
return utils.get_variable( | |||||
torch.zeros(key, self.controller_hid), | |||||
self.use_cuda, | |||||
requires_grad=False) | |||||
self.static_inputs = utils.keydefaultdict(_get_default_hidden) | |||||
def reset_parameters(self): | |||||
init_range = 0.1 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
for decoder in self.decoders: | |||||
decoder.bias.data.fill_(0) | |||||
def forward(self, # pylint:disable=arguments-differ | |||||
inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed): | |||||
if not is_embed: | |||||
embed = self.encoder(inputs) | |||||
else: | |||||
embed = inputs | |||||
hx, cx = self.lstm(embed, hidden) | |||||
logits = self.decoders[block_idx](hx) | |||||
logits /= 5.0 | |||||
# # exploration | |||||
# if self.args.mode == 'train': | |||||
# logits = (2.5 * F.tanh(logits)) | |||||
return logits, (hx, cx) | |||||
def sample(self, batch_size=1, with_details=False, save_dir=None): | |||||
"""Samples a set of `args.num_blocks` many computational nodes from the | |||||
controller, where each node is made up of an activation function, and | |||||
each node except the last also includes a previous node. | |||||
""" | |||||
if batch_size < 1: | |||||
raise Exception(f'Wrong batch_size: {batch_size} < 1') | |||||
# [B, L, H] | |||||
inputs = self.static_inputs[batch_size] | |||||
hidden = self.static_init_hidden[batch_size] | |||||
activations = [] | |||||
entropies = [] | |||||
log_probs = [] | |||||
prev_nodes = [] | |||||
# The RNN controller alternately outputs an activation, | |||||
# followed by a previous node, for each block except the last one, | |||||
# which only gets an activation function. The last node is the output | |||||
# node, and its previous node is the average of all leaf nodes. | |||||
for block_idx in range(2*(self.num_blocks - 1) + 1): | |||||
logits, hidden = self.forward(inputs, | |||||
hidden, | |||||
block_idx, | |||||
is_embed=(block_idx == 0)) | |||||
probs = F.softmax(logits, dim=-1) | |||||
log_prob = F.log_softmax(logits, dim=-1) | |||||
# .mean() for entropy? | |||||
entropy = -(log_prob * probs).sum(1, keepdim=False) | |||||
action = probs.multinomial(num_samples=1).data | |||||
selected_log_prob = log_prob.gather( | |||||
1, utils.get_variable(action, requires_grad=False)) | |||||
# why the [:, 0] here? Should it be .squeeze(), or | |||||
# .view()? Same below with `action`. | |||||
entropies.append(entropy) | |||||
log_probs.append(selected_log_prob[:, 0]) | |||||
# 0: function, 1: previous node | |||||
mode = block_idx % 2 | |||||
inputs = utils.get_variable( | |||||
action[:, 0] + sum(self.num_tokens[:mode]), | |||||
requires_grad=False) | |||||
if mode == 0: | |||||
activations.append(action[:, 0]) | |||||
elif mode == 1: | |||||
prev_nodes.append(action[:, 0]) | |||||
prev_nodes = torch.stack(prev_nodes).transpose(0, 1) | |||||
activations = torch.stack(activations).transpose(0, 1) | |||||
dags = _construct_dags(prev_nodes, | |||||
activations, | |||||
self.func_names, | |||||
self.num_blocks) | |||||
if save_dir is not None: | |||||
for idx, dag in enumerate(dags): | |||||
utils.draw_network(dag, | |||||
os.path.join(save_dir, f'graph{idx}.png')) | |||||
if with_details: | |||||
return dags, torch.cat(log_probs), torch.cat(entropies) | |||||
return dags | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.controller_hid) | |||||
return (utils.get_variable(zeros, self.use_cuda, requires_grad=False), | |||||
utils.get_variable(zeros.clone(), self.use_cuda, requires_grad=False)) |
@@ -1,388 +0,0 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
"""Module containing the shared RNN model.""" | |||||
import collections | |||||
import numpy as np | |||||
import torch | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
from torch.autograd import Variable | |||||
import fastNLP.automl.enas_utils as utils | |||||
from fastNLP.models.base_model import BaseModel | |||||
def _get_dropped_weights(w_raw, dropout_p, is_training): | |||||
"""Drops out weights to implement DropConnect. | |||||
Args: | |||||
w_raw: Full, pre-dropout, weights to be dropped out. | |||||
dropout_p: Proportion of weights to drop out. | |||||
is_training: True iff _shared_ model is training. | |||||
Returns: | |||||
The dropped weights. | |||||
Why does torch.nn.functional.dropout() return: | |||||
1. `torch.autograd.Variable()` on the training loop | |||||
2. `torch.nn.Parameter()` on the controller or eval loop, when | |||||
training = False... | |||||
Even though the call to `_setweights` in the Smerity repo's | |||||
`weight_drop.py` does not have this behaviour, and `F.dropout` always | |||||
returns `torch.autograd.Variable` there, even when `training=False`? | |||||
The above TODO is the reason for the hacky check for `torch.nn.Parameter`. | |||||
""" | |||||
dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) | |||||
if isinstance(dropped_w, torch.nn.Parameter): | |||||
dropped_w = dropped_w.clone() | |||||
return dropped_w | |||||
class EmbeddingDropout(torch.nn.Embedding): | |||||
"""Class for dropping out embeddings by zero'ing out parameters in the | |||||
embedding matrix. | |||||
This is equivalent to dropping out particular words, e.g., in the sentence | |||||
'the quick brown fox jumps over the lazy dog', dropping out 'the' would | |||||
lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the | |||||
embedding vector space). | |||||
See 'A Theoretically Grounded Application of Dropout in Recurrent Neural | |||||
Networks', (Gal and Ghahramani, 2016). | |||||
""" | |||||
def __init__(self, | |||||
num_embeddings, | |||||
embedding_dim, | |||||
max_norm=None, | |||||
norm_type=2, | |||||
scale_grad_by_freq=False, | |||||
sparse=False, | |||||
dropout=0.1, | |||||
scale=None): | |||||
"""Embedding constructor. | |||||
Args: | |||||
dropout: Dropout probability. | |||||
scale: Used to scale parameters of embedding weight matrix that are | |||||
not dropped out. Note that this is _in addition_ to the | |||||
`1/(1 - dropout)` scaling. | |||||
See `torch.nn.Embedding` for remaining arguments. | |||||
""" | |||||
torch.nn.Embedding.__init__(self, | |||||
num_embeddings=num_embeddings, | |||||
embedding_dim=embedding_dim, | |||||
max_norm=max_norm, | |||||
norm_type=norm_type, | |||||
scale_grad_by_freq=scale_grad_by_freq, | |||||
sparse=sparse) | |||||
self.dropout = dropout | |||||
assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' | |||||
'and < 1.0') | |||||
self.scale = scale | |||||
def forward(self, inputs): # pylint:disable=arguments-differ | |||||
"""Embeds `inputs` with the dropped out embedding weight matrix.""" | |||||
if self.training: | |||||
dropout = self.dropout | |||||
else: | |||||
dropout = 0 | |||||
if dropout: | |||||
mask = self.weight.data.new(self.weight.size(0), 1) | |||||
mask.bernoulli_(1 - dropout) | |||||
mask = mask.expand_as(self.weight) | |||||
mask = mask / (1 - dropout) | |||||
masked_weight = self.weight * Variable(mask) | |||||
else: | |||||
masked_weight = self.weight | |||||
if self.scale and self.scale != 1: | |||||
masked_weight = masked_weight * self.scale | |||||
return F.embedding(inputs, | |||||
masked_weight, | |||||
max_norm=self.max_norm, | |||||
norm_type=self.norm_type, | |||||
scale_grad_by_freq=self.scale_grad_by_freq, | |||||
sparse=self.sparse) | |||||
class LockedDropout(nn.Module): | |||||
# code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py | |||||
def __init__(self): | |||||
super().__init__() | |||||
def forward(self, x, dropout=0.5): | |||||
if not self.training or not dropout: | |||||
return x | |||||
m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) | |||||
mask = Variable(m, requires_grad=False) / (1 - dropout) | |||||
mask = mask.expand_as(x) | |||||
return mask * x | |||||
class ENASModel(BaseModel): | |||||
"""Shared RNN model.""" | |||||
def __init__(self, embed_num, num_classes, num_blocks=4, cuda=False, shared_hid=1000, shared_embed=1000): | |||||
super(ENASModel, self).__init__() | |||||
self.use_cuda = cuda | |||||
self.shared_hid = shared_hid | |||||
self.num_blocks = num_blocks | |||||
self.decoder = nn.Linear(self.shared_hid, num_classes) | |||||
self.encoder = EmbeddingDropout(embed_num, | |||||
shared_embed, | |||||
dropout=0.1) | |||||
self.lockdrop = LockedDropout() | |||||
self.dag = None | |||||
# Tie weights | |||||
# self.decoder.weight = self.encoder.weight | |||||
# Since W^{x, c} and W^{h, c} are always summed, there | |||||
# is no point duplicating their bias offset parameter. Likewise for | |||||
# W^{x, h} and W^{h, h}. | |||||
self.w_xc = nn.Linear(shared_embed, self.shared_hid) | |||||
self.w_xh = nn.Linear(shared_embed, self.shared_hid) | |||||
# The raw weights are stored here because the hidden-to-hidden weights | |||||
# are weight dropped on the forward pass. | |||||
self.w_hc_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hh_raw = torch.nn.Parameter( | |||||
torch.Tensor(self.shared_hid, self.shared_hid)) | |||||
self.w_hc = None | |||||
self.w_hh = None | |||||
self.w_h = collections.defaultdict(dict) | |||||
self.w_c = collections.defaultdict(dict) | |||||
for idx in range(self.num_blocks): | |||||
for jdx in range(idx + 1, self.num_blocks): | |||||
self.w_h[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self.w_c[idx][jdx] = nn.Linear(self.shared_hid, | |||||
self.shared_hid, | |||||
bias=False) | |||||
self._w_h = nn.ModuleList([self.w_h[idx][jdx] | |||||
for idx in self.w_h | |||||
for jdx in self.w_h[idx]]) | |||||
self._w_c = nn.ModuleList([self.w_c[idx][jdx] | |||||
for idx in self.w_c | |||||
for jdx in self.w_c[idx]]) | |||||
self.batch_norm = None | |||||
# if args.mode == 'train': | |||||
# self.batch_norm = nn.BatchNorm1d(self.shared_hid) | |||||
# else: | |||||
# self.batch_norm = None | |||||
self.reset_parameters() | |||||
self.static_init_hidden = utils.keydefaultdict(self.init_hidden) | |||||
def setDAG(self, dag): | |||||
if self.dag is None: | |||||
self.dag = dag | |||||
def forward(self, word_seq, hidden=None): | |||||
inputs = torch.transpose(word_seq, 0, 1) | |||||
time_steps = inputs.size(0) | |||||
batch_size = inputs.size(1) | |||||
self.w_hh = _get_dropped_weights(self.w_hh_raw, | |||||
0.5, | |||||
self.training) | |||||
self.w_hc = _get_dropped_weights(self.w_hc_raw, | |||||
0.5, | |||||
self.training) | |||||
# hidden = self.static_init_hidden[batch_size] if hidden is None else hidden | |||||
hidden = self.static_init_hidden[batch_size] | |||||
embed = self.encoder(inputs) | |||||
embed = self.lockdrop(embed, 0.65 if self.training else 0) | |||||
# The norm of hidden states are clipped here because | |||||
# otherwise ENAS is especially prone to exploding activations on the | |||||
# forward pass. This could probably be fixed in a more elegant way, but | |||||
# it might be exposing a weakness in the ENAS algorithm as currently | |||||
# proposed. | |||||
# | |||||
# For more details, see | |||||
# https://github.com/carpedm20/ENAS-pytorch/issues/6 | |||||
clipped_num = 0 | |||||
max_clipped_norm = 0 | |||||
h1tohT = [] | |||||
logits = [] | |||||
for step in range(time_steps): | |||||
x_t = embed[step] | |||||
logit, hidden = self.cell(x_t, hidden, self.dag) | |||||
hidden_norms = hidden.norm(dim=-1) | |||||
max_norm = 25.0 | |||||
if hidden_norms.data.max() > max_norm: | |||||
# Just directly use the torch slice operations | |||||
# in PyTorch v0.4. | |||||
# | |||||
# This workaround for PyTorch v0.3.1 does everything in numpy, | |||||
# because the PyTorch slicing and slice assignment is too | |||||
# flaky. | |||||
hidden_norms = hidden_norms.data.cpu().numpy() | |||||
clipped_num += 1 | |||||
if hidden_norms.max() > max_clipped_norm: | |||||
max_clipped_norm = hidden_norms.max() | |||||
clip_select = hidden_norms > max_norm | |||||
clip_norms = hidden_norms[clip_select] | |||||
mask = np.ones(hidden.size()) | |||||
normalizer = max_norm/clip_norms | |||||
normalizer = normalizer[:, np.newaxis] | |||||
mask[clip_select] = normalizer | |||||
if self.use_cuda: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask).cuda(), requires_grad=False) | |||||
else: | |||||
hidden *= torch.autograd.Variable( | |||||
torch.FloatTensor(mask), requires_grad=False) | |||||
logits.append(logit) | |||||
h1tohT.append(hidden) | |||||
h1tohT = torch.stack(h1tohT) | |||||
output = torch.stack(logits) | |||||
raw_output = output | |||||
output = self.lockdrop(output, 0.4 if self.training else 0) | |||||
#Pooling | |||||
output = torch.mean(output, 0) | |||||
decoded = self.decoder(output) | |||||
extra_out = {'dropped': decoded, | |||||
'hiddens': h1tohT, | |||||
'raw': raw_output} | |||||
return {'pred': decoded, 'hidden': hidden, 'extra_out': extra_out} | |||||
def cell(self, x, h_prev, dag): | |||||
"""Computes a single pass through the discovered RNN cell.""" | |||||
c = {} | |||||
h = {} | |||||
f = {} | |||||
f[0] = self.get_f(dag[-1][0].name) | |||||
c[0] = torch.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) | |||||
h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + | |||||
(1 - c[0])*h_prev) | |||||
leaf_node_ids = [] | |||||
q = collections.deque() | |||||
q.append(0) | |||||
# Computes connections from the parent nodes `node_id` | |||||
# to their child nodes `next_id` recursively, skipping leaf nodes. A | |||||
# leaf node is a node whose id == `self.num_blocks`. | |||||
# | |||||
# Connections between parent i and child j should be computed as | |||||
# h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i, | |||||
# where c_j = \sigmoid{(W^c_{ij}*h_i)} | |||||
# | |||||
# See Training details from Section 3.1 of the paper. | |||||
# | |||||
# The following algorithm does a breadth-first (since `q.popleft()` is | |||||
# used) search over the nodes and computes all the hidden states. | |||||
while True: | |||||
if len(q) == 0: | |||||
break | |||||
node_id = q.popleft() | |||||
nodes = dag[node_id] | |||||
for next_node in nodes: | |||||
next_id = next_node.id | |||||
if next_id == self.num_blocks: | |||||
leaf_node_ids.append(node_id) | |||||
assert len(nodes) == 1, ('parent of leaf node should have ' | |||||
'only one child') | |||||
continue | |||||
w_h = self.w_h[node_id][next_id] | |||||
w_c = self.w_c[node_id][next_id] | |||||
f[next_id] = self.get_f(next_node.name) | |||||
c[next_id] = torch.sigmoid(w_c(h[node_id])) | |||||
h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) + | |||||
(1 - c[next_id])*h[node_id]) | |||||
q.append(next_id) | |||||
# Instead of averaging loose ends, perhaps there should | |||||
# be a set of separate unshared weights for each "loose" connection | |||||
# between each node in a cell and the output. | |||||
# | |||||
# As it stands, all weights W^h_{ij} are doing double duty by | |||||
# connecting both from i to j, as well as from i to the output. | |||||
# average all the loose ends | |||||
leaf_nodes = [h[node_id] for node_id in leaf_node_ids] | |||||
output = torch.mean(torch.stack(leaf_nodes, 2), -1) | |||||
# stabilizing the Updates of omega | |||||
if self.batch_norm is not None: | |||||
output = self.batch_norm(output) | |||||
return output, h[self.num_blocks - 1] | |||||
def init_hidden(self, batch_size): | |||||
zeros = torch.zeros(batch_size, self.shared_hid) | |||||
return utils.get_variable(zeros, self.use_cuda, requires_grad=False) | |||||
def get_f(self, name): | |||||
name = name.lower() | |||||
if name == 'relu': | |||||
f = torch.relu | |||||
elif name == 'tanh': | |||||
f = torch.tanh | |||||
elif name == 'identity': | |||||
f = lambda x: x | |||||
elif name == 'sigmoid': | |||||
f = torch.sigmoid | |||||
return f | |||||
@property | |||||
def num_parameters(self): | |||||
def size(p): | |||||
return np.prod(p.size()) | |||||
return sum([size(param) for param in self.parameters()]) | |||||
def reset_parameters(self): | |||||
init_range = 0.025 | |||||
# init_range = 0.025 if self.args.mode == 'train' else 0.04 | |||||
for param in self.parameters(): | |||||
param.data.uniform_(-init_range, init_range) | |||||
self.decoder.bias.data.fill_(0) | |||||
def predict(self, word_seq): | |||||
""" | |||||
:param word_seq: torch.LongTensor, [batch_size, seq_len] | |||||
:return predict: dict of torch.LongTensor, [batch_size, seq_len] | |||||
""" | |||||
output = self(word_seq) | |||||
_, predict = output['pred'].max(dim=1) | |||||
return {'pred': predict} |
@@ -1,383 +0,0 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
import math | |||||
import time | |||||
from datetime import datetime | |||||
from datetime import timedelta | |||||
import numpy as np | |||||
import torch | |||||
try: | |||||
from tqdm.auto import tqdm | |||||
except: | |||||
from fastNLP.core.utils import _pseudo_tqdm as tqdm | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.callback import CallbackException | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
import fastNLP | |||||
from . import enas_utils as utils | |||||
from fastNLP.core.utils import _build_args | |||||
from torch.optim import Adam | |||||
def _get_no_grad_ctx_mgr(): | |||||
"""Returns a the `torch.no_grad` context manager for PyTorch version >= | |||||
0.4, or a no-op context manager otherwise. | |||||
""" | |||||
return torch.no_grad() | |||||
class ENASTrainer(fastNLP.Trainer): | |||||
"""A class to wrap training code.""" | |||||
def __init__(self, train_data, model, controller, **kwargs): | |||||
"""Constructor for training algorithm. | |||||
:param DataSet train_data: the training data | |||||
:param torch.nn.modules.module model: a PyTorch model | |||||
:param torch.nn.modules.module controller: a PyTorch model | |||||
""" | |||||
self.final_epochs = kwargs['final_epochs'] | |||||
kwargs.pop('final_epochs') | |||||
super(ENASTrainer, self).__init__(train_data, model, **kwargs) | |||||
self.controller_step = 0 | |||||
self.shared_step = 0 | |||||
self.max_length = 35 | |||||
self.shared = model | |||||
self.controller = controller | |||||
self.shared_optim = Adam( | |||||
self.shared.parameters(), | |||||
lr=20.0, | |||||
weight_decay=1e-7) | |||||
self.controller_optim = Adam( | |||||
self.controller.parameters(), | |||||
lr=3.5e-4) | |||||
def train(self, load_best_model=True): | |||||
""" | |||||
:param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 | |||||
最好的模型参数。 | |||||
:return results: 返回一个字典类型的数据, | |||||
内含以下内容:: | |||||
seconds: float, 表示训练时长 | |||||
以下三个内容只有在提供了dev_data的情况下会有。 | |||||
best_eval: Dict of Dict, 表示evaluation的结果 | |||||
best_epoch: int,在第几个epoch取得的最佳值 | |||||
best_step: int, 在第几个step(batch)更新取得的最佳值 | |||||
""" | |||||
results = {} | |||||
if self.n_epochs <= 0: | |||||
print(f"training epoch is {self.n_epochs}, nothing was done.") | |||||
results['seconds'] = 0. | |||||
return results | |||||
try: | |||||
if torch.cuda.is_available() and self.use_cuda: | |||||
self.model = self.model.cuda() | |||||
self._model_device = self.model.parameters().__next__().device | |||||
self._mode(self.model, is_test=False) | |||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||||
start_time = time.time() | |||||
print("training epochs started " + self.start_time, flush=True) | |||||
try: | |||||
self.callback_manager.on_train_begin() | |||||
self._train() | |||||
self.callback_manager.on_train_end(self.model) | |||||
except (CallbackException, KeyboardInterrupt) as e: | |||||
self.callback_manager.on_exception(e, self.model) | |||||
if self.dev_data is not None: | |||||
print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||||
self.tester._format_eval_results(self.best_dev_perf),) | |||||
results['best_eval'] = self.best_dev_perf | |||||
results['best_epoch'] = self.best_dev_epoch | |||||
results['best_step'] = self.best_dev_step | |||||
if load_best_model: | |||||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | |||||
load_succeed = self._load_model(self.model, model_name) | |||||
if load_succeed: | |||||
print("Reloaded the best model.") | |||||
else: | |||||
print("Fail to reload best model.") | |||||
finally: | |||||
pass | |||||
results['seconds'] = round(time.time() - start_time, 2) | |||||
return results | |||||
def _train(self): | |||||
if not self.use_tqdm: | |||||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||||
else: | |||||
inner_tqdm = tqdm | |||||
self.step = 0 | |||||
start = time.time() | |||||
total_steps = (len(self.train_data) // self.batch_size + int( | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||||
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||||
avg_loss = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for epoch in range(1, self.n_epochs+1): | |||||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||||
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs) | |||||
if epoch == self.n_epochs + 1 - self.final_epochs: | |||||
print('Entering the final stage. (Only train the selected structure)') | |||||
# early stopping | |||||
self.callback_manager.on_epoch_begin(epoch, self.n_epochs) | |||||
# 1. Training the shared parameters omega of the child models | |||||
self.train_shared(pbar) | |||||
# 2. Training the controller parameters theta | |||||
if not last_stage: | |||||
self.train_controller() | |||||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||||
and self.dev_data is not None: | |||||
if not last_stage: | |||||
self.derive() | |||||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||||
total_steps) + \ | |||||
self.tester._format_eval_results(eval_res) | |||||
pbar.write(eval_str) | |||||
# lr decay; early stopping | |||||
self.callback_manager.on_epoch_end(epoch, self.n_epochs, self.optimizer) | |||||
# =============== epochs end =================== # | |||||
pbar.close() | |||||
# ============ tqdm end ============== # | |||||
def get_loss(self, inputs, targets, hidden, dags): | |||||
"""Computes the loss for the same batch for M models. | |||||
This amounts to an estimate of the loss, which is turned into an | |||||
estimate for the gradients of the shared model. | |||||
""" | |||||
if not isinstance(dags, list): | |||||
dags = [dags] | |||||
loss = 0 | |||||
for dag in dags: | |||||
self.shared.setDAG(dag) | |||||
inputs = _build_args(self.shared.forward, **inputs) | |||||
inputs['hidden'] = hidden | |||||
result = self.shared(**inputs) | |||||
output, hidden, extra_out = result['pred'], result['hidden'], result['extra_out'] | |||||
self.callback_manager.on_loss_begin(targets, result) | |||||
sample_loss = self._compute_loss(result, targets) | |||||
loss += sample_loss | |||||
assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' | |||||
return loss, hidden, extra_out | |||||
def train_shared(self, pbar=None, max_step=None, dag=None): | |||||
"""Train the language model for 400 steps of minibatches of 64 | |||||
examples. | |||||
Args: | |||||
max_step: Used to run extra training steps as a warm-up. | |||||
dag: If not None, is used instead of calling sample(). | |||||
BPTT is truncated at 35 timesteps. | |||||
For each weight update, gradients are estimated by sampling M models | |||||
from the fixed controller policy, and averaging their gradients | |||||
computed on a batch of training data. | |||||
""" | |||||
model = self.shared | |||||
model.train() | |||||
self.controller.eval() | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
abs_max_grad = 0 | |||||
abs_max_hidden_norm = 0 | |||||
step = 0 | |||||
raw_total_loss = 0 | |||||
total_loss = 0 | |||||
train_idx = 0 | |||||
avg_loss = 0 | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for batch_x, batch_y in data_iterator: | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
indices = data_iterator.get_batch_indices() | |||||
# negative sampling; replace unknown; re-weight batch_y | |||||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||||
# prediction = self._data_forward(self.model, batch_x) | |||||
dags = self.controller.sample(1) | |||||
inputs, targets = batch_x, batch_y | |||||
# self.callback_manager.on_loss_begin(batch_y, prediction) | |||||
loss, hidden, extra_out = self.get_loss(inputs, | |||||
targets, | |||||
hidden, | |||||
dags) | |||||
hidden.detach_() | |||||
avg_loss += loss.item() | |||||
# Is loss NaN or inf? requires_grad = False | |||||
self.callback_manager.on_backward_begin(loss, self.model) | |||||
self._grad_backward(loss) | |||||
self.callback_manager.on_backward_end(self.model) | |||||
self._update() | |||||
self.callback_manager.on_step_end(self.optimizer) | |||||
if (self.step+1) % self.print_every == 0: | |||||
if self.use_tqdm: | |||||
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every) | |||||
pbar.update(self.print_every) | |||||
else: | |||||
end = time.time() | |||||
diff = timedelta(seconds=round(end - start)) | |||||
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( | |||||
epoch, self.step, avg_loss, diff) | |||||
pbar.set_postfix_str(print_output) | |||||
avg_loss = 0 | |||||
self.step += 1 | |||||
step += 1 | |||||
self.shared_step += 1 | |||||
self.callback_manager.on_batch_end() | |||||
# ================= mini-batch end ==================== # | |||||
def get_reward(self, dag, entropies, hidden, valid_idx=0): | |||||
"""Computes the perplexity of a single sampled model on a minibatch of | |||||
validation data. | |||||
""" | |||||
if not isinstance(entropies, np.ndarray): | |||||
entropies = entropies.data.cpu().numpy() | |||||
data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||||
prefetch=self.prefetch) | |||||
for inputs, targets in data_iterator: | |||||
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) | |||||
valid_loss = utils.to_item(valid_loss.data) | |||||
valid_ppl = math.exp(valid_loss) | |||||
R = 80 / valid_ppl | |||||
rewards = R + 1e-4 * entropies | |||||
return rewards, hidden | |||||
def train_controller(self): | |||||
"""Fixes the shared parameters and updates the controller parameters. | |||||
The controller is updated with a score function gradient estimator | |||||
(i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl | |||||
is computed on a minibatch of validation data. | |||||
A moving average baseline is used. | |||||
The controller is trained for 2000 steps per epoch (i.e., | |||||
first (Train Shared) phase -> second (Train Controller) phase). | |||||
""" | |||||
model = self.controller | |||||
model.train() | |||||
# Why can't we call shared.eval() here? Leads to loss | |||||
# being uniformly zero for the controller. | |||||
# self.shared.eval() | |||||
avg_reward_base = None | |||||
baseline = None | |||||
adv_history = [] | |||||
entropy_history = [] | |||||
reward_history = [] | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
total_loss = 0 | |||||
valid_idx = 0 | |||||
for step in range(20): | |||||
# sample models | |||||
dags, log_probs, entropies = self.controller.sample( | |||||
with_details=True) | |||||
# calculate reward | |||||
np_entropies = entropies.data.cpu().numpy() | |||||
# No gradients should be backpropagated to the | |||||
# shared model during controller training, obviously. | |||||
with _get_no_grad_ctx_mgr(): | |||||
rewards, hidden = self.get_reward(dags, | |||||
np_entropies, | |||||
hidden, | |||||
valid_idx) | |||||
reward_history.extend(rewards) | |||||
entropy_history.extend(np_entropies) | |||||
# moving average baseline | |||||
if baseline is None: | |||||
baseline = rewards | |||||
else: | |||||
decay = 0.95 | |||||
baseline = decay * baseline + (1 - decay) * rewards | |||||
adv = rewards - baseline | |||||
adv_history.extend(adv) | |||||
# policy loss | |||||
loss = -log_probs*utils.get_variable(adv, | |||||
self.use_cuda, | |||||
requires_grad=False) | |||||
loss = loss.sum() # or loss.mean() | |||||
# update | |||||
self.controller_optim.zero_grad() | |||||
loss.backward() | |||||
self.controller_optim.step() | |||||
total_loss += utils.to_item(loss.data) | |||||
if ((step % 50) == 0) and (step > 0): | |||||
reward_history, adv_history, entropy_history = [], [], [] | |||||
total_loss = 0 | |||||
self.controller_step += 1 | |||||
# prev_valid_idx = valid_idx | |||||
# valid_idx = ((valid_idx + self.max_length) % | |||||
# (self.valid_data.size(0) - 1)) | |||||
# # Whenever we wrap around to the beginning of the | |||||
# # validation data, we reset the hidden states. | |||||
# if prev_valid_idx > valid_idx: | |||||
# hidden = self.shared.init_hidden(self.batch_size) | |||||
def derive(self, sample_num=10, valid_idx=0): | |||||
"""We are always deriving based on the very first batch | |||||
of validation data? This seems wrong... | |||||
""" | |||||
hidden = self.shared.init_hidden(self.batch_size) | |||||
dags, _, entropies = self.controller.sample(sample_num, | |||||
with_details=True) | |||||
max_R = 0 | |||||
best_dag = None | |||||
for dag in dags: | |||||
R, _ = self.get_reward(dag, entropies, hidden, valid_idx) | |||||
if R.max() > max_R: | |||||
max_R = R.max() | |||||
best_dag = dag | |||||
self.model.setDAG(best_dag) |
@@ -1,53 +0,0 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||||
from __future__ import print_function | |||||
import collections | |||||
from collections import defaultdict | |||||
import numpy as np | |||||
import torch | |||||
from torch.autograd import Variable | |||||
def detach(h): | |||||
if type(h) == Variable: | |||||
return Variable(h.data) | |||||
else: | |||||
return tuple(detach(v) for v in h) | |||||
def get_variable(inputs, cuda=False, **kwargs): | |||||
if type(inputs) in [list, np.ndarray]: | |||||
inputs = torch.Tensor(inputs) | |||||
if cuda: | |||||
out = Variable(inputs.cuda(), **kwargs) | |||||
else: | |||||
out = Variable(inputs, **kwargs) | |||||
return out | |||||
def update_lr(optimizer, lr): | |||||
for param_group in optimizer.param_groups: | |||||
param_group['lr'] = lr | |||||
Node = collections.namedtuple('Node', ['id', 'name']) | |||||
class keydefaultdict(defaultdict): | |||||
def __missing__(self, key): | |||||
if self.default_factory is None: | |||||
raise KeyError(key) | |||||
else: | |||||
ret = self[key] = self.default_factory(key) | |||||
return ret | |||||
def to_item(x): | |||||
"""Converts x, possibly scalar and possibly tensor, to a Python scalar.""" | |||||
if isinstance(x, (float, int)): | |||||
return x | |||||
if float(torch.__version__[0:3]) < 0.4: | |||||
assert (x.dim() == 1) and (len(x) == 1) | |||||
return x[0] | |||||
return x.item() |
@@ -1 +0,0 @@ | |||||
from .bert_tokenizer import BertTokenizer |
@@ -1,378 +0,0 @@ | |||||
""" | |||||
bert_tokenizer.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0. | |||||
""" | |||||
import collections | |||||
import os | |||||
import unicodedata | |||||
from io import open | |||||
PRETRAINED_VOCAB_ARCHIVE_MAP = { | |||||
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", | |||||
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-vocab.txt", | |||||
'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt", | |||||
'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-vocab.txt", | |||||
'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-vocab.txt", | |||||
'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt", | |||||
'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt", | |||||
} | |||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { | |||||
'bert-base-uncased': 512, | |||||
'bert-large-uncased': 512, | |||||
'bert-base-cased': 512, | |||||
'bert-large-cased': 512, | |||||
'bert-base-multilingual-uncased': 512, | |||||
'bert-base-multilingual-cased': 512, | |||||
'bert-base-chinese': 512, | |||||
} | |||||
VOCAB_NAME = 'vocab.txt' | |||||
def load_vocab(vocab_file): | |||||
"""Loads a vocabulary file into a dictionary.""" | |||||
vocab = collections.OrderedDict() | |||||
index = 0 | |||||
with open(vocab_file, "r", encoding="utf-8") as reader: | |||||
while True: | |||||
token = reader.readline() | |||||
if not token: | |||||
break | |||||
token = token.strip() | |||||
vocab[token] = index | |||||
index += 1 | |||||
return vocab | |||||
def whitespace_tokenize(text): | |||||
"""Runs basic whitespace cleaning and splitting on a piece of text.""" | |||||
text = text.strip() | |||||
if not text: | |||||
return [] | |||||
tokens = text.split() | |||||
return tokens | |||||
class BertTokenizer(object): | |||||
"""Runs end-to-end tokenization: punctuation splitting + wordpiece""" | |||||
def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True, | |||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||||
"""Constructs a BertTokenizer. | |||||
Args: | |||||
vocab_file: Path to a one-wordpiece-per-line vocabulary file | |||||
do_lower_case: Whether to lower case the input | |||||
Only has an effect when do_wordpiece_only=False | |||||
do_basic_tokenize: Whether to do basic tokenization before wordpiece. | |||||
max_len: An artificial maximum length to truncate tokenized sequences to; | |||||
Effective maximum length is always the minimum of this | |||||
value (if specified) and the underlying BERT model's | |||||
sequence length. | |||||
never_split: List of tokens which will never be split during tokenization. | |||||
Only has an effect when do_wordpiece_only=False | |||||
""" | |||||
if not os.path.isfile(vocab_file): | |||||
raise ValueError( | |||||
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained " | |||||
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file)) | |||||
self.vocab = load_vocab(vocab_file) | |||||
self.ids_to_tokens = collections.OrderedDict( | |||||
[(ids, tok) for tok, ids in self.vocab.items()]) | |||||
self.do_basic_tokenize = do_basic_tokenize | |||||
if do_basic_tokenize: | |||||
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, | |||||
never_split=never_split) | |||||
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab) | |||||
self.max_len = max_len if max_len is not None else int(1e12) | |||||
def tokenize(self, text): | |||||
split_tokens = [] | |||||
if self.do_basic_tokenize: | |||||
for token in self.basic_tokenizer.tokenize(text): | |||||
for sub_token in self.wordpiece_tokenizer.tokenize(token): | |||||
split_tokens.append(sub_token) | |||||
else: | |||||
split_tokens = self.wordpiece_tokenizer.tokenize(text) | |||||
return split_tokens | |||||
def convert_tokens_to_ids(self, tokens): | |||||
"""Converts a sequence of tokens into ids using the vocab.""" | |||||
ids = [] | |||||
for token in tokens: | |||||
ids.append(self.vocab[token]) | |||||
if len(ids) > self.max_len: | |||||
print( | |||||
"WARNING!\n\"" | |||||
"Token indices sequence length is longer than the specified maximum " | |||||
"sequence length for this BERT model ({} > {}). Running this" | |||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len) | |||||
) | |||||
return ids | |||||
def convert_ids_to_tokens(self, ids): | |||||
"""Converts a sequence of ids in wordpiece tokens using the vocab.""" | |||||
tokens = [] | |||||
for i in ids: | |||||
tokens.append(self.ids_to_tokens[i]) | |||||
return tokens | |||||
def save_vocabulary(self, vocab_path): | |||||
"""Save the tokenizer vocabulary to a directory or file.""" | |||||
index = 0 | |||||
if os.path.isdir(vocab_path): | |||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME) | |||||
with open(vocab_file, "w", encoding="utf-8") as writer: | |||||
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]): | |||||
if index != token_index: | |||||
print("Saving vocabulary to {}: vocabulary indices are not consecutive." | |||||
" Please check that the vocabulary is not corrupted!".format(vocab_file)) | |||||
index = token_index | |||||
writer.write(token + u'\n') | |||||
index += 1 | |||||
return vocab_file | |||||
@classmethod | |||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): | |||||
""" | |||||
Instantiate a PreTrainedBertModel from a pre-trained model file. | |||||
Download and cache the pre-trained model file if needed. | |||||
""" | |||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: | |||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] | |||||
if '-cased' in pretrained_model_name_or_path and kwargs.get('do_lower_case', True): | |||||
print("The pre-trained model you are loading is a cased model but you have not set " | |||||
"`do_lower_case` to False. We are setting `do_lower_case=False` for you but " | |||||
"you may want to check this behavior.") | |||||
kwargs['do_lower_case'] = False | |||||
elif '-cased' not in pretrained_model_name_or_path and not kwargs.get('do_lower_case', True): | |||||
print("The pre-trained model you are loading is an uncased model but you have set " | |||||
"`do_lower_case` to False. We are setting `do_lower_case=True` for you " | |||||
"but you may want to check this behavior.") | |||||
kwargs['do_lower_case'] = True | |||||
else: | |||||
vocab_file = pretrained_model_name_or_path | |||||
if os.path.isdir(vocab_file): | |||||
vocab_file = os.path.join(vocab_file, VOCAB_NAME) | |||||
# redirect to the cache, if necessary | |||||
resolved_vocab_file = vocab_file | |||||
print("loading vocabulary file {}".format(vocab_file)) | |||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: | |||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer | |||||
# than the number of positional embeddings | |||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] | |||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) | |||||
# Instantiate tokenizer. | |||||
tokenizer = cls(resolved_vocab_file, *inputs, **kwargs) | |||||
return tokenizer | |||||
class BasicTokenizer(object): | |||||
"""Runs basic tokenization (punctuation splitting, lower casing, etc.).""" | |||||
def __init__(self, | |||||
do_lower_case=True, | |||||
never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")): | |||||
"""Constructs a BasicTokenizer. | |||||
Args: | |||||
do_lower_case: Whether to lower case the input. | |||||
""" | |||||
self.do_lower_case = do_lower_case | |||||
self.never_split = never_split | |||||
def tokenize(self, text): | |||||
"""Tokenizes a piece of text.""" | |||||
text = self._clean_text(text) | |||||
# This was added on November 1st, 2018 for the multilingual and Chinese | |||||
# models. This is also applied to the English models now, but it doesn't | |||||
# matter since the English models were not trained on any Chinese data | |||||
# and generally don't have any Chinese data in them (there are Chinese | |||||
# characters in the vocabulary because Wikipedia does have some Chinese | |||||
# words in the English Wikipedia.). | |||||
text = self._tokenize_chinese_chars(text) | |||||
orig_tokens = whitespace_tokenize(text) | |||||
split_tokens = [] | |||||
for token in orig_tokens: | |||||
if self.do_lower_case and token not in self.never_split: | |||||
token = token.lower() | |||||
token = self._run_strip_accents(token) | |||||
split_tokens.extend(self._run_split_on_punc(token)) | |||||
output_tokens = whitespace_tokenize(" ".join(split_tokens)) | |||||
return output_tokens | |||||
def _run_strip_accents(self, text): | |||||
"""Strips accents from a piece of text.""" | |||||
text = unicodedata.normalize("NFD", text) | |||||
output = [] | |||||
for char in text: | |||||
cat = unicodedata.category(char) | |||||
if cat == "Mn": | |||||
continue | |||||
output.append(char) | |||||
return "".join(output) | |||||
def _run_split_on_punc(self, text): | |||||
"""Splits punctuation on a piece of text.""" | |||||
if text in self.never_split: | |||||
return [text] | |||||
chars = list(text) | |||||
i = 0 | |||||
start_new_word = True | |||||
output = [] | |||||
while i < len(chars): | |||||
char = chars[i] | |||||
if _is_punctuation(char): | |||||
output.append([char]) | |||||
start_new_word = True | |||||
else: | |||||
if start_new_word: | |||||
output.append([]) | |||||
start_new_word = False | |||||
output[-1].append(char) | |||||
i += 1 | |||||
return ["".join(x) for x in output] | |||||
def _tokenize_chinese_chars(self, text): | |||||
"""Adds whitespace around any CJK character.""" | |||||
output = [] | |||||
for char in text: | |||||
cp = ord(char) | |||||
if self._is_chinese_char(cp): | |||||
output.append(" ") | |||||
output.append(char) | |||||
output.append(" ") | |||||
else: | |||||
output.append(char) | |||||
return "".join(output) | |||||
def _is_chinese_char(self, cp): | |||||
"""Checks whether CP is the codepoint of a CJK character.""" | |||||
# This defines a "chinese character" as anything in the CJK Unicode block: | |||||
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) | |||||
# | |||||
# Note that the CJK Unicode block is NOT all Japanese and Korean characters, | |||||
# despite its name. The modern Korean Hangul alphabet is a different block, | |||||
# as is Japanese Hiragana and Katakana. Those alphabets are used to write | |||||
# space-separated words, so they are not treated specially and handled | |||||
# like the all of the other languages. | |||||
if ((cp >= 0x4E00 and cp <= 0x9FFF) or # | |||||
(cp >= 0x3400 and cp <= 0x4DBF) or # | |||||
(cp >= 0x20000 and cp <= 0x2A6DF) or # | |||||
(cp >= 0x2A700 and cp <= 0x2B73F) or # | |||||
(cp >= 0x2B740 and cp <= 0x2B81F) or # | |||||
(cp >= 0x2B820 and cp <= 0x2CEAF) or | |||||
(cp >= 0xF900 and cp <= 0xFAFF) or # | |||||
(cp >= 0x2F800 and cp <= 0x2FA1F)): # | |||||
return True | |||||
return False | |||||
def _clean_text(self, text): | |||||
"""Performs invalid character removal and whitespace cleanup on text.""" | |||||
output = [] | |||||
for char in text: | |||||
cp = ord(char) | |||||
if cp == 0 or cp == 0xfffd or _is_control(char): | |||||
continue | |||||
if _is_whitespace(char): | |||||
output.append(" ") | |||||
else: | |||||
output.append(char) | |||||
return "".join(output) | |||||
class WordpieceTokenizer(object): | |||||
"""Runs WordPiece tokenization.""" | |||||
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100): | |||||
self.vocab = vocab | |||||
self.unk_token = unk_token | |||||
self.max_input_chars_per_word = max_input_chars_per_word | |||||
def tokenize(self, text): | |||||
"""Tokenizes a piece of text into its word pieces. | |||||
This uses a greedy longest-match-first algorithm to perform tokenization | |||||
using the given vocabulary. | |||||
For example: | |||||
input = "unaffable" | |||||
output = ["un", "##aff", "##able"] | |||||
Args: | |||||
text: A single token or whitespace separated tokens. This should have | |||||
already been passed through `BasicTokenizer`. | |||||
Returns: | |||||
A list of wordpiece tokens. | |||||
""" | |||||
output_tokens = [] | |||||
for token in whitespace_tokenize(text): | |||||
chars = list(token) | |||||
if len(chars) > self.max_input_chars_per_word: | |||||
output_tokens.append(self.unk_token) | |||||
continue | |||||
is_bad = False | |||||
start = 0 | |||||
sub_tokens = [] | |||||
while start < len(chars): | |||||
end = len(chars) | |||||
cur_substr = None | |||||
while start < end: | |||||
substr = "".join(chars[start:end]) | |||||
if start > 0: | |||||
substr = "##" + substr | |||||
if substr in self.vocab: | |||||
cur_substr = substr | |||||
break | |||||
end -= 1 | |||||
if cur_substr is None: | |||||
is_bad = True | |||||
break | |||||
sub_tokens.append(cur_substr) | |||||
start = end | |||||
if is_bad: | |||||
output_tokens.append(self.unk_token) | |||||
else: | |||||
output_tokens.extend(sub_tokens) | |||||
return output_tokens | |||||
def _is_whitespace(char): | |||||
"""Checks whether `chars` is a whitespace character.""" | |||||
# \t, \n, and \r are technically contorl characters but we treat them | |||||
# as whitespace since they are generally considered as such. | |||||
if char == " " or char == "\t" or char == "\n" or char == "\r": | |||||
return True | |||||
cat = unicodedata.category(char) | |||||
if cat == "Zs": | |||||
return True | |||||
return False | |||||
def _is_control(char): | |||||
"""Checks whether `chars` is a control character.""" | |||||
# These are technically control characters but we count them as whitespace | |||||
# characters. | |||||
if char == "\t" or char == "\n" or char == "\r": | |||||
return False | |||||
cat = unicodedata.category(char) | |||||
if cat.startswith("C"): | |||||
return True | |||||
return False | |||||
def _is_punctuation(char): | |||||
"""Checks whether `chars` is a punctuation character.""" | |||||
cp = ord(char) | |||||
# We treat all non-letter/number ASCII as punctuation. | |||||
# Characters such as "^", "$", and "`" are not in the Unicode | |||||
# Punctuation class but we treat them as punctuation anyways, for | |||||
# consistency. | |||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or | |||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)): | |||||
return True | |||||
cat = unicodedata.category(char) | |||||
if cat.startswith("P"): | |||||
return True | |||||
return False | |||||
@@ -182,8 +182,9 @@ class TestDataSetMethods(unittest.TestCase): | |||||
def test_apply2(self): | def test_apply2(self): | ||||
def split_sent(ins): | def split_sent(ins): | ||||
return ins['raw_sentence'].split() | return ins['raw_sentence'].split() | ||||
csv_loader = CSVLoader(headers=['raw_sentence', 'label'],sep='\t') | |||||
dataset = csv_loader.load('test/data_for_tests/tutorial_sample_dataset.csv') | |||||
csv_loader = CSVLoader(headers=['raw_sentence', 'label'], sep='\t') | |||||
data_bundle = csv_loader.load('test/data_for_tests/tutorial_sample_dataset.csv') | |||||
dataset = data_bundle.datasets['train'] | |||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | ||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | dataset.apply(split_sent, new_field_name='words', is_input=True) | ||||
# print(dataset) | # print(dataset) | ||||
@@ -1,15 +0,0 @@ | |||||
import unittest | |||||
from fastNLP.core.const import Const | |||||
from fastNLP.io.data_loader import MNLILoader | |||||
class TestDataLoader(unittest.TestCase): | |||||
def test_mnli_loader(self): | |||||
ds = MNLILoader().process('test/data_for_tests/sample_mnli.tsv', | |||||
to_lower=True, get_index=True, seq_len_type='mask') | |||||
self.assertTrue('train' in ds.datasets) | |||||
self.assertTrue(len(ds.datasets) == 1) | |||||
self.assertTrue(len(ds.datasets['train']) == 11) | |||||
self.assertTrue(isinstance(ds.datasets['train'][0][Const.INPUT_LENS(0)], list)) |
@@ -1,77 +0,0 @@ | |||||
import unittest | |||||
import os | |||||
from fastNLP.io import CSVLoader, JsonLoader | |||||
from fastNLP.io.data_loader import SSTLoader, SNLILoader, Conll2003Loader, PeopleDailyCorpusLoader | |||||
class TestDatasetLoader(unittest.TestCase): | |||||
def test_Conll2003Loader(self): | |||||
""" | |||||
Test the the loader of Conll2003 dataset | |||||
""" | |||||
dataset_path = "test/data_for_tests/conll_2003_example.txt" | |||||
loader = Conll2003Loader() | |||||
dataset_2003 = loader.load(dataset_path) | |||||
def test_PeopleDailyCorpusLoader(self): | |||||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | |||||
def test_CSVLoader(self): | |||||
ds = CSVLoader(sep='\t', headers=['words', 'label']) \ | |||||
.load('test/data_for_tests/tutorial_sample_dataset.csv') | |||||
assert len(ds) > 0 | |||||
def test_SNLILoader(self): | |||||
ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') | |||||
assert len(ds) == 3 | |||||
def test_JsonLoader(self): | |||||
ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl') | |||||
assert len(ds) == 3 | |||||
def no_test_SST(self): | |||||
train_data = """(3 (2 (2 The) (2 Rock)) (4 (3 (2 is) (4 (2 destined) (2 (2 (2 (2 (2 to) (2 (2 be) (2 (2 the) (2 (2 21st) (2 (2 (2 Century) (2 's)) (2 (3 new) (2 (2 ``) (2 Conan)))))))) (2 '')) (2 and)) (3 (2 that) (3 (2 he) (3 (2 's) (3 (2 going) (3 (2 to) (4 (3 (2 make) (3 (3 (2 a) (3 splash)) (2 (2 even) (3 greater)))) (2 (2 than) (2 (2 (2 (2 (1 (2 Arnold) (2 Schwarzenegger)) (2 ,)) (2 (2 Jean-Claud) (2 (2 Van) (2 Damme)))) (2 or)) (2 (2 Steven) (2 Segal))))))))))))) (2 .))) | |||||
(4 (4 (4 (2 The) (4 (3 gorgeously) (3 (2 elaborate) (2 continuation)))) (2 (2 (2 of) (2 ``)) (2 (2 The) (2 (2 (2 Lord) (2 (2 of) (2 (2 the) (2 Rings)))) (2 (2 '') (2 trilogy)))))) (2 (3 (2 (2 is) (2 (2 so) (2 huge))) (2 (2 that) (3 (2 (2 (2 a) (2 column)) (2 (2 of) (2 words))) (2 (2 (2 (2 can) (1 not)) (3 adequately)) (2 (2 describe) (2 (3 (2 (2 co-writer\/director) (2 (2 Peter) (3 (2 Jackson) (2 's)))) (3 (2 expanded) (2 vision))) (2 (2 of) (2 (2 (2 J.R.R.) (2 (2 Tolkien) (2 's))) (2 Middle-earth))))))))) (2 .))) | |||||
(3 (3 (2 (2 (2 (2 (2 Singer\/composer) (2 (2 Bryan) (2 Adams))) (2 (2 contributes) (2 (2 (2 a) (2 slew)) (2 (2 of) (2 songs))))) (2 (2 --) (2 (2 (2 (2 a) (2 (2 few) (3 potential))) (2 (2 (2 hits) (2 ,)) (2 (2 (2 a) (2 few)) (1 (1 (2 more) (1 (2 simply) (2 intrusive))) (2 (2 to) (2 (2 the) (2 story))))))) (2 --)))) (2 but)) (3 (4 (2 the) (3 (2 whole) (2 package))) (2 (3 certainly) (3 (2 captures) (2 (1 (2 the) (2 (2 (2 intended) (2 (2 ,) (2 (2 er) (2 ,)))) (3 spirit))) (2 (2 of) (2 (2 the) (2 piece)))))))) (2 .)) | |||||
(2 (2 (2 You) (2 (2 'd) (2 (2 think) (2 (2 by) (2 now))))) (2 (2 America) (2 (2 (2 would) (1 (2 have) (2 (2 (2 had) (1 (2 enough) (2 (2 of) (2 (2 plucky) (2 (2 British) (1 eccentrics)))))) (4 (2 with) (4 (3 hearts) (3 (2 of) (3 gold))))))) (2 .)))) | |||||
""" | |||||
test_data = """(3 (2 Yet) (3 (2 (2 the) (2 act)) (3 (4 (3 (2 is) (3 (2 still) (4 charming))) (2 here)) (2 .)))) | |||||
(4 (2 (2 Whether) (2 (2 (2 (2 or) (1 not)) (3 (2 you) (2 (2 're) (3 (3 enlightened) (2 (2 by) (2 (2 any) (2 (2 of) (2 (2 Derrida) (2 's))))))))) (2 (2 lectures) (2 (2 on) (2 (2 ``) (2 (2 (2 (2 (2 (2 the) (2 other)) (2 '')) (2 and)) (2 ``)) (2 (2 the) (2 self)))))))) (3 (2 ,) (3 (2 '') (3 (2 Derrida) (3 (3 (2 is) (4 (2 an) (4 (4 (2 undeniably) (3 (4 (3 fascinating) (2 and)) (4 playful))) (2 fellow)))) (2 .)))))) | |||||
(4 (3 (2 (2 Just) (2 (2 the) (2 labour))) (3 (2 involved) (3 (2 in) (4 (2 creating) (3 (3 (2 the) (3 (3 layered) (2 richness))) (3 (2 of) (3 (2 (2 the) (2 imagery)) (2 (2 in) (3 (2 (2 this) (2 chiaroscuro)) (2 (2 of) (2 (2 (2 madness) (2 and)) (2 light)))))))))))) (3 (3 (2 is) (4 astonishing)) (2 .))) | |||||
(3 (3 (2 Part) (3 (2 of) (4 (2 (2 the) (3 charm)) (2 (2 of) (2 (2 Satin) (2 Rouge)))))) (3 (3 (2 is) (3 (2 that) (3 (2 it) (2 (1 (2 avoids) (2 (2 the) (1 obvious))) (3 (2 with) (3 (3 (3 humour) (2 and)) (2 lightness))))))) (2 .))) | |||||
(4 (2 (2 a) (2 (2 screenplay) (2 more))) (3 (4 ingeniously) (2 (2 constructed) (2 (2 (2 (2 than) (2 ``)) (2 Memento)) (2 ''))))) | |||||
(3 (2 ``) (3 (2 (2 Extreme) (2 Ops)) (3 (2 '') (4 (4 (3 exceeds) (2 expectations)) (2 .))))) | |||||
""" | |||||
train, test = 'train--', 'test--' | |||||
with open(train, 'w', encoding='utf-8') as f: | |||||
f.write(train_data) | |||||
with open(test, 'w', encoding='utf-8') as f: | |||||
f.write(test_data) | |||||
loader = SSTLoader() | |||||
info = loader.process( | |||||
{train: train, test: test}, | |||||
train_ds=[train], | |||||
src_vocab_op=dict(min_freq=2) | |||||
) | |||||
assert len(list(info.vocabs.items())) == 2 | |||||
assert len(list(info.datasets.items())) == 2 | |||||
print(info.vocabs) | |||||
print(info.datasets) | |||||
os.remove(train), os.remove(test) | |||||
# def test_import(self): | |||||
# import fastNLP | |||||
# from fastNLP.io import SNLILoader | |||||
# ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, | |||||
# get_index=True, seq_len_type='seq_len', extra_split=['-']) | |||||
# assert 'train' in ds.datasets | |||||
# assert len(ds.datasets) == 1 | |||||
# assert len(ds.datasets['train']) == 3 | |||||
# | |||||
# ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True, | |||||
# get_index=True, seq_len_type='seq_len') | |||||
# assert 'train' in ds.datasets | |||||
# assert len(ds.datasets) == 1 | |||||
# assert len(ds.datasets['train']) == 3 |