- add DataSet, Instance, Field to represent data in different levels - encapsulate batching method in Batch class - modify samplers in action.py to fit Batch - preprocessor.run returns DataSet, instead of list - Use Batch in Trainer/Tester - add required_arg "task" in Trainer/Tester - remove SeqLabelTrainer/SeqLabelTester dependencies successfully. They empty classes to deprecate. - modify SeqLabeling model, add another argument in forward, in order to compute mask inside model - test\model\seq_labeling.py workstags/v0.1.0
@@ -168,19 +168,7 @@ class BaseSampler(object): | |||
""" | |||
def __init__(self, data_set): | |||
""" | |||
:param data_set: multi-level list, of shape [num_example, *] | |||
""" | |||
self.data_set_length = len(data_set) | |||
self.data = data_set | |||
def __len__(self): | |||
return self.data_set_length | |||
def __iter__(self): | |||
def __call__(self, *args, **kwargs): | |||
raise NotImplementedError | |||
@@ -189,16 +177,8 @@ class SequentialSampler(BaseSampler): | |||
""" | |||
def __init__(self, data_set): | |||
""" | |||
:param data_set: multi-level list | |||
""" | |||
super(SequentialSampler, self).__init__(data_set) | |||
def __iter__(self): | |||
return iter(self.data) | |||
def __call__(self, data_set): | |||
return list(range(len(data_set))) | |||
class RandomSampler(BaseSampler): | |||
@@ -206,17 +186,9 @@ class RandomSampler(BaseSampler): | |||
""" | |||
def __init__(self, data_set): | |||
""" | |||
def __call__(self, data_set): | |||
return list(np.random.permutation(len(data_set))) | |||
:param data_set: multi-level list | |||
""" | |||
super(RandomSampler, self).__init__(data_set) | |||
self.order = np.random.permutation(self.data_set_length) | |||
def __iter__(self): | |||
return iter((self.data[idx] for idx in self.order)) | |||
class Batchifier(object): | |||
@@ -0,0 +1,126 @@ | |||
from collections import defaultdict | |||
import torch | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.field import TextField, LabelField | |||
from fastNLP.core.instance import Instance | |||
class Batch(object): | |||
"""Batch is an iterable object which iterates over mini-batches. | |||
:: | |||
for batch_x, batch_y in Batch(data_set): | |||
""" | |||
def __init__(self, dataset, batch_size, sampler, use_cuda): | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
self.sampler = sampler | |||
self.use_cuda = use_cuda | |||
self.idx_list = None | |||
self.curidx = 0 | |||
def __iter__(self): | |||
self.idx_list = self.sampler(self.dataset) | |||
self.curidx = 0 | |||
self.lengths = self.dataset.get_length() | |||
return self | |||
def __next__(self): | |||
""" | |||
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||
batch_x also contains an item (str: list of int) about origin lengths, | |||
which means ("field_name_origin_len": origin lengths). | |||
E.g. | |||
:: | |||
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | |||
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | |||
The names of fields are defined in preprocessor's convert_to_dataset method. | |||
""" | |||
if self.curidx >= len(self.idx_list): | |||
raise StopIteration | |||
else: | |||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
padding_length = {field_name: max(field_length[self.curidx: endidx]) | |||
for field_name, field_length in self.lengths.items()} | |||
origin_lengths = {field_name: field_length[self.curidx: endidx] | |||
for field_name, field_length in self.lengths.items()} | |||
batch_x, batch_y = defaultdict(list), defaultdict(list) | |||
for idx in range(self.curidx, endidx): | |||
x, y = self.dataset.to_tensor(idx, padding_length) | |||
for name, tensor in x.items(): | |||
batch_x[name].append(tensor) | |||
for name, tensor in y.items(): | |||
batch_y[name].append(tensor) | |||
batch_origin_length = {} | |||
# combine instances into a batch | |||
for batch in (batch_x, batch_y): | |||
for name, tensor_list in batch.items(): | |||
if self.use_cuda: | |||
batch[name] = torch.stack(tensor_list, dim=0).cuda() | |||
else: | |||
batch[name] = torch.stack(tensor_list, dim=0) | |||
# add origin lengths in batch_x | |||
for name, tensor in batch_x.items(): | |||
if self.use_cuda: | |||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]).cuda() | |||
else: | |||
batch_origin_length[name + "_origin_len"] = torch.LongTensor(origin_lengths[name]) | |||
batch_x.update(batch_origin_length) | |||
self.curidx += endidx | |||
return batch_x, batch_y | |||
if __name__ == "__main__": | |||
"""simple running example | |||
""" | |||
texts = ["i am a cat", | |||
"this is a test of new batch", | |||
"haha" | |||
] | |||
labels = [0, 1, 0] | |||
# prepare vocabulary | |||
vocab = {} | |||
for text in texts: | |||
for tokens in text.split(): | |||
if tokens not in vocab: | |||
vocab[tokens] = len(vocab) | |||
print("vocabulary: ", vocab) | |||
# prepare input dataset | |||
data = DataSet() | |||
for text, label in zip(texts, labels): | |||
x = TextField(text.split(), False) | |||
y = LabelField(label, is_target=True) | |||
ins = Instance(text=x, label=y) | |||
data.append(ins) | |||
# use vocabulary to index data | |||
data.index_field("text", vocab) | |||
# define naive sampler for batch class | |||
class SeqSampler: | |||
def __call__(self, dataset): | |||
return list(range(len(dataset))) | |||
# use batch to iterate dataset | |||
data_iterator = Batch(data, 2, SeqSampler(), False) | |||
for epoch in range(1): | |||
for batch_x, batch_y in data_iterator: | |||
print(batch_x) | |||
print(batch_y) | |||
# do stuff |
@@ -7,23 +7,36 @@ class DataSet(list): | |||
self.name = name | |||
if instances is not None: | |||
self.extend(instances) | |||
def index_all(self, vocab): | |||
for ins in self: | |||
ins.index_all(vocab) | |||
def index_field(self, field_name, vocab): | |||
for ins in self: | |||
ins.index_field(field_name, vocab) | |||
def to_tensor(self, idx: int, padding_length: dict): | |||
"""Convert an instance in a dataset to tensor. | |||
:param idx: int, the index of the instance in the dataset. | |||
:param padding_length: int | |||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||
""" | |||
ins = self[idx] | |||
return ins.to_tensor(padding_length) | |||
def get_length(self): | |||
"""Fetch lengths of all fields in all instances in a dataset. | |||
:return lengths: dict of (str: list). The str is the field name. | |||
The list contains lengths of this field in all instances. | |||
""" | |||
lengths = defaultdict(list) | |||
for ins in self: | |||
for field_name, field_length in ins.get_length().items(): | |||
lengths[field_name].append(field_length) | |||
return lengths | |||
@@ -1,18 +1,23 @@ | |||
import torch | |||
class Field(object): | |||
"""A field defines a data type. | |||
""" | |||
def __init__(self, is_target: bool): | |||
self.is_target = is_target | |||
def index(self, vocab): | |||
pass | |||
raise NotImplementedError | |||
def get_length(self): | |||
pass | |||
raise NotImplementedError | |||
def to_tensor(self, padding_length): | |||
pass | |||
raise NotImplementedError | |||
class TextField(Field): | |||
def __init__(self, text: list, is_target): | |||
@@ -31,25 +36,38 @@ class TextField(Field): | |||
return self._index | |||
def get_length(self): | |||
"""Fetch the length of the text field. | |||
:return length: int, the length of the text. | |||
""" | |||
return len(self.text) | |||
def to_tensor(self, padding_length: int): | |||
"""Convert text field to tensor. | |||
:param padding_length: int | |||
:return tensor: torch.LongTensor, of shape [padding_length, ] | |||
""" | |||
pads = [] | |||
if self._index is None: | |||
print('error') | |||
raise RuntimeError("Indexing not done before to_tensor in TextField.") | |||
if padding_length > self.get_length(): | |||
pads = [0 for i in range(padding_length - self.get_length())] | |||
# (length, ) | |||
pads = [0] * (padding_length - self.get_length()) | |||
return torch.LongTensor(self._index + pads) | |||
class LabelField(Field): | |||
def __init__(self, label, is_target=True): | |||
super(LabelField, self).__init__(is_target) | |||
self.label = label | |||
self._index = None | |||
def get_length(self): | |||
"""Fetch the length of the label field. | |||
:return length: int, the length of the label, always 1. | |||
""" | |||
return 1 | |||
def index(self, vocab): | |||
@@ -58,13 +76,13 @@ class LabelField(Field): | |||
else: | |||
pass | |||
return self._index | |||
def to_tensor(self, padding_length): | |||
if self._index is None: | |||
return torch.LongTensor([self.label]) | |||
else: | |||
return torch.LongTensor([self._index]) | |||
if __name__ == "__main__": | |||
tf = TextField("test the code".split()) | |||
tf = TextField("test the code".split(), is_target=False) |
@@ -0,0 +1,53 @@ | |||
class Instance(object): | |||
"""An instance which consists of Fields is an example in the DataSet. | |||
""" | |||
def __init__(self, **fields): | |||
self.fields = fields | |||
self.has_index = False | |||
self.indexes = {} | |||
def add_field(self, field_name, field): | |||
self.fields[field_name] = field | |||
def get_length(self): | |||
"""Fetch the length of all fields in the instance. | |||
:return length: dict of (str: int), which means (field name: field length). | |||
""" | |||
length = {name: field.get_length() for name, field in self.fields.items()} | |||
return length | |||
def index_field(self, field_name, vocab): | |||
"""use `vocab` to index certain field | |||
""" | |||
self.indexes[field_name] = self.fields[field_name].index(vocab) | |||
def index_all(self, vocab): | |||
"""use `vocab` to index all fields | |||
""" | |||
if self.has_index: | |||
print("error") | |||
return self.indexes | |||
indexes = {name: field.index(vocab) for name, field in self.fields.items()} | |||
self.indexes = indexes | |||
return indexes | |||
def to_tensor(self, padding_length: dict): | |||
"""Convert instance to tensor. | |||
:param padding_length: dict of (str: int), which means (field name: padding_length of this field) | |||
:return tensor_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||
tensor_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [padding_length, ]) | |||
""" | |||
tensor_x = {} | |||
tensor_y = {} | |||
for name, field in self.fields.items(): | |||
if field.is_target: | |||
tensor_y[name] = field.to_tensor(padding_length[name]) | |||
else: | |||
tensor_x[name] = field.to_tensor(padding_length[name]) | |||
return tensor_x, tensor_y |
@@ -3,6 +3,10 @@ import os | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.field import TextField, LabelField | |||
from fastNLP.core.instance import Instance | |||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
DEFAULT_RESERVED_LABEL = ['<reserved-2>', | |||
@@ -84,7 +88,7 @@ class BasePreprocess(object): | |||
return len(self.label2index) | |||
def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10): | |||
"""Main preprocessing pipeline. | |||
"""Main pre-processing pipeline. | |||
:param train_dev_data: three-level list, with either single label or multiple labels in a sample. | |||
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional) | |||
@@ -92,7 +96,9 @@ class BasePreprocess(object): | |||
:param train_dev_split: float, between [0, 1]. The ratio of training data used as validation set. | |||
:param cross_val: bool, whether to do cross validation. | |||
:param n_fold: int, the number of folds of cross validation. Only useful when cross_val is True. | |||
:return results: a tuple of datasets after preprocessing. | |||
:return results: multiple datasets after pre-processing. If test_data is provided, return one more dataset. | |||
If train_dev_split > 0, return one more dataset - the dev set. If cross_val is True, each dataset | |||
is a list of DataSet objects; Otherwise, each dataset is a DataSet object. | |||
""" | |||
if pickle_exist(pickle_path, "word2id.pkl") and pickle_exist(pickle_path, "class2id.pkl"): | |||
@@ -111,68 +117,87 @@ class BasePreprocess(object): | |||
index2label = self.build_reverse_dict(self.label2index) | |||
save_pickle(index2label, pickle_path, "id2class.pkl") | |||
data_train = [] | |||
data_dev = [] | |||
train_set = [] | |||
dev_set = [] | |||
if not cross_val: | |||
if not pickle_exist(pickle_path, "data_train.pkl"): | |||
data_train.extend(self.to_index(train_dev_data)) | |||
if train_dev_split > 0 and not pickle_exist(pickle_path, "data_dev.pkl"): | |||
split = int(len(data_train) * train_dev_split) | |||
data_dev = data_train[: split] | |||
data_train = data_train[split:] | |||
save_pickle(data_dev, pickle_path, "data_dev.pkl") | |||
split = int(len(train_dev_data) * train_dev_split) | |||
data_dev = train_dev_data[: split] | |||
data_train = train_dev_data[split:] | |||
train_set = self.convert_to_dataset(data_train, self.word2index, self.label2index) | |||
dev_set = self.convert_to_dataset(data_dev, self.word2index, self.label2index) | |||
save_pickle(dev_set, pickle_path, "data_dev.pkl") | |||
print("{} of the training data is split for validation. ".format(train_dev_split)) | |||
save_pickle(data_train, pickle_path, "data_train.pkl") | |||
else: | |||
train_set = self.convert_to_dataset(train_dev_data, self.word2index, self.label2index) | |||
save_pickle(train_set, pickle_path, "data_train.pkl") | |||
else: | |||
data_train = load_pickle(pickle_path, "data_train.pkl") | |||
train_set = load_pickle(pickle_path, "data_train.pkl") | |||
if pickle_exist(pickle_path, "data_dev.pkl"): | |||
data_dev = load_pickle(pickle_path, "data_dev.pkl") | |||
dev_set = load_pickle(pickle_path, "data_dev.pkl") | |||
else: | |||
# cross_val is True | |||
if not pickle_exist(pickle_path, "data_train_0.pkl"): | |||
# cross validation | |||
data_idx = self.to_index(train_dev_data) | |||
data_cv = self.cv_split(data_idx, n_fold) | |||
data_cv = self.cv_split(train_dev_data, n_fold) | |||
for i, (data_train_cv, data_dev_cv) in enumerate(data_cv): | |||
data_train_cv = self.convert_to_dataset(data_train_cv, self.word2index, self.label2index) | |||
data_dev_cv = self.convert_to_dataset(data_dev_cv, self.word2index, self.label2index) | |||
save_pickle( | |||
data_train_cv, pickle_path, | |||
"data_train_{}.pkl".format(i)) | |||
save_pickle( | |||
data_dev_cv, pickle_path, | |||
"data_dev_{}.pkl".format(i)) | |||
data_train.append(data_train_cv) | |||
data_dev.append(data_dev_cv) | |||
train_set.append(data_train_cv) | |||
dev_set.append(data_dev_cv) | |||
print("{}-fold cross validation.".format(n_fold)) | |||
else: | |||
for i in range(n_fold): | |||
data_train_cv = load_pickle(pickle_path, "data_train_{}.pkl".format(i)) | |||
data_dev_cv = load_pickle(pickle_path, "data_dev_{}.pkl".format(i)) | |||
data_train.append(data_train_cv) | |||
data_dev.append(data_dev_cv) | |||
train_set.append(data_train_cv) | |||
dev_set.append(data_dev_cv) | |||
# prepare test data if provided | |||
data_test = [] | |||
test_set = [] | |||
if test_data is not None: | |||
if not pickle_exist(pickle_path, "data_test.pkl"): | |||
data_test = self.to_index(test_data) | |||
save_pickle(data_test, pickle_path, "data_test.pkl") | |||
test_set = self.convert_to_dataset(test_data, self.word2index, self.label2index) | |||
save_pickle(test_set, pickle_path, "data_test.pkl") | |||
# return preprocessed results | |||
results = [data_train] | |||
results = [train_set] | |||
if cross_val or train_dev_split > 0: | |||
results.append(data_dev) | |||
results.append(dev_set) | |||
if test_data: | |||
results.append(data_test) | |||
results.append(test_set) | |||
if len(results) == 1: | |||
return results[0] | |||
else: | |||
return tuple(results) | |||
def build_dict(self, data): | |||
raise NotImplementedError | |||
label2index = DEFAULT_WORD_TO_INDEX.copy() | |||
word2index = DEFAULT_WORD_TO_INDEX.copy() | |||
for example in data: | |||
for word in example[0]: | |||
if word not in word2index: | |||
word2index[word] = len(word2index) | |||
label = example[1] | |||
if isinstance(label, str): | |||
# label is a string | |||
if label not in label2index: | |||
label2index[label] = len(label2index) | |||
elif isinstance(label, list): | |||
# label is a list of strings | |||
for single_label in label: | |||
if single_label not in label2index: | |||
label2index[single_label] = len(label2index) | |||
return word2index, label2index | |||
def to_index(self, data): | |||
raise NotImplementedError | |||
def build_reverse_dict(self, word_dict): | |||
id2word = {word_dict[w]: w for w in word_dict} | |||
@@ -186,11 +211,23 @@ class BasePreprocess(object): | |||
return data_train, data_dev | |||
def cv_split(self, data, n_fold): | |||
"""Split data for cross validation.""" | |||
"""Split data for cross validation. | |||
:param data: list of string | |||
:param n_fold: int | |||
:return data_cv: | |||
:: | |||
[ | |||
(data_train, data_dev), # 1st fold | |||
(data_train, data_dev), # 2nd fold | |||
... | |||
] | |||
""" | |||
data_copy = data.copy() | |||
np.random.shuffle(data_copy) | |||
fold_size = round(len(data_copy) / n_fold) | |||
data_cv = [] | |||
for i in range(n_fold - 1): | |||
start = i * fold_size | |||
@@ -202,154 +239,62 @@ class BasePreprocess(object): | |||
data_dev = data_copy[start:] | |||
data_train = data_copy[:start] | |||
data_cv.append((data_train, data_dev)) | |||
return data_cv | |||
def convert_to_dataset(self, data, vocab, label_vocab): | |||
"""Convert list of indices into a DataSet object. | |||
class SeqLabelPreprocess(BasePreprocess): | |||
"""Preprocess pipeline, including building mapping from words to index, from index to words, | |||
from labels/classes to index, from index to labels/classes. | |||
data of three-level list which have multiple labels in each sample. | |||
:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
""" | |||
def __init__(self): | |||
super(SeqLabelPreprocess, self).__init__() | |||
def build_dict(self, data): | |||
"""Add new words with indices into self.word_dict, new labels with indices into self.label_dict. | |||
:param data: three-level list | |||
:: | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:return word2index: dict of {str, int} | |||
label2index: dict of {str, int} | |||
:param data: list. Entries are strings. | |||
:param vocab: a dict, mapping string (token) to index (int). | |||
:param label_vocab: a dict, mapping string (label) to index (int). | |||
:return data_set: a DataSet object | |||
""" | |||
# In seq labeling, both word seq and label seq need to be padded to the same length in a mini-batch. | |||
label2index = DEFAULT_WORD_TO_INDEX.copy() | |||
word2index = DEFAULT_WORD_TO_INDEX.copy() | |||
use_word_seq = False | |||
use_label_seq = False | |||
data_set = DataSet() | |||
for example in data: | |||
for word, label in zip(example[0], example[1]): | |||
if word not in word2index: | |||
word2index[word] = len(word2index) | |||
if label not in label2index: | |||
label2index[label] = len(label2index) | |||
return word2index, label2index | |||
words, label = example[0], example[1] | |||
instance = Instance() | |||
def to_index(self, data): | |||
"""Convert word strings and label strings into indices. | |||
if isinstance(words, list): | |||
x = TextField(words, is_target=False) | |||
instance.add_field("word_seq", x) | |||
use_word_seq = True | |||
else: | |||
raise NotImplementedError("words is a {}".format(type(words))) | |||
if isinstance(label, list): | |||
y = TextField(label, is_target=True) | |||
instance.add_field("label_seq", y) | |||
use_label_seq = True | |||
elif isinstance(label, str): | |||
y = LabelField(label, is_target=True) | |||
instance.add_field("label", y) | |||
else: | |||
raise NotImplementedError("label is a {}".format(type(label))) | |||
:param data: three-level list | |||
:: | |||
data_set.append(instance) | |||
if use_word_seq: | |||
data_set.index_field("word_seq", vocab) | |||
if use_label_seq: | |||
data_set.index_field("label_seq", label_vocab) | |||
return data_set | |||
[ | |||
[ [word_11, word_12, ...], [label_1, label_1, ...] ], | |||
[ [word_21, word_22, ...], [label_2, label_1, ...] ], | |||
... | |||
] | |||
:return data_index: the same shape as data, but each string is replaced by its corresponding index | |||
""" | |||
data_index = [] | |||
for example in data: | |||
word_list = [] | |||
label_list = [] | |||
for word, label in zip(example[0], example[1]): | |||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||
label_list.append(self.label2index.get(label, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||
data_index.append([word_list, label_list]) | |||
return data_index | |||
class SeqLabelPreprocess(BasePreprocess): | |||
def __init__(self): | |||
super(SeqLabelPreprocess, self).__init__() | |||
class ClassPreprocess(BasePreprocess): | |||
""" Preprocess pipeline for classification datasets. | |||
Preprocess pipeline, including building mapping from words to index, from index to words, | |||
from labels/classes to index, from index to labels/classes. | |||
design for data of three-level list which has a single label in each sample. | |||
:: | |||
[ | |||
[ [word_11, word_12, ...], label_1 ], | |||
[ [word_21, word_22, ...], label_2 ], | |||
... | |||
] | |||
""" | |||
def __init__(self): | |||
super(ClassPreprocess, self).__init__() | |||
def build_dict(self, data): | |||
"""Build vocabulary.""" | |||
# build vocabulary from scratch if nothing exists | |||
word2index = DEFAULT_WORD_TO_INDEX.copy() | |||
label2index = DEFAULT_WORD_TO_INDEX.copy() | |||
# collect every word and label | |||
for sent, label in data: | |||
if len(sent) <= 1: | |||
continue | |||
if label not in label2index: | |||
label2index[label] = len(label2index) | |||
for word in sent: | |||
if word not in word2index: | |||
word2index[word] = len(word2index) | |||
return word2index, label2index | |||
def to_index(self, data): | |||
"""Convert word strings and label strings into indices. | |||
:param data: three-level list | |||
:: | |||
[ | |||
[ [word_11, word_12, ...], label_1 ], | |||
[ [word_21, word_22, ...], label_2 ], | |||
... | |||
] | |||
:return data_index: the same shape as data, but each string is replaced by its corresponding index | |||
""" | |||
data_index = [] | |||
for example in data: | |||
word_list = [] | |||
# example[0] is the word list, example[1] is the single label | |||
for word in example[0]: | |||
word_list.append(self.word2index.get(word, DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL])) | |||
label_index = self.label2index.get(example[1], DEFAULT_WORD_TO_INDEX[DEFAULT_UNKNOWN_LABEL]) | |||
data_index.append([word_list, label_index]) | |||
return data_index | |||
def infer_preprocess(pickle_path, data): | |||
"""Preprocess over inference data. Transform three-level list of strings into that of index. | |||
:: | |||
[ | |||
[word_11, word_12, ...], | |||
[word_21, word_22, ...], | |||
... | |||
] | |||
""" | |||
word2index = load_pickle(pickle_path, "word2id.pkl") | |||
data_index = [] | |||
for example in data: | |||
data_index.append([word2index.get(w, DEFAULT_UNKNOWN_LABEL) for w in example]) | |||
return data_index | |||
if __name__ == "__main__": | |||
p = BasePreprocess() | |||
train_dev_data = [[["I", "am", "a", "good", "student", "."], "0"], | |||
[["You", "are", "pretty", "."], "1"] | |||
] | |||
training_set = p.run(train_dev_data) | |||
print(training_set) |
@@ -2,8 +2,8 @@ import numpy as np | |||
import torch | |||
from fastNLP.core.action import Action | |||
from fastNLP.core.action import RandomSampler, Batchifier | |||
from fastNLP.modules import utils | |||
from fastNLP.core.action import RandomSampler | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.saver.logger import create_logger | |||
logger = create_logger(__name__, "./train_test.log") | |||
@@ -35,16 +35,16 @@ class BaseTester(object): | |||
""" | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
This is used to warn users of essential settings in the training. | |||
Obviously, "required_args" is the subset of "default_args". | |||
The value in "default_args" to the keys in "required_args" is simply for type check. | |||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||
""" | |||
# add required arguments here | |||
required_args = {} | |||
required_args = {"task" # one of ("seq_label", "text_classify") | |||
} | |||
for req_key in required_args: | |||
if req_key not in kwargs: | |||
logger.error("Tester lacks argument {}".format(req_key)) | |||
raise ValueError("Tester lacks argument {}".format(req_key)) | |||
self._task = kwargs["task"] | |||
for key in default_args: | |||
if key in kwargs: | |||
@@ -83,10 +83,10 @@ class BaseTester(object): | |||
self.eval_history.clear() | |||
self.batch_output.clear() | |||
iterator = iter(Batchifier(RandomSampler(dev_data), self.batch_size, drop_last=False)) | |||
data_iterator = Batch(dev_data, self.batch_size, sampler=RandomSampler(), use_cuda=self.use_cuda) | |||
step = 0 | |||
for batch_x, batch_y in self.make_batch(iterator): | |||
for batch_x, batch_y in data_iterator: | |||
with torch.no_grad(): | |||
prediction = self.data_forward(network, batch_x) | |||
eval_results = self.evaluate(prediction, batch_y) | |||
@@ -112,7 +112,8 @@ class BaseTester(object): | |||
def data_forward(self, network, x): | |||
"""A forward pass of the model. """ | |||
raise NotImplementedError | |||
y = network(**x) | |||
return y | |||
def evaluate(self, predict, truth): | |||
"""Compute evaluation metrics. | |||
@@ -121,7 +122,26 @@ class BaseTester(object): | |||
:param truth: Tensor | |||
:return eval_results: can be anything. It will be stored in self.eval_history | |||
""" | |||
raise NotImplementedError | |||
batch_size, max_len = predict.size(0), predict.size(1) | |||
if "label_seq" in truth: | |||
truth = truth["label_seq"] | |||
elif "label" in truth: | |||
truth = truth["label"] | |||
else: | |||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||
loss = self._model.loss(predict, truth) / batch_size | |||
prediction = self._model.prediction(predict) | |||
# pad prediction to equal length | |||
for pred in prediction: | |||
if len(pred) < max_len: | |||
pred += [0] * (max_len - len(pred)) | |||
results = torch.Tensor(prediction).view(-1, ) | |||
# make sure "results" is in the same device as "truth" | |||
results = results.to(truth) | |||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | |||
return [float(loss), float(accuracy)] | |||
@property | |||
def metrics(self): | |||
@@ -131,7 +151,9 @@ class BaseTester(object): | |||
:return : variable number of outputs | |||
""" | |||
raise NotImplementedError | |||
batch_loss = np.mean([x[0] for x in self.eval_history]) | |||
batch_accuracy = np.mean([x[1] for x in self.eval_history]) | |||
return batch_loss, batch_accuracy | |||
def show_metrics(self): | |||
"""Customize evaluation outputs in Trainer. | |||
@@ -140,10 +162,8 @@ class BaseTester(object): | |||
:return print_str: str | |||
""" | |||
raise NotImplementedError | |||
def make_batch(self, iterator): | |||
raise NotImplementedError | |||
loss, accuracy = self.metrics | |||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||
def make_eval_output(self, predictions, eval_results): | |||
"""Customize Tester outputs. | |||
@@ -152,78 +172,21 @@ class BaseTester(object): | |||
:param eval_results: Tensor | |||
:return: str, to be printed. | |||
""" | |||
raise NotImplementedError | |||
return self.show_metrics() | |||
class SeqLabelTester(BaseTester): | |||
"""Tester for sequence labeling. | |||
""" | |||
def __init__(self, **test_args): | |||
""" | |||
:param test_args: a dict-like object that has __getitem__ method, can be accessed by "test_args["key_str"]" | |||
""" | |||
test_args.update({"task": "seq_label"}) | |||
print( | |||
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester with argument 'task'='seq_label'.") | |||
super(SeqLabelTester, self).__init__(**test_args) | |||
self.max_len = None | |||
self.mask = None | |||
self.seq_len = None | |||
def data_forward(self, network, inputs): | |||
"""This is only for sequence labeling with CRF decoder. | |||
:param network: a PyTorch model | |||
:param inputs: tuple of (x, seq_len) | |||
x: Tensor of shape [batch_size, max_len], where max_len is the maximum length of the mini-batch | |||
after padding. | |||
seq_len: list of int, the lengths of sequences before padding. | |||
:return y: Tensor of shape [batch_size, max_len] | |||
""" | |||
if not isinstance(inputs, tuple): | |||
raise RuntimeError("output_length must be true for sequence modeling.") | |||
# unpack the returned value from make_batch | |||
x, seq_len = inputs[0], inputs[1] | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = utils.seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
mask = mask.cuda() | |||
self.mask = mask | |||
self.seq_len = seq_len | |||
y = network(x) | |||
return y | |||
def evaluate(self, predict, truth): | |||
"""Compute metrics (or loss). | |||
:param predict: Tensor, [batch_size, max_len, tag_size] | |||
:param truth: Tensor, [batch_size, max_len] | |||
:return: | |||
""" | |||
batch_size, max_len = predict.size(0), predict.size(1) | |||
loss = self._model.loss(predict, truth, self.mask) / batch_size | |||
prediction = self._model.prediction(predict, self.mask) | |||
results = torch.Tensor(prediction).view(-1, ) | |||
# make sure "results" is in the same device as "truth" | |||
results = results.to(truth) | |||
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] | |||
return [float(loss), float(accuracy)] | |||
def metrics(self): | |||
batch_loss = np.mean([x[0] for x in self.eval_history]) | |||
batch_accuracy = np.mean([x[1] for x in self.eval_history]) | |||
return batch_loss, batch_accuracy | |||
def show_metrics(self): | |||
"""This is called by Trainer to print evaluation on dev set. | |||
:return print_str: str | |||
""" | |||
loss, accuracy = self.metrics() | |||
return "dev loss={:.2f}, accuracy={:.2f}".format(loss, accuracy) | |||
def make_batch(self, iterator): | |||
return Action.make_batch(iterator, use_cuda=self.use_cuda, output_length=True) | |||
class ClassificationTester(BaseTester): | |||
@@ -236,9 +199,6 @@ class ClassificationTester(BaseTester): | |||
""" | |||
super(ClassificationTester, self).__init__(**test_args) | |||
def make_batch(self, iterator, max_len=None): | |||
return Action.make_batch(iterator, use_cuda=self.use_cuda, max_len=max_len) | |||
def data_forward(self, network, x): | |||
"""Forward through network.""" | |||
logits = network(x) | |||
@@ -4,15 +4,14 @@ import time | |||
from datetime import timedelta | |||
import torch | |||
import tensorboardX | |||
from tensorboardX import SummaryWriter | |||
from fastNLP.core.action import Action | |||
from fastNLP.core.action import RandomSampler, Batchifier | |||
from fastNLP.core.action import RandomSampler | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.loss import Loss | |||
from fastNLP.core.optimizer import Optimizer | |||
from fastNLP.core.tester import SeqLabelTester, ClassificationTester | |||
from fastNLP.modules import utils | |||
from fastNLP.saver.logger import create_logger | |||
from fastNLP.saver.model_saver import ModelSaver | |||
@@ -50,16 +49,16 @@ class BaseTrainer(object): | |||
""" | |||
"required_args" is the collection of arguments that users must pass to Trainer explicitly. | |||
This is used to warn users of essential settings in the training. | |||
Obviously, "required_args" is the subset of "default_args". | |||
The value in "default_args" to the keys in "required_args" is simply for type check. | |||
Specially, "required_args" does not have default value, so they have nothing to do with "default_args". | |||
""" | |||
# add required arguments here | |||
required_args = {} | |||
required_args = {"task" # one of ("seq_label", "text_classify") | |||
} | |||
for req_key in required_args: | |||
if req_key not in kwargs: | |||
logger.error("Trainer lacks argument {}".format(req_key)) | |||
raise ValueError("Trainer lacks argument {}".format(req_key)) | |||
self._task = kwargs["task"] | |||
for key in default_args: | |||
if key in kwargs: | |||
@@ -90,13 +89,14 @@ class BaseTrainer(object): | |||
self._optimizer_proto = default_args["optimizer"] | |||
self._summary_writer = SummaryWriter(self.pickle_path + 'tensorboard_logs') | |||
self._graph_summaried = False | |||
self._best_accuracy = 0.0 | |||
def train(self, network, train_data, dev_data=None): | |||
"""General Training Procedure | |||
:param network: a model | |||
:param train_data: three-level list, the training set. | |||
:param dev_data: three-level list, the validation data (optional) | |||
:param train_data: a DataSet instance, the training data | |||
:param dev_data: a DataSet instance, the validation data (optional) | |||
""" | |||
# transfer model to gpu if available | |||
if torch.cuda.is_available() and self.use_cuda: | |||
@@ -128,7 +128,8 @@ class BaseTrainer(object): | |||
# turn on network training mode | |||
self.mode(network, test=False) | |||
# prepare mini-batch iterator | |||
data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) | |||
data_iterator = Batch(train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||
use_cuda=self.use_cuda) | |||
logger.info("prepared data iterator") | |||
# one forward and backward pass | |||
@@ -157,7 +158,7 @@ class BaseTrainer(object): | |||
- epoch: int, | |||
""" | |||
step = 0 | |||
for batch_x, batch_y in self.make_batch(data_iterator): | |||
for batch_x, batch_y in data_iterator: | |||
prediction = self.data_forward(network, batch_x) | |||
@@ -166,10 +167,6 @@ class BaseTrainer(object): | |||
self.update() | |||
self._summary_writer.add_scalar("loss", loss.item(), global_step=step) | |||
if not self._graph_summaried: | |||
self._summary_writer.add_graph(network, batch_x) | |||
self._graph_summaried = True | |||
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0: | |||
end = time.time() | |||
diff = timedelta(seconds=round(end - kwargs["start"])) | |||
@@ -204,9 +201,6 @@ class BaseTrainer(object): | |||
network_copy = copy.deepcopy(network) | |||
self.train(network_copy, train_data_cv[i], dev_data_cv[i]) | |||
def make_batch(self, iterator): | |||
raise NotImplementedError | |||
def mode(self, network, test): | |||
Action.mode(network, test) | |||
@@ -224,7 +218,12 @@ class BaseTrainer(object): | |||
self._optimizer.step() | |||
def data_forward(self, network, x): | |||
raise NotImplementedError | |||
y = network(**x) | |||
if not self._graph_summaried: | |||
if self._task == "seq_label": | |||
self._summary_writer.add_graph(network, (x["word_seq"], x["word_seq_origin_len"]), verbose=False) | |||
self._graph_summaried = True | |||
return y | |||
def grad_backward(self, loss): | |||
"""Compute gradient with link rules. | |||
@@ -243,6 +242,12 @@ class BaseTrainer(object): | |||
:param truth: ground truth label vector | |||
:return: a scalar | |||
""" | |||
if "label_seq" in truth: | |||
truth = truth["label_seq"] | |||
elif "label" in truth: | |||
truth = truth["label"] | |||
else: | |||
raise NotImplementedError("Unknown key {} in batch_y.".format(truth.keys())) | |||
return self._loss_func(predict, truth) | |||
def define_loss(self): | |||
@@ -270,7 +275,12 @@ class BaseTrainer(object): | |||
:param validator: a Tester instance | |||
:return: bool, True means current results on dev set is the best. | |||
""" | |||
raise NotImplementedError | |||
loss, accuracy = validator.metrics() | |||
if accuracy > self._best_accuracy: | |||
self._best_accuracy = accuracy | |||
return True | |||
else: | |||
return False | |||
def save_model(self, network, model_name): | |||
"""Save this model with such a name. | |||
@@ -291,55 +301,11 @@ class SeqLabelTrainer(BaseTrainer): | |||
"""Trainer for Sequence Labeling | |||
""" | |||
def __init__(self, **kwargs): | |||
kwargs.update({"task": "seq_label"}) | |||
print( | |||
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer with argument 'task'='seq_label'.") | |||
super(SeqLabelTrainer, self).__init__(**kwargs) | |||
# self.vocab_size = kwargs["vocab_size"] | |||
# self.num_classes = kwargs["num_classes"] | |||
self.max_len = None | |||
self.mask = None | |||
self.best_accuracy = 0.0 | |||
def data_forward(self, network, inputs): | |||
if not isinstance(inputs, tuple): | |||
raise RuntimeError("output_length must be true for sequence modeling. Receive {}".format(type(inputs[0]))) | |||
# unpack the returned value from make_batch | |||
x, seq_len = inputs[0], inputs[1] | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = utils.seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
if torch.cuda.is_available() and self.use_cuda: | |||
mask = mask.cuda() | |||
self.mask = mask | |||
y = network(x) | |||
return y | |||
def get_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
:param predict: prediction label vector, [batch_size, max_len, tag_size] | |||
:param truth: ground truth label vector, [batch_size, max_len] | |||
:return loss: a scalar | |||
""" | |||
batch_size, max_len = predict.size(0), predict.size(1) | |||
assert truth.shape == (batch_size, max_len) | |||
loss = self._model.loss(predict, truth, self.mask) | |||
return loss | |||
def best_eval_result(self, validator): | |||
loss, accuracy = validator.metrics() | |||
if accuracy > self.best_accuracy: | |||
self.best_accuracy = accuracy | |||
return True | |||
else: | |||
return False | |||
def make_batch(self, iterator): | |||
return Action.make_batch(iterator, output_length=True, use_cuda=self.use_cuda) | |||
def _create_validator(self, valid_args): | |||
return SeqLabelTester(**valid_args) | |||
@@ -361,9 +327,6 @@ class ClassificationTrainer(BaseTrainer): | |||
logits = network(x) | |||
return logits | |||
def make_batch(self, iterator): | |||
return Action.make_batch(iterator, output_length=False, use_cuda=self.use_cuda) | |||
def get_acc(self, y_logit, y_true): | |||
"""Compute accuracy.""" | |||
y_pred = torch.argmax(y_logit, dim=-1) | |||
@@ -1,86 +0,0 @@ | |||
from collections import defaultdict | |||
import torch | |||
class Batch(object): | |||
def __init__(self, dataset, sampler, batch_size): | |||
self.dataset = dataset | |||
self.sampler = sampler | |||
self.batch_size = batch_size | |||
self.idx_list = None | |||
self.curidx = 0 | |||
def __iter__(self): | |||
self.idx_list = self.sampler(self.dataset) | |||
self.curidx = 0 | |||
self.lengths = self.dataset.get_length() | |||
return self | |||
def __next__(self): | |||
if self.curidx >= len(self.idx_list): | |||
raise StopIteration | |||
else: | |||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
padding_length = {field_name : max(field_length[self.curidx: endidx]) | |||
for field_name, field_length in self.lengths.items()} | |||
batch_x, batch_y = defaultdict(list), defaultdict(list) | |||
for idx in range(self.curidx, endidx): | |||
x, y = self.dataset.to_tensor(idx, padding_length) | |||
for name, tensor in x.items(): | |||
batch_x[name].append(tensor) | |||
for name, tensor in y.items(): | |||
batch_y[name].append(tensor) | |||
for batch in (batch_x, batch_y): | |||
for name, tensor_list in batch.items(): | |||
print(name, " ", tensor_list) | |||
batch[name] = torch.stack(tensor_list, dim=0) | |||
self.curidx += endidx | |||
return batch_x, batch_y | |||
if __name__ == "__main__": | |||
"""simple running example | |||
""" | |||
from field import TextField, LabelField | |||
from instance import Instance | |||
from dataset import DataSet | |||
texts = ["i am a cat", | |||
"this is a test of new batch", | |||
"haha" | |||
] | |||
labels = [0, 1, 0] | |||
# prepare vocabulary | |||
vocab = {} | |||
for text in texts: | |||
for tokens in text.split(): | |||
if tokens not in vocab: | |||
vocab[tokens] = len(vocab) | |||
# prepare input dataset | |||
data = DataSet() | |||
for text, label in zip(texts, labels): | |||
x = TextField(text.split(), False) | |||
y = LabelField(label, is_target=True) | |||
ins = Instance(text=x, label=y) | |||
data.append(ins) | |||
# use vocabulary to index data | |||
data.index_field("text", vocab) | |||
# define naive sampler for batch class | |||
class SeqSampler: | |||
def __call__(self, dataset): | |||
return list(range(len(dataset))) | |||
# use bacth to iterate dataset | |||
batcher = Batch(data, SeqSampler(), 2) | |||
for epoch in range(3): | |||
for batch_x, batch_y in batcher: | |||
print(batch_x) | |||
print(batch_y) | |||
# do stuff | |||
@@ -1,38 +0,0 @@ | |||
class Instance(object): | |||
def __init__(self, **fields): | |||
self.fields = fields | |||
self.has_index = False | |||
self.indexes = {} | |||
def add_field(self, field_name, field): | |||
self.fields[field_name] = field | |||
def get_length(self): | |||
length = {name : field.get_length() for name, field in self.fields.items()} | |||
return length | |||
def index_field(self, field_name, vocab): | |||
"""use `vocab` to index certain field | |||
""" | |||
self.indexes[field_name] = self.fields[field_name].index(vocab) | |||
def index_all(self, vocab): | |||
"""use `vocab` to index all fields | |||
""" | |||
if self.has_index: | |||
print("error") | |||
return self.indexes | |||
indexes = {name : field.index(vocab) for name, field in self.fields.items()} | |||
self.indexes = indexes | |||
return indexes | |||
def to_tensor(self, padding_length: dict): | |||
tensorX = {} | |||
tensorY = {} | |||
for name, field in self.fields.items(): | |||
if field.is_target: | |||
tensorY[name] = field.to_tensor(padding_length[name]) | |||
else: | |||
tensorX[name] = field.to_tensor(padding_length[name]) | |||
return tensorX, tensorY |
@@ -4,6 +4,20 @@ from fastNLP.models.base_model import BaseModel | |||
from fastNLP.modules import decoder, encoder | |||
def seq_mask(seq_len, max_len): | |||
"""Create a mask for the sequences. | |||
:param seq_len: list or torch.LongTensor | |||
:param max_len: int | |||
:return mask: torch.LongTensor | |||
""" | |||
if isinstance(seq_len, list): | |||
seq_len = torch.LongTensor(seq_len) | |||
mask = [torch.ge(seq_len, i + 1) for i in range(max_len)] | |||
mask = torch.stack(mask, 1) | |||
return mask | |||
class SeqLabeling(BaseModel): | |||
""" | |||
PyTorch Network for sequence labeling | |||
@@ -20,13 +34,17 @@ class SeqLabeling(BaseModel): | |||
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim) | |||
self.Linear = encoder.linear.Linear(hidden_dim, num_classes) | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||
self.mask = None | |||
def forward(self, x): | |||
def forward(self, word_seq, word_seq_origin_len): | |||
""" | |||
:param x: LongTensor, [batch_size, mex_len] | |||
:param word_seq: LongTensor, [batch_size, mex_len] | |||
:param word_seq_origin_len: LongTensor, [batch_size,], the origin lengths of the sequences. | |||
:return y: [batch_size, mex_len, tag_size] | |||
""" | |||
x = self.Embedding(x) | |||
self.mask = self.make_mask(word_seq, word_seq_origin_len) | |||
x = self.Embedding(word_seq) | |||
# [batch_size, max_len, word_emb_dim] | |||
x = self.Rnn(x) | |||
# [batch_size, max_len, hidden_size * direction] | |||
@@ -34,27 +52,32 @@ class SeqLabeling(BaseModel): | |||
# [batch_size, max_len, num_classes] | |||
return x | |||
def loss(self, x, y, mask): | |||
def loss(self, x, y): | |||
""" | |||
Negative log likelihood loss. | |||
:param x: Tensor, [batch_size, max_len, tag_size] | |||
:param y: Tensor, [batch_size, max_len] | |||
:param mask: ByteTensor, [batch_size, ,max_len] | |||
:return loss: a scalar Tensor | |||
""" | |||
x = x.float() | |||
y = y.long() | |||
total_loss = self.Crf(x, y, mask) | |||
total_loss = self.Crf(x, y, self.mask) | |||
return torch.mean(total_loss) | |||
def prediction(self, x, mask): | |||
def make_mask(self, x, seq_len): | |||
batch_size, max_len = x.size(0), x.size(1) | |||
mask = seq_mask(seq_len, max_len) | |||
mask = mask.byte().view(batch_size, max_len) | |||
mask = mask.to(x) | |||
return mask | |||
def prediction(self, x): | |||
""" | |||
:param x: FloatTensor, [batch_size, max_len, tag_size] | |||
:param mask: ByteTensor, [batch_size, max_len] | |||
:return prediction: list of [decode path(list)] | |||
""" | |||
tag_seq = self.Crf.viterbi_decode(x, mask) | |||
tag_seq = self.Crf.viterbi_decode(x, self.mask) | |||
return tag_seq | |||
@@ -81,11 +104,14 @@ class AdvSeqLabel(SeqLabeling): | |||
self.Crf = decoder.CRF.ConditionalRandomField(num_classes) | |||
def forward(self, x): | |||
def forward(self, x, seq_len): | |||
""" | |||
:param x: LongTensor, [batch_size, mex_len] | |||
:param seq_len: list of int. | |||
:return y: [batch_size, mex_len, tag_size] | |||
""" | |||
self.mask = self.make_mask(x, seq_len) | |||
batch_size = x.size(0) | |||
max_len = x.size(1) | |||
x = self.Embedding(x) | |||
@@ -15,11 +15,11 @@ from fastNLP.core.optimizer import Optimizer | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files") | |||
parser.add_argument("-t", "--train", type=str, default="./data_for_tests/people.txt", | |||
parser.add_argument("-t", "--train", type=str, default="../data_for_tests/people.txt", | |||
help="path to the training data") | |||
parser.add_argument("-c", "--config", type=str, default="./data_for_tests/config", help="path to the config file") | |||
parser.add_argument("-c", "--config", type=str, default="../data_for_tests/config", help="path to the config file") | |||
parser.add_argument("-m", "--model_name", type=str, default="seq_label_model.pkl", help="the name of the model") | |||
parser.add_argument("-i", "--infer", type=str, default="data_for_tests/people_infer.txt", | |||
parser.add_argument("-i", "--infer", type=str, default="../data_for_tests/people_infer.txt", | |||
help="data used for inference") | |||
args = parser.parse_args() | |||
@@ -86,7 +86,7 @@ def train_and_test(): | |||
trainer = SeqLabelTrainer( | |||
epochs=trainer_args["epochs"], | |||
batch_size=trainer_args["batch_size"], | |||
validate=trainer_args["validate"], | |||
validate=False, | |||
use_cuda=trainer_args["use_cuda"], | |||
pickle_path=pickle_path, | |||
save_best_dev=trainer_args["save_best_dev"], | |||
@@ -139,5 +139,5 @@ def train_and_test(): | |||
if __name__ == "__main__": | |||
# train_and_test() | |||
infer() | |||
train_and_test() | |||
# infer() |
@@ -115,4 +115,4 @@ def train(): | |||
if __name__ == "__main__": | |||
train() | |||
infer() | |||
# infer() |