@@ -4,48 +4,48 @@ fastNLP.io package | |||
Submodules | |||
---------- | |||
fastNLP.io.base\_loader module | |||
------------------------------ | |||
fastNLP.io.base_loader module | |||
----------------------------- | |||
.. automodule:: fastNLP.io.base_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.config\_io module | |||
---------------------------- | |||
fastNLP.io.config_io module | |||
--------------------------- | |||
.. automodule:: fastNLP.io.config_io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.dataset\_loader module | |||
--------------------------------- | |||
fastNLP.io.dataset_loader module | |||
-------------------------------- | |||
.. automodule:: fastNLP.io.dataset_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.embed\_loader module | |||
------------------------------- | |||
fastNLP.io.embed_loader module | |||
------------------------------ | |||
.. automodule:: fastNLP.io.embed_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.file\_reader module | |||
------------------------------ | |||
fastNLP.io.file_reader module | |||
----------------------------- | |||
.. automodule:: fastNLP.io.file_reader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.io.model\_io module | |||
--------------------------- | |||
fastNLP.io.model_io module | |||
-------------------------- | |||
.. automodule:: fastNLP.io.model_io | |||
:members: | |||
@@ -4,8 +4,8 @@ fastNLP.models package | |||
Submodules | |||
---------- | |||
fastNLP.models.base\_model module | |||
--------------------------------- | |||
fastNLP.models.base_model module | |||
-------------------------------- | |||
.. automodule:: fastNLP.models.base_model | |||
:members: | |||
@@ -20,64 +20,64 @@ fastNLP.models.bert module | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.biaffine\_parser module | |||
-------------------------------------- | |||
fastNLP.models.biaffine_parser module | |||
------------------------------------- | |||
.. automodule:: fastNLP.models.biaffine_parser | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.char\_language\_model module | |||
------------------------------------------- | |||
fastNLP.models.char_language_model module | |||
----------------------------------------- | |||
.. automodule:: fastNLP.models.char_language_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.cnn\_text\_classification module | |||
----------------------------------------------- | |||
fastNLP.models.cnn_text_classification module | |||
--------------------------------------------- | |||
.. automodule:: fastNLP.models.cnn_text_classification | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_controller module | |||
-------------------------------------- | |||
fastNLP.models.enas_controller module | |||
------------------------------------- | |||
.. automodule:: fastNLP.models.enas_controller | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_model module | |||
--------------------------------- | |||
fastNLP.models.enas_model module | |||
-------------------------------- | |||
.. automodule:: fastNLP.models.enas_model | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_trainer module | |||
----------------------------------- | |||
fastNLP.models.enas_trainer module | |||
---------------------------------- | |||
.. automodule:: fastNLP.models.enas_trainer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.enas\_utils module | |||
--------------------------------- | |||
fastNLP.models.enas_utils module | |||
-------------------------------- | |||
.. automodule:: fastNLP.models.enas_utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.sequence\_modeling module | |||
---------------------------------------- | |||
fastNLP.models.sequence_modeling module | |||
--------------------------------------- | |||
.. automodule:: fastNLP.models.sequence_modeling | |||
:members: | |||
@@ -92,8 +92,8 @@ fastNLP.models.snli module | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.models.star\_transformer module | |||
--------------------------------------- | |||
fastNLP.models.star_transformer module | |||
-------------------------------------- | |||
.. automodule:: fastNLP.models.star_transformer | |||
:members: | |||
@@ -12,32 +12,32 @@ fastNLP.modules.aggregator.attention module | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.avg\_pool module | |||
------------------------------------------- | |||
fastNLP.modules.aggregator.avg_pool module | |||
------------------------------------------ | |||
.. automodule:: fastNLP.modules.aggregator.avg_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.kmax\_pool module | |||
-------------------------------------------- | |||
fastNLP.modules.aggregator.kmax_pool module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.aggregator.kmax_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.max\_pool module | |||
------------------------------------------- | |||
fastNLP.modules.aggregator.max_pool module | |||
------------------------------------------ | |||
.. automodule:: fastNLP.modules.aggregator.max_pool | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.aggregator.self\_attention module | |||
------------------------------------------------- | |||
fastNLP.modules.aggregator.self_attention module | |||
------------------------------------------------ | |||
.. automodule:: fastNLP.modules.aggregator.self_attention | |||
:members: | |||
@@ -4,8 +4,8 @@ fastNLP.modules.encoder package | |||
Submodules | |||
---------- | |||
fastNLP.modules.encoder.char\_embedding module | |||
---------------------------------------------- | |||
fastNLP.modules.encoder.char_embedding module | |||
--------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.char_embedding | |||
:members: | |||
@@ -20,8 +20,8 @@ fastNLP.modules.encoder.conv module | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.conv\_maxpool module | |||
-------------------------------------------- | |||
fastNLP.modules.encoder.conv_maxpool module | |||
------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.conv_maxpool | |||
:members: | |||
@@ -52,16 +52,16 @@ fastNLP.modules.encoder.lstm module | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.masked\_rnn module | |||
------------------------------------------ | |||
fastNLP.modules.encoder.masked_rnn module | |||
----------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.masked_rnn | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.star\_transformer module | |||
------------------------------------------------ | |||
fastNLP.modules.encoder.star_transformer module | |||
----------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.star_transformer | |||
:members: | |||
@@ -76,8 +76,8 @@ fastNLP.modules.encoder.transformer module | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.encoder.variational\_rnn module | |||
----------------------------------------------- | |||
fastNLP.modules.encoder.variational_rnn module | |||
---------------------------------------------- | |||
.. automodule:: fastNLP.modules.encoder.variational_rnn | |||
:members: | |||
@@ -21,8 +21,8 @@ fastNLP.modules.dropout module | |||
:undoc-members: | |||
:show-inheritance: | |||
fastNLP.modules.other\_modules module | |||
------------------------------------- | |||
fastNLP.modules.other_modules module | |||
------------------------------------ | |||
.. automodule:: fastNLP.modules.other_modules | |||
:members: | |||
@@ -1,14 +1,114 @@ | |||
""" | |||
fastNLP.core.DataSet的介绍文档 | |||
DataSet是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,每一行是一个sample(在fastNLP中被称为Instance),每一列是一个feature(在fastNLP中称为field)。 | |||
DataSet是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,每一行是一个instance(或sample),每一列是一个feature。 | |||
.. _DataSet: | |||
csv-table:: | |||
:header: "Field1", "Field2", "Field3" | |||
:widths:20, 10, 10 | |||
.. csv-table:: Following is a demo layout of DataSet | |||
:header: "sentence", "words", "seq_len" | |||
"This is the first instance .", "[This, is, the, first, instance, .]", 6 | |||
"Second instance .", "[Second, instance, .]", 3 | |||
"Third instance .", "[Third, instance, .]", 3 | |||
"...", "[...]", "..." | |||
在fastNLP内部每一行是一个 Instance_ 对象; 每一列是一个 FieldArray_ 对象。 | |||
1. DataSet的创建 | |||
创建DataSet主要有以下的3种方式 | |||
1. 传入dict | |||
Example:: | |||
from fastNLP import DataSet | |||
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."], | |||
'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'], | |||
'seq_len': [6, 3, 3]} | |||
dataset = DataSet(data) | |||
# 传入的dict的每个key的value应该为具有相同长度的list | |||
2. 通过构建Instance | |||
Example:: | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
dataset = DataSet() | |||
instance = Instance(sentence="This is the first instance", | |||
words=['this', 'is', 'the', 'first', 'instance', '.'], | |||
seq_len=6) | |||
dataset.append(instance) | |||
# 可以继续append更多内容,但是append的instance应该和第一个instance拥有完全相同的field | |||
3. 通过list(Instance) | |||
Example:: | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
instances = [] | |||
instances.append(Instance(sentence="This is the first instance", | |||
words=['this', 'is', 'the', 'first', 'instance', '.'], | |||
seq_len=6)) | |||
instances.append(Instance(sentence="Second instance .", | |||
words=['Second', 'instance', '.'], | |||
seq_len=3)) | |||
dataset = DataSet(instances) | |||
2. DataSet的基本使用 | |||
1. 从某个文本文件读取内容 # TODO 引用DataLoader | |||
Example:: | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
dataset = DataSet() | |||
filepath='some/text/file' | |||
# 假设文件中每行内容如下(sentence label): | |||
# This is a fantastic day positive | |||
# The bad weather negative | |||
# ..... | |||
with open(filepath, 'r') as f: | |||
for line in f: | |||
sent, label = line.strip().split('\t') | |||
dataset.append(Instance(sentence=sent, label=label)) | |||
2. index, 返回结果为对DataSet对象的浅拷贝 | |||
Example:: | |||
import numpy as np | |||
from fastNLP import DataSet | |||
dataset = DataSet({'a': np.arange(10), 'b': [[_] for _ in range(10)]}) | |||
d[0] # 使用一个下标获取一个instance | |||
>>{'a': 0 type=int,'b': [2] type=list} # 得到一个instance | |||
d[1:3] # 使用slice获取一个新的DataSet | |||
>>DataSet({'a': 1 type=int, 'b': [2] type=list}, {'a': 2 type=int, 'b': [2] type=list}) | |||
3. 对DataSet中的内容处理 | |||
Example:: | |||
from fastNLP import DataSet | |||
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]} | |||
dataset = DataSet(data) | |||
# 将句子分成单词形式, 详见DataSet.apply()方法 | |||
dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words') | |||
# 或使用DataSet.apply_field() | |||
dataset.apply(lambda sent:sent.split(), field_name='sentence', new_field_name='words') | |||
4. 删除DataSet的内容 | |||
Example:: | |||
from fastNLP import DataSet | |||
dataset = DataSet({'a': list(range(-5, 5))}) | |||
# 返回满足条件的instance,并放入DataSet中 | |||
dropped_dataset = dataset.drop(lambda ins:ins['a']<0, inplace=False) | |||
# 在dataset中删除满足条件的instance | |||
dataset.drop(lambda ins:ins['a']<0) # dataset的instance数量减少 | |||
"This is the first instance", ['This', 'is', 'the', 'first', 'instance'], 5 | |||
"Second instance", ['Second', 'instance'], 2 | |||
""" | |||
@@ -22,7 +122,6 @@ from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.utils import get_func_signature | |||
class DataSet(object): | |||
"""DataSet is the collection of examples. | |||
DataSet provides instance-level interface. You can append and access an instance of the DataSet. | |||
@@ -87,10 +186,7 @@ class DataSet(object): | |||
return inner_iter_func() | |||
def __getitem__(self, idx): | |||
"""Fetch Instance(s) at the `idx` position(s) in the dataset. | |||
Notice: This method returns a copy of the actual instance(s). Any change to the returned value would not modify | |||
the origin instance(s) of the DataSet. | |||
If you want to make in-place changes to all Instances, use `apply` method. | |||
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||
:param idx: can be int or slice. | |||
:return: If `idx` is int, return an Instance object. | |||
@@ -145,33 +241,48 @@ class DataSet(object): | |||
def __repr__(self): | |||
return "DataSet(" + self.__inner_repr__() + ")" | |||
def append(self, ins): | |||
def append(self, instance): | |||
"""将一个instance对象append到DataSet后面。 | |||
If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. | |||
:param ins: an Instance object | |||
:param instance: an Instance object | |||
""" | |||
if len(self.field_arrays) == 0: | |||
# DataSet has no field yet | |||
for name, field in ins.fields.items(): | |||
for name, field in instance.fields.items(): | |||
field = field.tolist() if isinstance(field, np.ndarray) else field | |||
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | |||
else: | |||
if len(self.field_arrays) != len(ins.fields): | |||
if len(self.field_arrays) != len(instance.fields): | |||
raise ValueError( | |||
"DataSet object has {} fields, but attempt to append an Instance object with {} fields." | |||
.format(len(self.field_arrays), len(ins.fields))) | |||
for name, field in ins.fields.items(): | |||
.format(len(self.field_arrays), len(instance.fields))) | |||
for name, field in instance.fields.items(): | |||
assert name in self.field_arrays | |||
self.field_arrays[name].append(field) | |||
def add_fieldarray(self, field_name, fieldarray): | |||
"""将fieldarray添加到DataSet中. | |||
:param str field_name: 新加入的field的名称 | |||
:param FieldArray fieldarray: 需要加入DataSet的field的内容 | |||
:return: | |||
""" | |||
if not isinstance(fieldarray, FieldArray): | |||
raise TypeError("Only fastNLP.FieldArray supported.") | |||
if len(self) != len(fieldarray): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fieldarray)}") | |||
self.field_arrays[field_name] = fieldarray | |||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | |||
"""新增一个field | |||
:param str field_name: 新增的field的名称 | |||
:param list fields: 需要新增的field的内容 | |||
:param None, Padder padder: 如果为None,则不进行pad。 | |||
:param None,Padder padder: 如果为None,则不进行pad。 | |||
:param bool is_input: 新加入的field是否是input | |||
:param bool is_target: 新加入的field是否是target | |||
:param bool ignore_type: 是否忽略对新加入的field的类型检查 | |||
@@ -179,18 +290,28 @@ class DataSet(object): | |||
if len(self.field_arrays) != 0: | |||
if len(self) != len(fields): | |||
raise RuntimeError(f"The field to append must have the same size as dataset. " | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fields)}") | |||
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, | |||
padder=padder, ignore_type=ignore_type) | |||
def delete_field(self, field_name): | |||
"""删除field | |||
"""删除名为field_name的field | |||
:param str field_name: 需要删除的field的名称. | |||
""" | |||
self.field_arrays.pop(field_name) | |||
def has_field(self, field_name): | |||
"""判断DataSet中是否有field_name这个field | |||
:param str field_name: field的名称 | |||
:return: bool | |||
""" | |||
if isinstance(field_name, str): | |||
return field_name in self.field_arrays | |||
return False | |||
def get_field(self, field_name): | |||
"""获取field_name这个field | |||
@@ -318,25 +439,21 @@ class DataSet(object): | |||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
"""将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值. | |||
:param callable func: input是instance的`field_name`这个field. | |||
:param str field_name: 传入func的是哪个field. | |||
:param str, None new_field_name: 将func返回的内容放入到什么field中 | |||
1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 | |||
同,则覆盖之前的field | |||
2. None, 不创建新的field | |||
:param kwargs: 合法的参数有以下三个 | |||
:param callable func: input是instance的`field_name`这个field的内容。 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆 | |||
:盖之前的field。如果为None则不创建新的field。 | |||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
:return: list(Any), 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
""" | |||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||
assert len(self)!=0, "Null DataSet cannot use apply_field()." | |||
if field_name not in self: | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
results = [] | |||
@@ -388,23 +505,19 @@ class DataSet(object): | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
"""将DataSet中每个instance传入到func中,并获取它的返回值. | |||
:param callable func: 参数是DataSet中的instance | |||
:param str, None new_field_name: 将func返回的内容放入到什么field中 | |||
1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 | |||
同,则覆盖之前的field | |||
""" 将DataSet中每个instance传入到func中,并获取它的返回值. | |||
2. None, 不创建新的field | |||
:param kwargs: 合法的参数有以下三个 | |||
:param callable func: 参数是DataSet中的Instance | |||
:param None,str new_field_name: 将func返回的内容放入到new_field_name这个field中,如果名称与已有的field相同,则覆 | |||
:盖之前的field。如果为None则不创建新的field。 | |||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
1. is_input: bool, 如果为True则将`new_field_name`的field设置为input | |||
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
2. is_target: bool, 如果为True则将`new_field_name`的field设置为target | |||
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
:return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 | |||
:return: list(Any), 里面的元素为func的返回值,所以list长度为DataSet的长度 | |||
""" | |||
assert len(self)!=0, "Null DataSet cannot use apply()." | |||
idx = -1 | |||
@@ -426,10 +539,10 @@ class DataSet(object): | |||
return results | |||
def drop(self, func, inplace=True): | |||
"""func接受一个instance,返回bool值,返回值为True时,该instance会被删除。 | |||
"""func接受一个instance,返回bool值,返回值为True时,该instance会被移除或者加入到返回的DataSet中。 | |||
:param callable func: 接受一个instance作为参数,返回bool值。为True时删除该instance | |||
:param bool inplace: 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet | |||
:param bool inplace: 是否在当前DataSet中直接删除instance。如果为False,返回值被删除的instance的组成的新DataSet | |||
:return: DataSet | |||
""" | |||
@@ -440,10 +553,13 @@ class DataSet(object): | |||
return self | |||
else: | |||
results = [ins for ins in self if not func(ins)] | |||
dataset = DataSet(results) | |||
for field_name, field in self.field_arrays.items(): | |||
dataset.field_arrays[field_name].to(field) | |||
return dataset | |||
if len(results)!=0: | |||
dataset = DataSet(results) | |||
for field_name, field in self.field_arrays.items(): | |||
dataset.field_arrays[field_name].to(field) | |||
return dataset | |||
else: | |||
return DataSet() | |||
def split(self, ratio): | |||
"""将DataSet按照ratio的比例拆分,返回两个DataSet | |||
@@ -1,4 +1,9 @@ | |||
""" | |||
FieldArray是 DataSet_ 中一列的存储方式 | |||
.. _FieldArray: | |||
""" | |||
import numpy as np | |||
@@ -1,3 +1,14 @@ | |||
""" | |||
Instance文档 | |||
.. _Instance: | |||
测试 | |||
""" | |||
class Instance(object): | |||
"""An Instance is an example of data. | |||
Example:: | |||
@@ -24,47 +24,50 @@ def _prepare_cache_filepath(filepath): | |||
if not os.path.exists(cache_dir): | |||
os.makedirs(cache_dir) | |||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||
def cache_results(cache_filepath, refresh=False, verbose=1): | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
if key in ('cache_filepath', 'refresh', 'verbose'): | |||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
def wrapper(*args, **kwargs): | |||
if 'cache_filepath' in kwargs: | |||
_cache_filepath = kwargs.pop('cache_filepath') | |||
assert isinstance(_cache_filepath, str), "cache_filepath can only be str." | |||
if '_cache_fp' in kwargs: | |||
cache_filepath = kwargs.pop('_cache_fp') | |||
assert isinstance(cache_filepath, str), "_cache_fp can only be str." | |||
else: | |||
cache_filepath = _cache_fp | |||
if '_refresh' in kwargs: | |||
refresh = kwargs.pop('_refresh') | |||
assert isinstance(refresh, bool), "_refresh can only be bool." | |||
else: | |||
_cache_filepath = cache_filepath | |||
if 'refresh' in kwargs: | |||
_refresh = kwargs.pop('refresh') | |||
assert isinstance(_refresh, bool), "refresh can only be bool." | |||
refresh = _refresh | |||
if '_verbose' in kwargs: | |||
verbose = kwargs.pop('_verbose') | |||
assert isinstance(verbose, int), "_verbose can only be integer." | |||
else: | |||
_refresh = refresh | |||
if 'verbose' in kwargs: | |||
_verbose = kwargs.pop('verbose') | |||
assert isinstance(_verbose, int), "verbose can only be integer." | |||
verbose = _verbose | |||
refresh_flag = True | |||
if _cache_filepath is not None and _refresh is False: | |||
if cache_filepath is not None and refresh is False: | |||
# load data | |||
if os.path.exists(_cache_filepath): | |||
with open(_cache_filepath, 'rb') as f: | |||
if os.path.exists(cache_filepath): | |||
with open(cache_filepath, 'rb') as f: | |||
results = _pickle.load(f) | |||
if verbose==1: | |||
print("Read cache from {}.".format(_cache_filepath)) | |||
print("Read cache from {}.".format(cache_filepath)) | |||
refresh_flag = False | |||
if refresh_flag: | |||
results = func(*args, **kwargs) | |||
if _cache_filepath is not None: | |||
if cache_filepath is not None: | |||
if results is None: | |||
raise RuntimeError("The return value is None. Delete the decorator.") | |||
_prepare_cache_filepath(_cache_filepath) | |||
with open(_cache_filepath, 'wb') as f: | |||
_prepare_cache_filepath(cache_filepath) | |||
with open(cache_filepath, 'wb') as f: | |||
_pickle.dump(results, f) | |||
print("Save cache to {}.".format(_cache_filepath)) | |||
print("Save cache to {}.".format(cache_filepath)) | |||
return results | |||
return wrapper | |||
@@ -1,7 +1,6 @@ | |||
import os | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.io.base_loader import BaseLoader | |||
@@ -14,120 +13,6 @@ class EmbedLoader(BaseLoader): | |||
def __init__(self): | |||
super(EmbedLoader, self).__init__() | |||
@staticmethod | |||
def _load_glove(emb_file): | |||
"""Read file as a glove embedding | |||
file format: | |||
embeddings are split by line, | |||
for one embedding, word and numbers split by space | |||
Example:: | |||
word_1 float_1 float_2 ... float_emb_dim | |||
word_2 float_1 float_2 ... float_emb_dim | |||
... | |||
""" | |||
emb = {} | |||
with open(emb_file, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = list(filter(lambda w: len(w) > 0, line.strip().split(' '))) | |||
if len(line) > 2: | |||
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) | |||
return emb | |||
@staticmethod | |||
def _load_pretrain(emb_file, emb_type): | |||
"""Read txt data from embedding file and convert to np.array as pre-trained embedding | |||
:param str emb_file: the pre-trained embedding file path | |||
:param str emb_type: the pre-trained embedding data format | |||
:return: a dict of ``{str: np.array}`` | |||
""" | |||
if emb_type == 'glove': | |||
return EmbedLoader._load_glove(emb_file) | |||
else: | |||
raise Exception("embedding type {} not support yet".format(emb_type)) | |||
@staticmethod | |||
def load_embedding(emb_dim, emb_file, emb_type, vocab): | |||
"""Load the pre-trained embedding and combine with the given dictionary. | |||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||
:param str emb_file: the pre-trained embedding file path. | |||
:param str emb_type: the pre-trained embedding format, support glove now | |||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||
:return (embedding_tensor, vocab): | |||
embedding_tensor - Tensor of shape (len(word_dict), emb_dim); | |||
vocab - input vocab or vocab built by pre-train | |||
""" | |||
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) | |||
if vocab is None: | |||
# build vocabulary from pre-trained embedding | |||
vocab = Vocabulary() | |||
for w in pretrain.keys(): | |||
vocab.add(w) | |||
embedding_tensor = torch.randn(len(vocab), emb_dim) | |||
for w, v in pretrain.items(): | |||
if len(v.shape) > 1 or emb_dim != v.shape[0]: | |||
raise ValueError( | |||
"Pretrained embedding dim is {}. Dimension dismatched. Required {}".format(v.shape, (emb_dim,))) | |||
if vocab.has_word(w): | |||
embedding_tensor[vocab[w]] = v | |||
return embedding_tensor, vocab | |||
@staticmethod | |||
def parse_glove_line(line): | |||
line = line.split() | |||
if len(line) <= 2: | |||
raise RuntimeError("something goes wrong in parsing glove embedding") | |||
return line[0], line[1:] | |||
@staticmethod | |||
def str_list_2_vec(line): | |||
try: | |||
return torch.Tensor(list(map(float, line))) | |||
except Exception: | |||
raise RuntimeError("something goes wrong in parsing glove embedding") | |||
@staticmethod | |||
def fast_load_embedding(emb_dim, emb_file, vocab): | |||
"""Fast load the pre-trained embedding and combine with the given dictionary. | |||
This loading method uses line-by-line operation. | |||
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding. | |||
:param str emb_file: the pre-trained embedding file path. | |||
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding | |||
:return embedding_matrix: numpy.ndarray | |||
""" | |||
if vocab is None: | |||
raise RuntimeError("You must provide a vocabulary.") | |||
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim), dtype=np.float32) | |||
hit_flags = np.zeros(shape=(len(vocab),), dtype=int) | |||
with open(emb_file, "r", encoding="utf-8") as f: | |||
startline = f.readline() | |||
if len(startline.split()) > 2: | |||
f.seek(0) | |||
for line in f: | |||
word, vector = EmbedLoader.parse_glove_line(line) | |||
if word in vocab: | |||
vector = EmbedLoader.str_list_2_vec(vector) | |||
if len(vector.shape) > 1 or emb_dim != vector.shape[0]: | |||
raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,))) | |||
embedding_matrix[vocab[word]] = vector | |||
hit_flags[vocab[word]] = 1 | |||
if np.sum(hit_flags) < len(vocab): | |||
# some words from vocab are missing in pre-trained embedding | |||
# we normally sample each dimension | |||
vocab_embed = embedding_matrix[np.where(hit_flags)] | |||
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0), | |||
size=(len(vocab) - np.sum(hit_flags), emb_dim)) | |||
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors | |||
return embedding_matrix | |||
@staticmethod | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): | |||
""" | |||
@@ -36,8 +36,7 @@ def viterbi_decode(feats, transitions, mask=None, unpad=False): | |||
vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) | |||
vscore = feats[0] | |||
vscore += transitions[n_tags, :n_tags] | |||
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data | |||
trans_score = transitions.view(1, n_tags, n_tags).data | |||
for i in range(1, seq_len): | |||
prev_score = vscore.view(batch_size, n_tags, 1) | |||
cur_score = feats[i].view(batch_size, 1, n_tags) | |||
@@ -155,7 +155,7 @@ print('test len {}'.format(len(test_data))) | |||
def train(path): | |||
# test saving pipeline | |||
save_pipe(path) | |||
embed = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v) | |||
embed = EmbedLoader.load_with_vocab(emb_file_name, word_v) | |||
embed = torch.tensor(embed, dtype=torch.float32) | |||
# embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v) | |||
@@ -1,4 +1,5 @@ | |||
numpy>=1.14.2 | |||
torch>=0.4.0 | |||
tensorboardX | |||
tqdm>=4.28.1 | |||
tqdm>=4.28.1 | |||
nltk>=3.4.1 |
@@ -89,17 +89,17 @@ class TestCache(unittest.TestCase): | |||
def test_duplicate_keyword(self): | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_verbose(a, verbose): | |||
def func_verbose(a, _verbose): | |||
pass | |||
func_verbose(0, 1) | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_cache(a, cache_filepath): | |||
def func_cache(a, _cache_fp): | |||
pass | |||
func_cache(1, 2) | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_refresh(a, refresh): | |||
def func_refresh(a, _refresh): | |||
pass | |||
func_refresh(1, 2) | |||
@@ -6,12 +6,6 @@ from fastNLP.io.embed_loader import EmbedLoader | |||
class TestEmbedLoader(unittest.TestCase): | |||
def test_case(self): | |||
vocab = Vocabulary() | |||
vocab.update(["the", "in", "I", "to", "of", "hahaha"]) | |||
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab) | |||
self.assertEqual(tuple(embedding.shape), (len(vocab), 50)) | |||
def test_load_with_vocab(self): | |||
vocab = Vocabulary() | |||
glove = "test/data_for_tests/glove.6B.50d_test.txt" | |||