@@ -47,7 +47,7 @@ from fastNLP.core.dataset import DataSet | |||
from fastNLP.api.utils import load_url | |||
from fastNLP.api.processor import ModelProcessor | |||
from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader | |||
from fastNLP.io.dataset_loader import cut_long_sentence, ConllLoader | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
@@ -61,6 +61,85 @@ model_urls = { | |||
} | |||
class ConllCWSReader(object): | |||
"""Deprecated. Use ConllLoader for all types of conll-format files.""" | |||
def __init__(self): | |||
pass | |||
def load(self, path, cut_long_sent=False): | |||
""" | |||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
:: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.strip().split()) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_char_lst(sample) | |||
if res is None: | |||
continue | |||
line = ' '.join(res) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for raw_sentence in sents: | |||
ds.append(Instance(raw_sentence=raw_sentence)) | |||
return ds | |||
def get_char_lst(self, sample): | |||
if len(sample) == 0: | |||
return None | |||
text = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
return text | |||
class ConllxDataLoader(ConllLoader): | |||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'words', 'pos_tags', 'heads', 'labels', | |||
] | |||
indexs = [ | |||
1, 3, 6, 7, | |||
] | |||
super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) | |||
class API: | |||
""" | |||
这是 API 类的文档 | |||
@@ -449,6 +449,9 @@ class DataSet(object): | |||
:return dataset: the read data set | |||
""" | |||
import warnings | |||
warnings.warn('read_csv is deprecated, use CSVLoader instead', | |||
category=DeprecationWarning) | |||
with open(csv_path, "r", encoding='utf-8') as f: | |||
start_idx = 0 | |||
if headers is None: | |||
@@ -66,28 +66,28 @@ class Trainer(object): | |||
不足,通过设置batch_size=32, update_every=4达到目的 | |||
""" | |||
super(Trainer, self).__init__() | |||
if not isinstance(train_data, DataSet): | |||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
# check metrics and dev_data | |||
if (not metrics) and dev_data is not None: | |||
raise ValueError("No metric for dev_data evaluation.") | |||
if metrics and (dev_data is None): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# check update every | |||
assert update_every >= 1, "update_every must be no less than 1." | |||
self.update_every = int(update_every) | |||
# check save_path | |||
if not (save_path is None or isinstance(save_path, str)): | |||
raise ValueError("save_path can only be None or `str`.") | |||
# prepare evaluate | |||
metrics = _prepare_metrics(metrics) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
@@ -97,19 +97,19 @@ class Trainer(object): | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
elif len(metrics) > 0: | |||
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric') | |||
# prepare loss | |||
losser = _prepare_losser(loss) | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, BaseSampler): | |||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||
if check_code_level > -1: | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
self.model = model | |||
@@ -130,18 +130,18 @@ class Trainer(object): | |||
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
else: | |||
if optimizer is None: | |||
optimizer = Adam(lr=0.01, weight_decay=0) | |||
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) | |||
self.use_tqdm = use_tqdm | |||
self.pbar = None | |||
self.print_every = abs(self.print_every) | |||
if self.dev_data is not None: | |||
self.tester = Tester(model=self.model, | |||
data=self.dev_data, | |||
@@ -149,13 +149,13 @@ class Trainer(object): | |||
batch_size=self.batch_size, | |||
use_cuda=self.use_cuda, | |||
verbose=0) | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True): | |||
""" | |||
@@ -205,18 +205,18 @@ class Trainer(object): | |||
self.model = self.model.cuda() | |||
self._model_device = self.model.parameters().__next__().device | |||
self._mode(self.model, is_test=False) | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
start_time = time.time() | |||
print("training epochs started " + self.start_time, flush=True) | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.on_train_end() | |||
except (CallbackException, KeyboardInterrupt) as e: | |||
self.callback_manager.on_exception(e) | |||
if self.dev_data is not None and hasattr(self, 'best_dev_perf'): | |||
print( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
@@ -234,9 +234,9 @@ class Trainer(object): | |||
finally: | |||
pass | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
return results | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm | |||
@@ -245,7 +245,7 @@ class Trainer(object): | |||
self.step = 0 | |||
self.epoch = 0 | |||
start = time.time() | |||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
self.pbar = pbar if isinstance(pbar, tqdm) else None | |||
avg_loss = 0 | |||
@@ -263,23 +263,23 @@ class Trainer(object): | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y).mean() | |||
avg_loss += loss.item() | |||
loss = loss / self.update_every | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
self._grad_backward(loss) | |||
self.callback_manager.on_backward_end() | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if (self.step + 1) % self.print_every == 0: | |||
avg_loss = avg_loss / self.print_every | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
if self.use_tqdm: | |||
print_output = "loss:{0:<6.5f}".format(avg_loss) | |||
pbar.update(self.print_every) | |||
@@ -291,7 +291,7 @@ class Trainer(object): | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.callback_manager.on_batch_end() | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
and self.dev_data is not None: | |||
@@ -300,20 +300,20 @@ class Trainer(object): | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
self.pbar = None | |||
# ============ tqdm end ============== # | |||
def _do_validation(self, epoch, step): | |||
self.callback_manager.on_valid_begin() | |||
res = self.tester.test() | |||
is_better_eval = False | |||
if self._better_eval_result(res): | |||
if self.save_path is not None: | |||
@@ -328,7 +328,7 @@ class Trainer(object): | |||
# get validation results; adjust optimizer | |||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||
return res | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -340,21 +340,21 @@ class Trainer(object): | |||
model.eval() | |||
else: | |||
model.train() | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
""" | |||
if (self.step + 1) % self.update_every == 0: | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
x = _build_args(network.forward, **x) | |||
y = network(**x) | |||
if not isinstance(y, dict): | |||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | |||
return y | |||
def _grad_backward(self, loss): | |||
"""Compute gradient with link rules. | |||
@@ -365,7 +365,7 @@ class Trainer(object): | |||
if self.step % self.update_every == 0: | |||
self.model.zero_grad() | |||
loss.backward() | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
@@ -374,7 +374,7 @@ class Trainer(object): | |||
:return: a scalar | |||
""" | |||
return self.losser(predict, truth) | |||
def _save_model(self, model, model_name, only_param=False): | |||
""" 存储不含有显卡信息的state_dict或model | |||
:param model: | |||
@@ -395,7 +395,7 @@ class Trainer(object): | |||
model.cpu() | |||
torch.save(model, model_path) | |||
model.to(self._model_device) | |||
def _load_model(self, model, model_name, only_param=False): | |||
# 返回bool值指示是否成功reload模型 | |||
if self.save_path is not None: | |||
@@ -410,7 +410,7 @@ class Trainer(object): | |||
else: | |||
return False | |||
return True | |||
def _better_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
@@ -461,7 +461,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
check_level=0): | |||
# check get_loss 方法 | |||
model_devcie = model.parameters().__next__().device | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
@@ -485,13 +485,13 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
print(info_str) | |||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
refined_batch_x = _build_args(model.forward, **batch_x) | |||
pred_dict = model(**refined_batch_x) | |||
func_signature = get_func_signature(model.forward) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | |||
# loss check | |||
try: | |||
loss = losser(pred_dict, batch_y) | |||
@@ -515,7 +515,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
model.zero_grad() | |||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | |||
break | |||
if dev_data is not None: | |||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||
batch_size=batch_size, verbose=-1) | |||
@@ -529,7 +529,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||
# metric_list: 多个用来做评价的指标,来自Trainer的初始化 | |||
if isinstance(metrics, tuple): | |||
loss, metrics = metrics | |||
if isinstance(metrics, dict): | |||
if len(metrics) == 1: | |||
# only single metric, just use it | |||
@@ -540,7 +540,7 @@ def _check_eval_results(metrics, metric_key, metric_list): | |||
if metrics_name not in metrics: | |||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||
metric_dict = metrics[metrics_name] | |||
if len(metric_dict) == 1: | |||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | |||
elif len(metric_dict) > 1 and metric_key is None: | |||
@@ -1,71 +1,13 @@ | |||
import os | |||
import json | |||
from nltk.tree import Tree | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.io.base_loader import DataLoaderRegister | |||
from fastNLP.io.file_reader import read_csv, read_json, read_conll | |||
def convert_seq_dataset(data): | |||
"""Create an DataSet instance that contains no labels. | |||
:param data: list of list of strings, [num_examples, \*]. | |||
Example:: | |||
[ | |||
[word_11, word_12, ...], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
dataset = DataSet() | |||
for word_seq in data: | |||
dataset.append(Instance(word_seq=word_seq)) | |||
return dataset | |||
def convert_seq2tag_dataset(data): | |||
"""Convert list of data into DataSet. | |||
:param data: list of list of strings, [num_examples, \*]. | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], label_1 ], | |||
[ [word_21, word_22, ...], label_2 ], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
dataset = DataSet() | |||
for sample in data: | |||
dataset.append(Instance(word_seq=sample[0], label=sample[1])) | |||
return dataset | |||
def convert_seq2seq_dataset(data): | |||
"""Convert list of data into DataSet. | |||
:param data: list of list of strings, [num_examples, \*]. | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:return: a DataSet. | |||
""" | |||
dataset = DataSet() | |||
for sample in data: | |||
dataset.append(Instance(word_seq=sample[0], label_seq=sample[1])) | |||
return dataset | |||
def download_from_url(url, path): | |||
def _download_from_url(url, path): | |||
from tqdm import tqdm | |||
import requests | |||
@@ -81,7 +23,7 @@ def download_from_url(url, path): | |||
t.update(len(chunk)) | |||
return | |||
def uncompress(src, dst): | |||
def _uncompress(src, dst): | |||
import zipfile, gzip, tarfile, os | |||
def unzip(src, dst): | |||
@@ -134,243 +76,6 @@ class DataSetLoader: | |||
raise NotImplementedError | |||
class NativeDataSetLoader(DataSetLoader): | |||
"""A simple example of DataSetLoader | |||
""" | |||
def __init__(self): | |||
super(NativeDataSetLoader, self).__init__() | |||
def load(self, path): | |||
ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t") | |||
ds.set_input("raw_sentence") | |||
ds.set_target("label") | |||
return ds | |||
DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') | |||
class RawDataSetLoader(DataSetLoader): | |||
"""A simple example of raw data reader | |||
""" | |||
def __init__(self): | |||
super(RawDataSetLoader, self).__init__() | |||
def load(self, data_path, split=None): | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
lines = lines if split is None else [l.split(split) for l in lines] | |||
lines = list(filter(lambda x: len(x) > 0, lines)) | |||
return self.convert(lines) | |||
def convert(self, data): | |||
return convert_seq_dataset(data) | |||
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') | |||
class DummyPOSReader(DataSetLoader): | |||
"""A simple reader for a dummy POS tagging dataset. | |||
In these datasets, each line are divided by "\\\\t". The first Col is the vocabulary and the second | |||
Col is the label. Different sentence are divided by an empty line. | |||
E.g:: | |||
Tom label1 | |||
and label2 | |||
Jerry label1 | |||
. label3 | |||
(separated by an empty line) | |||
Hello label4 | |||
world label5 | |||
! label3 | |||
In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. | |||
""" | |||
def __init__(self): | |||
super(DummyPOSReader, self).__init__() | |||
def load(self, data_path): | |||
""" | |||
:return data: three-level list | |||
Example:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
data = self.parse(lines) | |||
return self.convert(data) | |||
@staticmethod | |||
def parse(lines): | |||
data = [] | |||
sentence = [] | |||
for line in lines: | |||
line = line.strip() | |||
if len(line) > 1: | |||
sentence.append(line.split('\t')) | |||
else: | |||
words = [] | |||
labels = [] | |||
for tokens in sentence: | |||
words.append(tokens[0]) | |||
labels.append(tokens[1]) | |||
data.append([words, labels]) | |||
sentence = [] | |||
if len(sentence) != 0: | |||
words = [] | |||
labels = [] | |||
for tokens in sentence: | |||
words.append(tokens[0]) | |||
labels.append(tokens[1]) | |||
data.append([words, labels]) | |||
return data | |||
def convert(self, data): | |||
"""Convert lists of strings into Instances with Fields. | |||
""" | |||
return convert_seq2seq_dataset(data) | |||
DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos') | |||
class DummyCWSReader(DataSetLoader): | |||
"""Load pku dataset for Chinese word segmentation. | |||
""" | |||
def __init__(self): | |||
super(DummyCWSReader, self).__init__() | |||
def load(self, data_path, max_seq_len=32): | |||
"""Load pku dataset for Chinese word segmentation. | |||
CWS (Chinese Word Segmentation) pku training dataset format: | |||
1. Each line is a sentence. | |||
2. Each word in a sentence is separated by space. | |||
This function convert the pku dataset into three-level lists with labels <BMES>. | |||
B: beginning of a word | |||
M: middle of a word | |||
E: ending of a word | |||
S: single character | |||
:param str data_path: path to the data set. | |||
:param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into | |||
several sequences. | |||
:return: three-level lists | |||
""" | |||
assert isinstance(max_seq_len, int) and max_seq_len > 0 | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
sentences = f.readlines() | |||
data = [] | |||
for sent in sentences: | |||
tokens = sent.strip().split() | |||
words = [] | |||
labels = [] | |||
for token in tokens: | |||
if len(token) == 1: | |||
words.append(token) | |||
labels.append("S") | |||
else: | |||
words.append(token[0]) | |||
labels.append("B") | |||
for idx in range(1, len(token) - 1): | |||
words.append(token[idx]) | |||
labels.append("M") | |||
words.append(token[-1]) | |||
labels.append("E") | |||
num_samples = len(words) // max_seq_len | |||
if len(words) % max_seq_len != 0: | |||
num_samples += 1 | |||
for sample_idx in range(num_samples): | |||
start = sample_idx * max_seq_len | |||
end = (sample_idx + 1) * max_seq_len | |||
seq_words = words[start:end] | |||
seq_labels = labels[start:end] | |||
data.append([seq_words, seq_labels]) | |||
return self.convert(data) | |||
def convert(self, data): | |||
return convert_seq2seq_dataset(data) | |||
class DummyClassificationReader(DataSetLoader): | |||
"""Loader for a dummy classification data set""" | |||
def __init__(self): | |||
super(DummyClassificationReader, self).__init__() | |||
def load(self, data_path): | |||
assert os.path.exists(data_path) | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
lines = f.readlines() | |||
data = self.parse(lines) | |||
return self.convert(data) | |||
@staticmethod | |||
def parse(lines): | |||
"""每行第一个token是标签,其余是字/词;由空格分隔。 | |||
:param lines: lines from dataset | |||
:return: list(list(list())): the three level of lists are words, sentence, and dataset | |||
""" | |||
dataset = list() | |||
for line in lines: | |||
line = line.strip().split() | |||
label = line[0] | |||
words = line[1:] | |||
if len(words) <= 1: | |||
continue | |||
sentence = [words, label] | |||
dataset.append(sentence) | |||
return dataset | |||
def convert(self, data): | |||
return convert_seq2tag_dataset(data) | |||
class DummyLMReader(DataSetLoader): | |||
"""A Dummy Language Model Dataset Reader | |||
""" | |||
def __init__(self): | |||
super(DummyLMReader, self).__init__() | |||
def load(self, data_path): | |||
if not os.path.exists(data_path): | |||
raise FileNotFoundError("file {} not found.".format(data_path)) | |||
with open(data_path, "r", encoding="utf=8") as f: | |||
text = " ".join(f.readlines()) | |||
tokens = text.strip().split() | |||
data = self.sentence_cut(tokens) | |||
return self.convert(data) | |||
def sentence_cut(self, tokens, sentence_length=15): | |||
start_idx = 0 | |||
data_set = [] | |||
for idx in range(len(tokens) // sentence_length): | |||
x = tokens[start_idx * idx: start_idx * idx + sentence_length] | |||
y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] | |||
if start_idx * idx + sentence_length + 1 >= len(tokens): | |||
# ad hoc | |||
y.extend(["<unk>"]) | |||
data_set.append([x, y]) | |||
return data_set | |||
def convert(self, data): | |||
pass | |||
class PeopleDailyCorpusLoader(DataSetLoader): | |||
"""人民日报数据集 | |||
""" | |||
@@ -450,8 +155,9 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||
class ConllLoader: | |||
def __init__(self, headers, indexs=None): | |||
def __init__(self, headers, indexs=None, dropna=True): | |||
self.headers = headers | |||
self.dropna = dropna | |||
if indexs is None: | |||
self.indexs = list(range(len(self.headers))) | |||
else: | |||
@@ -460,33 +166,10 @@ class ConllLoader: | |||
self.indexs = indexs | |||
def load(self, path): | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
start = next(f) | |||
if '-DOCSTART-' not in start: | |||
sample.append(start.split()) | |||
for line in f: | |||
if line.startswith('\n'): | |||
if len(sample): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split()) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
data = [self.get_one(sample) for sample in datalist] | |||
data = filter(lambda x: x is not None, data) | |||
ds = DataSet() | |||
for sample in data: | |||
ins = Instance() | |||
for name, idx in zip(self.headers, self.indexs): | |||
ins.add_field(field_name=name, field=sample[idx]) | |||
ds.append(ins) | |||
for idx, data in read_conll(path, indexes=self.indexs, dropna=self.dropna): | |||
ins = {h:data[idx] for h, idx in zip(self.headers, self.indexs)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def get_one(self, sample): | |||
@@ -501,9 +184,7 @@ class Conll2003Loader(ConllLoader): | |||
"""Loader for conll2003 dataset | |||
More information about the given dataset cound be found on | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
@@ -512,198 +193,6 @@ class Conll2003Loader(ConllLoader): | |||
super(Conll2003Loader, self).__init__(headers=headers) | |||
class SNLIDataSetReader(DataSetLoader): | |||
"""A data set loader for SNLI data set. | |||
""" | |||
def __init__(self): | |||
super(SNLIDataSetReader, self).__init__() | |||
def load(self, path_list): | |||
""" | |||
:param list path_list: A list of file name, in the order of premise file, hypothesis file, and label file. | |||
:return: A DataSet object. | |||
""" | |||
assert len(path_list) == 3 | |||
line_set = [] | |||
for file in path_list: | |||
if not os.path.exists(file): | |||
raise FileNotFoundError("file {} NOT found".format(file)) | |||
with open(file, 'r', encoding='utf-8') as f: | |||
lines = f.readlines() | |||
line_set.append(lines) | |||
premise_lines, hypothesis_lines, label_lines = line_set | |||
assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines) | |||
data_set = [] | |||
for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines): | |||
p = premise.strip().split() | |||
h = hypothesis.strip().split() | |||
l = label.strip() | |||
data_set.append([p, h, l]) | |||
return self.convert(data_set) | |||
def convert(self, data): | |||
"""Convert a 3D list to a DataSet object. | |||
:param data: A 3D tensor. | |||
Example:: | |||
[ | |||
[ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], | |||
[ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], | |||
... | |||
] | |||
:return: A DataSet object. | |||
""" | |||
data_set = DataSet() | |||
for example in data: | |||
p, h, l = example | |||
# list, list, str | |||
instance = Instance() | |||
instance.add_field("premise", p) | |||
instance.add_field("hypothesis", h) | |||
instance.add_field("truth", l) | |||
data_set.append(instance) | |||
data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len") | |||
data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len") | |||
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") | |||
data_set.set_target("truth") | |||
return data_set | |||
class ConllCWSReader(object): | |||
"""Deprecated. Use ConllLoader for all types of conll-format files.""" | |||
def __init__(self): | |||
pass | |||
def load(self, path, cut_long_sent=False): | |||
""" | |||
返回的DataSet只包含raw_sentence这个field,内容为str。 | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 | |||
:: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.strip().split()) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_char_lst(sample) | |||
if res is None: | |||
continue | |||
line = ' '.join(res) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for raw_sentence in sents: | |||
ds.append(Instance(raw_sentence=raw_sentence)) | |||
return ds | |||
def get_char_lst(self, sample): | |||
if len(sample) == 0: | |||
return None | |||
text = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
return text | |||
class NaiveCWSReader(DataSetLoader): | |||
""" | |||
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 | |||
例如:: | |||
这是 fastNLP , 一个 非常 good 的 包 . | |||
或者,即每个part后面还有一个pos tag | |||
例如:: | |||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||
""" | |||
def __init__(self, in_word_splitter=None): | |||
super(NaiveCWSReader, self).__init__() | |||
self.in_word_splitter = in_word_splitter | |||
def load(self, filepath, in_word_splitter=None, cut_long_sent=False): | |||
""" | |||
允许使用的情况有(默认以\\\\t或空格作为seg):: | |||
这是 fastNLP , 一个 非常 good 的 包 . | |||
和:: | |||
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY | |||
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] | |||
:param filepath: | |||
:param in_word_splitter: | |||
:param cut_long_sent: | |||
:return: | |||
""" | |||
if in_word_splitter == None: | |||
in_word_splitter = self.in_word_splitter | |||
dataset = DataSet() | |||
with open(filepath, 'r') as f: | |||
for line in f: | |||
line = line.strip() | |||
if len(line.replace(' ', '')) == 0: # 不能接受空行 | |||
continue | |||
if not in_word_splitter is None: | |||
words = [] | |||
for part in line.split(): | |||
word = part.split(in_word_splitter)[0] | |||
words.append(word) | |||
line = ' '.join(words) | |||
if cut_long_sent: | |||
sents = cut_long_sentence(line) | |||
else: | |||
sents = [line] | |||
for sent in sents: | |||
instance = Instance(raw_sentence=sent) | |||
dataset.append(instance) | |||
return dataset | |||
def cut_long_sentence(sent, max_sample_length=200): | |||
""" | |||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||
@@ -733,104 +222,6 @@ def cut_long_sentence(sent, max_sample_length=200): | |||
return cutted_sentence | |||
class ZhConllPOSReader(object): | |||
"""读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
""" | |||
def __init__(self): | |||
pass | |||
def load(self, path): | |||
""" | |||
返回的DataSet, 包含以下的field:: | |||
words:list of str, | |||
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] | |||
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即:: | |||
1 编者按 编者按 NN O 11 nmod:topic | |||
2 : : PU O 11 punct | |||
3 7月 7月 NT DATE 4 compound:nn | |||
4 12日 12日 NT DATE 11 nmod:tmod | |||
5 , , PU O 11 punct | |||
1 这 这 DT O 3 det | |||
2 款 款 M O 1 mark:clf | |||
3 飞行 飞行 NN O 8 nsubj | |||
4 从 从 P O 5 case | |||
5 外型 外型 NN O 8 nmod:prep | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
sample = [] | |||
for line in f: | |||
if line.startswith('\n'): | |||
datalist.append(sample) | |||
sample = [] | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split('\t')) | |||
if len(sample) > 0: | |||
datalist.append(sample) | |||
ds = DataSet() | |||
for sample in datalist: | |||
# print(sample) | |||
res = self.get_one(sample) | |||
if res is None: | |||
continue | |||
char_seq = [] | |||
pos_seq = [] | |||
for word, tag in zip(res[0], res[1]): | |||
char_seq.extend(list(word)) | |||
if len(word) == 1: | |||
pos_seq.append('S-{}'.format(tag)) | |||
elif len(word) > 1: | |||
pos_seq.append('B-{}'.format(tag)) | |||
for _ in range(len(word) - 2): | |||
pos_seq.append('M-{}'.format(tag)) | |||
pos_seq.append('E-{}'.format(tag)) | |||
else: | |||
raise ValueError("Zero length of word detected.") | |||
ds.append(Instance(words=char_seq, | |||
tag=pos_seq)) | |||
return ds | |||
def get_one(self, sample): | |||
if len(sample) == 0: | |||
return None | |||
text = [] | |||
pos_tags = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
pos_tags.append(t2) | |||
return text, pos_tags | |||
class ConllxDataLoader(ConllLoader): | |||
"""返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 | |||
Deprecated. Use ConllLoader for all types of conll-format files. | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'words', 'pos_tags', 'heads', 'labels', | |||
] | |||
indexs = [ | |||
1, 3, 6, 7, | |||
] | |||
super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) | |||
class SSTLoader(DataSetLoader): | |||
"""load SST data in PTB tree format | |||
data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||
@@ -849,10 +240,7 @@ class SSTLoader(DataSetLoader): | |||
""" | |||
:param path: str,存储数据的路径 | |||
:return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) | |||
类似于拥有以下结构, 一行为一个instance(sample) | |||
words pos_tags heads labels | |||
['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] | |||
:return: DataSet。 | |||
""" | |||
datalist = [] | |||
with open(path, 'r', encoding='utf-8') as f: | |||
@@ -867,7 +255,6 @@ class SSTLoader(DataSetLoader): | |||
@staticmethod | |||
def get_one(data, subtree): | |||
from nltk.tree import Tree | |||
tree = Tree.fromstring(data) | |||
if subtree: | |||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | |||
@@ -879,26 +266,72 @@ class JsonLoader(DataSetLoader): | |||
every line contains a json obj, like a dict | |||
fields is the dict key that need to be load | |||
""" | |||
def __init__(self, **fields): | |||
def __init__(self, dropna=False, fields=None): | |||
super(JsonLoader, self).__init__() | |||
self.fields = {} | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.dropna = dropna | |||
self.fields = None | |||
self.fields_list = None | |||
if fields: | |||
self.fields = {} | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.fields_list = list(self.fields.keys()) | |||
def load(self, path): | |||
ds = DataSet() | |||
for idx, d in read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
ins = {self.fields[k]:v for k,v in d.items()} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
class SNLILoader(JsonLoader): | |||
""" | |||
data source: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||
""" | |||
def __init__(self): | |||
fields = { | |||
'sentence1_parse': 'words1', | |||
'sentence2_parse': 'words2', | |||
'gold_label': 'target', | |||
} | |||
super(SNLILoader, self).__init__(fields=fields) | |||
def load(self, path): | |||
ds = super(SNLILoader, self).load(path) | |||
def parse_tree(x): | |||
t = Tree.fromstring(x) | |||
return t.leaves() | |||
ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') | |||
ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') | |||
ds.drop(lambda x: x['target'] == '-') | |||
return ds | |||
class CSVLoader(DataSetLoader): | |||
"""Load data from a CSV file and return a DataSet object. | |||
:param str csv_path: path to the CSV file | |||
:param List[str] or Tuple[str] headers: headers of the CSV file | |||
:param str sep: delimiter in CSV file. Default: "," | |||
:param bool dropna: If True, drop rows that have less entries than headers. | |||
:return dataset: the read data set | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=True): | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
def load(self, path): | |||
with open(path, 'r', encoding='utf-8') as f: | |||
datas = [json.loads(l) for l in f] | |||
ds = DataSet() | |||
for d in datas: | |||
ins = Instance() | |||
for k, v in d.items(): | |||
if k in self.fields: | |||
ins.add_field(self.fields[k], v) | |||
ds.append(ins) | |||
for idx, data in read_csv(path, headers=self.headers, | |||
sep=self.sep, dropna=self.dropna): | |||
ds.append(Instance(**data)) | |||
return ds | |||
def add_seg_tag(data): | |||
def _add_seg_tag(data): | |||
""" | |||
:param data: list of ([word], [pos], [heads], [head_tags]) | |||
@@ -0,0 +1,112 @@ | |||
import json | |||
def read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
""" | |||
Construct a generator to read csv items | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param headers: file's headers, if None, make file's first line as headers. default: None | |||
:param sep: separator for each column. default: ',' | |||
:param dropna: weather to ignore and drop invalid data, | |||
if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, csv item) | |||
""" | |||
with open(path, 'r', encoding=encoding) as f: | |||
start_idx = 0 | |||
if headers is None: | |||
headers = f.readline().rstrip('\r\n') | |||
headers = headers.split(sep) | |||
start_idx += 1 | |||
elif not isinstance(headers, (list, tuple)): | |||
raise TypeError("headers should be list or tuple, not {}." \ | |||
.format(type(headers))) | |||
for line_idx, line in enumerate(f, start_idx): | |||
contents = line.rstrip('\r\n').split(sep) | |||
if len(contents) != len(headers): | |||
if dropna: | |||
continue | |||
else: | |||
raise ValueError("Line {} has {} parts, while header has {} parts." \ | |||
.format(line_idx, len(contents), len(headers))) | |||
_dict = {} | |||
for header, content in zip(headers, contents): | |||
_dict[header] = content | |||
yield line_idx, _dict | |||
def read_json(path, encoding='utf-8', fields=None, dropna=True): | |||
""" | |||
Construct a generator to read json items | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param fields: json object's fields that needed, if None, all fields are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, json item) | |||
""" | |||
if fields: | |||
fields = set(fields) | |||
with open(path, 'r', encoding=encoding) as f: | |||
for line_idx, line in enumerate(f): | |||
data = json.loads(line) | |||
if fields is None: | |||
yield line_idx, data | |||
continue | |||
_res = {} | |||
for k, v in data.items(): | |||
if k in fields: | |||
_res[k] = v | |||
if len(_res) < len(fields): | |||
if dropna: | |||
continue | |||
else: | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
yield line_idx, _res | |||
def read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||
""" | |||
Construct a generator to read conll items | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, conll item) | |||
""" | |||
def parse_conll(sample): | |||
sample = list(map(list, zip(*sample))) | |||
sample = [sample[i] for i in indexes] | |||
for f in sample: | |||
if len(f) <= 0: | |||
raise ValueError('empty field') | |||
return sample | |||
with open(path, 'r', encoding=encoding) as f: | |||
sample = [] | |||
start = next(f) | |||
if '-DOCSTART-' not in start: | |||
sample.append(start.split()) | |||
for line_idx, line in enumerate(f, 1): | |||
if line.startswith('\n'): | |||
if len(sample): | |||
try: | |||
res = parse_conll(sample) | |||
sample = [] | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
continue | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split()) | |||
if len(sample) > 0: | |||
try: | |||
res = parse_conll(sample) | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
return | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) |
@@ -110,5 +110,5 @@ class ESIM(BaseModel): | |||
def predict(self, words1, words2, seq_len1, seq_len2): | |||
prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] | |||
return torch.argmax(prediction, dim=-1) | |||
return {'pred': torch.argmax(prediction, dim=-1)} | |||
@@ -1,4 +1,6 @@ | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.utils.rnn as rnn | |||
from fastNLP.modules.utils import initial_parameter | |||
@@ -19,21 +21,44 @@ class LSTM(nn.Module): | |||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | |||
bidirectional=False, bias=True, initial_method=None, get_hidden=False): | |||
super(LSTM, self).__init__() | |||
self.batch_first = batch_first | |||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, | |||
dropout=dropout, bidirectional=bidirectional) | |||
self.get_hidden = get_hidden | |||
initial_parameter(self, initial_method) | |||
def forward(self, x, h0=None, c0=None): | |||
def forward(self, x, seq_lens=None, h0=None, c0=None): | |||
if h0 is not None and c0 is not None: | |||
x, (ht, ct) = self.lstm(x, (h0, c0)) | |||
hx = (h0, c0) | |||
else: | |||
x, (ht, ct) = self.lstm(x) | |||
if self.get_hidden: | |||
return x, (ht, ct) | |||
hx = None | |||
if seq_lens is not None and not isinstance(x, rnn.PackedSequence): | |||
print('padding') | |||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||
if self.batch_first: | |||
x = x[sort_idx] | |||
else: | |||
x = x[:, sort_idx] | |||
x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) | |||
output, hx = self.lstm(x, hx) # -> [N,L,C] | |||
output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first) | |||
_, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) | |||
if self.batch_first: | |||
output = output[unsort_idx] | |||
else: | |||
output = output[:, unsort_idx] | |||
else: | |||
return x | |||
output, hx = self.lstm(x, hx) | |||
if self.get_hidden: | |||
return output, hx | |||
return output | |||
if __name__ == "__main__": | |||
lstm = LSTM(10) | |||
lstm = LSTM(input_size=2, hidden_size=2, get_hidden=False) | |||
x = torch.randn((3, 5, 2)) | |||
seq_lens = torch.tensor([5,1,2]) | |||
y = lstm(x, seq_lens) | |||
print(x) | |||
print(y) | |||
print(x.size(), y.size(), ) |
@@ -202,20 +202,6 @@ class TestDataSetMethods(unittest.TestCase): | |||
self.assertTrue(isinstance(ans, FieldArray)) | |||
self.assertEqual(ans.content, [[5, 6]] * 10) | |||
def test_reader(self): | |||
# 跑通即可 | |||
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv") | |||
self.assertTrue(isinstance(ds, DataSet)) | |||
self.assertTrue(len(ds) > 0) | |||
ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt") | |||
self.assertTrue(isinstance(ds, DataSet)) | |||
self.assertTrue(len(ds) > 0) | |||
ds = DataSet().read_pos("test/data_for_tests/people.txt") | |||
self.assertTrue(isinstance(ds, DataSet)) | |||
self.assertTrue(len(ds) > 0) | |||
def test_add_null(self): | |||
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||
ds = DataSet() | |||
@@ -0,0 +1,3 @@ | |||
{"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"} | |||
{"annotator_labels": ["contradiction"], "captionID": "3416050480.jpg#4", "gold_label": "contradiction", "pairID": "3416050480.jpg#4r1c", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is at a diner, ordering an omelette.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .)))"} | |||
{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} |
@@ -1,8 +1,7 @@ | |||
import unittest | |||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, ConllCWSReader, \ | |||
ZhConllPOSReader, ConllxDataLoader | |||
from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ | |||
CSVLoader, SNLILoader | |||
class TestDatasetLoader(unittest.TestCase): | |||
@@ -17,3 +16,11 @@ class TestDatasetLoader(unittest.TestCase): | |||
def test_PeopleDailyCorpusLoader(self): | |||
data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") | |||
def test_CSVLoader(self): | |||
ds = CSVLoader(sep='\t', headers=['words', 'label'])\ | |||
.load('test/data_for_tests/tutorial_sample_dataset.csv') | |||
assert len(ds) > 0 | |||
def test_SNLILoader(self): | |||
ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') | |||
assert len(ds) == 3 |
@@ -1,9 +0,0 @@ | |||
import unittest | |||
class TestUtils(unittest.TestCase): | |||
def test_case_1(self): | |||
pass | |||
def test_case_2(self): | |||
pass |
@@ -379,6 +379,14 @@ class TestTutorial(unittest.TestCase): | |||
dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
train_data_2[-1], dev_data_2[-1] | |||
for data in [train_data, dev_data, test_data]: | |||
data.rename_field('premise', 'words1') | |||
data.rename_field('hypothesis', 'words2') | |||
data.rename_field('premise_len', 'seq_len1') | |||
data.rename_field('hypothesis_len', 'seq_len2') | |||
data.set_input('words1', 'words2', 'seq_len1', 'seq_len2') | |||
# step 1:加载模型参数(非必选) | |||
from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||
args = ConfigSection() | |||