* 重构dtype的检测代码,在FieldArray的初始化和append两处,达到更好的代码复用 * 类型检测的责任完全落在FieldArray,DataSet与之配合 测试: * 整理dtype相关的测试代码 * 给所有tutorial添加测试 其他: * 完善一个完整的Conll dataset loader * 升级POS tag model训练脚本tags/v0.3.1^2
@@ -2,8 +2,8 @@ import _pickle as pickle | |||
import numpy as np | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.fieldarray import AutoPadder | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.io.base_loader import DataLoaderRegister | |||
@@ -142,7 +142,8 @@ class DataSet(object): | |||
if len(self.field_arrays) == 0: | |||
# DataSet has no field yet | |||
for name, field in ins.fields.items(): | |||
self.field_arrays[name] = FieldArray(name, [field]) | |||
field = field.tolist() if isinstance(field, np.ndarray) else field | |||
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | |||
else: | |||
if len(self.field_arrays) != len(ins.fields): | |||
raise ValueError( | |||
@@ -290,9 +291,11 @@ class DataSet(object): | |||
extra_param['is_input'] = old_field.is_input | |||
if 'is_target' not in extra_param: | |||
extra_param['is_target'] = old_field.is_target | |||
self.add_field(name=new_field_name, fields=results) | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], | |||
is_target=extra_param["is_target"]) | |||
else: | |||
self.add_field(name=new_field_name, fields=results) | |||
self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None)) | |||
else: | |||
return results | |||
@@ -334,13 +337,14 @@ class DataSet(object): | |||
train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||
train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||
train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||
train_set.field_arrays[field_name].is_2d_list = self.field_arrays[field_name].is_2d_list | |||
train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input | |||
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target | |||
dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder | |||
dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype | |||
dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype | |||
dev_set.field_arrays[field_name].is_2d_list = self.field_arrays[field_name].is_2d_list | |||
dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim | |||
return train_set, dev_set | |||
@@ -100,6 +100,22 @@ class FieldArray(object): | |||
""" | |||
def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): | |||
"""DataSet在初始化时会有两类方法对FieldArray操作: | |||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | |||
然后后面的样本使用FieldArray.append进行添加。 | |||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||
注意:np.array必须仅在最外层,即np.array([np.array, np.array]) 和 list of np.array不考虑 | |||
类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 | |||
""" | |||
self.name = name | |||
if isinstance(content, list): | |||
content = content | |||
@@ -107,31 +123,39 @@ class FieldArray(object): | |||
content = content.tolist() # convert np.ndarray into 2-D list | |||
else: | |||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||
self.content = content | |||
if len(content) == 0: | |||
raise RuntimeError("Cannot initialize FieldArray with empty list.") | |||
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 | |||
self.content_dim = None # 表示content是多少维的list | |||
self.set_padder(padder) | |||
self._is_target = None | |||
self._is_input = None | |||
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array | |||
self.BASIC_TYPES = (int, float, str, np.ndarray) | |||
self.is_2d_list = False | |||
self.pytype = None # int, float, str, or np.ndarray | |||
self.dtype = None # np.int64, np.float64, np.str | |||
self.pytype = None | |||
self.dtype = None | |||
self._is_input = None | |||
self._is_target = None | |||
if is_input is not None: | |||
if is_input is not None or is_target is not None: | |||
self.is_input = is_input | |||
if is_target is not None: | |||
self.is_target = is_target | |||
def _set_dtype(self): | |||
self.pytype = self._type_detection(self.content) | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
@property | |||
def is_input(self): | |||
return self._is_input | |||
@is_input.setter | |||
def is_input(self, value): | |||
""" | |||
当 field_array.is_input = True / False 时被调用 | |||
""" | |||
if value is True: | |||
self.pytype = self._type_detection(self.content) | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
self._set_dtype() | |||
self._is_input = value | |||
@property | |||
@@ -140,46 +164,99 @@ class FieldArray(object): | |||
@is_target.setter | |||
def is_target(self, value): | |||
""" | |||
当 field_array.is_target = True / False 时被调用 | |||
""" | |||
if value is True: | |||
self.pytype = self._type_detection(self.content) | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
self._set_dtype() | |||
self._is_target = value | |||
def _type_detection(self, content): | |||
""" | |||
:param content: a list of int, float, str or np.ndarray, or a list of list of one. | |||
:return type: one of int, float, str, np.ndarray | |||
"""当该field被设置为is_input或者is_target时被调用 | |||
""" | |||
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): | |||
# content is a 2-D list | |||
if not all(isinstance(_, list) for _ in content): # strict check 2-D list | |||
raise TypeError("Please provide 2-D list.") | |||
type_set = set([self._type_detection(x) for x in content]) | |||
if len(type_set) == 2 and int in type_set and float in type_set: | |||
type_set = {float} | |||
elif len(type_set) > 1: | |||
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) | |||
self.is_2d_list = True | |||
if len(content) == 0: | |||
raise RuntimeError("Empty list in Field {}.".format(self.name)) | |||
type_set = set([type(item) for item in content]) | |||
if list in type_set: | |||
if len(type_set) > 1: | |||
# list 跟 非list 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
# >1维list | |||
inner_type_set = set() | |||
for l in content: | |||
[inner_type_set.add(type(obj)) for obj in l] | |||
if list not in inner_type_set: | |||
# 二维list | |||
self.content_dim = 2 | |||
return self._basic_type_detection(inner_type_set) | |||
else: | |||
if len(inner_type_set) == 1: | |||
# >2维list | |||
inner_inner_type_set = set() | |||
for _2d_list in content: | |||
for _1d_list in _2d_list: | |||
[inner_inner_type_set.add(type(obj)) for obj in _1d_list] | |||
if list in inner_inner_type_set: | |||
raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") | |||
# 3维list | |||
self.content_dim = 3 | |||
return self._basic_type_detection(inner_inner_type_set) | |||
else: | |||
# list 跟 非list 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) | |||
else: | |||
# 一维list | |||
for content_type in type_set: | |||
if content_type not in self.BASIC_TYPES: | |||
raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( | |||
self.name, self.BASIC_TYPES, content_type)) | |||
self.content_dim = 1 | |||
return self._basic_type_detection(type_set) | |||
def _basic_type_detection(self, type_set): | |||
""" | |||
:param type_set: a set of Python types | |||
:return: one of self.BASIC_TYPES | |||
""" | |||
if len(type_set) == 1: | |||
return type_set.pop() | |||
elif isinstance(content, list): | |||
# content is a 1-D list | |||
if len(content) == 0: | |||
# the old error is not informative enough. | |||
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") | |||
type_set = set([type(item) for item in content]) | |||
if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: | |||
return type_set.pop() | |||
elif len(type_set) == 2 and float in type_set and int in type_set: | |||
elif len(type_set) == 2: | |||
# 有多个basic type; 可能需要up-cast | |||
if float in type_set and int in type_set: | |||
# up-cast int to float | |||
return float | |||
else: | |||
raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) | |||
# str 跟 int 或者 float 混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
else: | |||
raise TypeError("Cannot create FieldArray with type {}".format(type(content))) | |||
# str, int, float混在一起 | |||
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
def _1d_list_check(self, val): | |||
"""如果不是1D list就报错 | |||
""" | |||
type_set = set((type(obj) for obj in val)) | |||
if any(obj not in self.BASIC_TYPES for obj in type_set): | |||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
self._basic_type_detection(type_set) | |||
# otherwise: _basic_type_detection will raise error | |||
return True | |||
def _2d_list_check(self, val): | |||
"""如果不是2D list 就报错 | |||
""" | |||
type_set = set(type(obj) for obj in val) | |||
if list(type_set) != [list]: | |||
raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) | |||
inner_type_set = set() | |||
for l in val: | |||
for obj in l: | |||
inner_type_set.add(type(obj)) | |||
self._basic_type_detection(inner_type_set) | |||
return True | |||
@staticmethod | |||
def _map_to_np_type(basic_type): | |||
@@ -194,38 +271,39 @@ class FieldArray(object): | |||
:param val: int, float, str, or a list of one. | |||
""" | |||
if self.is_target is True or self.is_input is True: | |||
# only check type when used as target or input | |||
if isinstance(val, list): | |||
pass | |||
elif isinstance(val, tuple): # 确保最外层是list | |||
val = list(val) | |||
elif isinstance(val, np.ndarray): | |||
val = val.tolist() | |||
elif any((isinstance(val, t) for t in self.BASIC_TYPES)): | |||
pass | |||
else: | |||
raise RuntimeError( | |||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||
val_type = type(val) | |||
if val_type == list: # shape check | |||
if self.is_2d_list is False: | |||
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") | |||
if self.is_input is True or self.is_target is True: | |||
if type(val) == list: | |||
if len(val) == 0: | |||
raise RuntimeError("Cannot append an empty list.") | |||
val_list_type = set([type(_) for _ in val]) # type check | |||
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: | |||
# up-cast int to float | |||
val_type = float | |||
elif len(val_list_type) == 1: | |||
val_type = val_list_type.pop() | |||
raise ValueError("Cannot append an empty list.") | |||
if self.content_dim == 2 and self._1d_list_check(val): | |||
# 1维list检查 | |||
pass | |||
elif self.content_dim == 3 and self._2d_list_check(val): | |||
# 2维list检查 | |||
pass | |||
else: | |||
raise TypeError("Cannot append a list of {}".format(val_list_type)) | |||
else: | |||
if self.is_2d_list is True: | |||
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") | |||
if val_type == float and self.pytype == int: | |||
# up-cast | |||
self.pytype = float | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
elif val_type == int and self.pytype == float: | |||
pass | |||
elif val_type == self.pytype: | |||
pass | |||
raise RuntimeError( | |||
"Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) | |||
elif type(val) in self.BASIC_TYPES and self.content_dim == 1: | |||
# scalar检查 | |||
if type(val) == float and self.pytype == int: | |||
self.pytype = float | |||
self.dtype = self._map_to_np_type(self.pytype) | |||
else: | |||
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) | |||
raise RuntimeError( | |||
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) | |||
self.content.append(val) | |||
def __getitem__(self, indices): | |||
@@ -11,6 +11,10 @@ class Instance(object): | |||
""" | |||
def __init__(self, **fields): | |||
""" | |||
:param fields: 可能是一维或者二维的 list or np.array | |||
""" | |||
self.fields = fields | |||
def add_field(self, field_name, field): | |||
@@ -32,5 +36,5 @@ class Instance(object): | |||
def __repr__(self): | |||
s = '\'' | |||
return "{" + ",\n".join( | |||
"\'" + field_name + "\': " + str(self.fields[field_name]) +\ | |||
"\'" + field_name + "\': " + str(self.fields[field_name]) + \ | |||
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" |
@@ -858,9 +858,22 @@ class ConllPOSReader(object): | |||
ds.append(Instance(words=char_seq, | |||
tag=pos_seq)) | |||
return ds | |||
def get_one(self, sample): | |||
if len(sample) == 0: | |||
return None | |||
text = [] | |||
pos_tags = [] | |||
for w in sample: | |||
t1, t2, t3, t4 = w[1], w[3], w[6], w[7] | |||
if t3 == '_': | |||
return None | |||
text.append(t1) | |||
pos_tags.append(t2) | |||
return text, pos_tags | |||
class ConllxDataLoader(object): | |||
def load(self, path): | |||
@@ -879,7 +892,12 @@ class ConllxDataLoader(object): | |||
datalist.append(sample) | |||
data = [self.get_one(sample) for sample in datalist] | |||
return list(filter(lambda x: x is not None, data)) | |||
data_list = list(filter(lambda x: x is not None, data)) | |||
ds = DataSet() | |||
for example in data_list: | |||
ds.append(Instance(words=example[0], tag=example[1])) | |||
return ds | |||
def get_one(self, sample): | |||
sample = list(map(list, zip(*sample))) | |||
@@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' | |||
[model] | |||
rnn_hidden_units = 300 | |||
word_emb_dim = 100 | |||
word_emb_dim = 300 | |||
dropout = 0.5 | |||
use_crf = true | |||
print_every_step = 10 | |||
@@ -8,16 +8,16 @@ import torch | |||
# in order to run fastNLP without installation | |||
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) | |||
from fastNLP.api.pipeline import Pipeline | |||
from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor | |||
from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor, SetInputProcessor, IndexerProcessor | |||
from fastNLP.core.metrics import SpanFPreRecMetric | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP.io.config_io import ConfigLoader, ConfigSection | |||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||
from fastNLP.io.dataset_loader import ZhConllPOSReader | |||
from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader | |||
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor | |||
cfgfile = './pos_tag.cfg' | |||
pickle_path = "save" | |||
@@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id): | |||
return embedding_tensor | |||
def train(checkpoint=None): | |||
def train(train_data_path, dev_data_path, checkpoint=None): | |||
# load config | |||
train_param = ConfigSection() | |||
model_param = ConfigSection() | |||
@@ -43,24 +43,36 @@ def train(checkpoint=None): | |||
print("config loaded") | |||
# Data Loader | |||
dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") | |||
print("loading training set...") | |||
dataset = ConllxDataLoader().load(train_data_path) | |||
print("loading dev set...") | |||
dev_data = ConllxDataLoader().load(dev_data_path) | |||
print(dataset) | |||
print("dataset transformed") | |||
print("================= dataset ready =====================") | |||
dataset.rename_field("tag", "truth") | |||
dev_data.rename_field("tag", "truth") | |||
vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") | |||
tag_proc = VocabIndexerProcessor("truth") | |||
seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) | |||
set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth") | |||
vocab_proc(dataset) | |||
tag_proc(dataset) | |||
seq_len_proc(dataset) | |||
# index dev set | |||
word_vocab, tag_vocab = vocab_proc.vocab, tag_proc.vocab | |||
dev_data.apply(lambda ins: [word_vocab.to_index(w) for w in ins["words"]], new_field_name="word_seq") | |||
dev_data.apply(lambda ins: [tag_vocab.to_index(w) for w in ins["truth"]], new_field_name="truth") | |||
dev_data.apply(lambda ins: len(ins["word_seq"]), new_field_name="word_seq_origin_len") | |||
# set input & target | |||
dataset.set_input("word_seq", "word_seq_origin_len", "truth") | |||
dev_data.set_input("word_seq", "word_seq_origin_len", "truth") | |||
dataset.set_target("truth", "word_seq_origin_len") | |||
print("processors defined") | |||
dev_data.set_target("truth", "word_seq_origin_len") | |||
# dataset.set_is_target(tag_ids=True) | |||
model_param["vocab_size"] = vocab_proc.get_vocab_size() | |||
@@ -71,7 +83,7 @@ def train(checkpoint=None): | |||
if checkpoint is None: | |||
# pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) | |||
pre_trained = None | |||
model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) | |||
model = AdvSeqLabel(model_param, id2words=None, emb=pre_trained) | |||
print(model) | |||
else: | |||
model = torch.load(checkpoint) | |||
@@ -80,33 +92,71 @@ def train(checkpoint=None): | |||
trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", | |||
target="truth", | |||
seq_lens="word_seq_origin_len"), | |||
dev_data=dataset, metric_key="f", | |||
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") | |||
dev_data=dev_data, metric_key="f", | |||
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save_0") | |||
trainer.train(load_best_model=True) | |||
# save model & pipeline | |||
model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") | |||
id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") | |||
pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) | |||
pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) | |||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | |||
torch.save(save_dict, "model_pp.pkl") | |||
print("pipeline saved") | |||
torch.save(model, "./save/best_model.pkl") | |||
def run_test(test_path): | |||
test_data = ZhConllPOSReader().load(test_path) | |||
with open("model_pp.pkl", "rb") as f: | |||
save_dict = torch.load(f) | |||
tag_vocab = save_dict["tag_vocab"] | |||
pipeline = save_dict["pipeline"] | |||
index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) | |||
pipeline.pipeline = [index_tag] + pipeline.pipeline | |||
pipeline(test_data) | |||
test_data.set_target("truth") | |||
prediction = test_data.field_arrays["predict"].content | |||
truth = test_data.field_arrays["truth"].content | |||
seq_len = test_data.field_arrays["word_seq_origin_len"].content | |||
# padding by hand | |||
max_length = max([len(seq) for seq in prediction]) | |||
for idx in range(len(prediction)): | |||
prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) | |||
truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) | |||
evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", | |||
seq_lens="word_seq_origin_len") | |||
evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, | |||
{"truth": torch.Tensor(truth)}) | |||
test_result = evaluator.get_metric() | |||
f1 = round(test_result['f'] * 100, 2) | |||
pre = round(test_result['pre'] * 100, 2) | |||
rec = round(test_result['rec'] * 100, 2) | |||
return {"F1": f1, "precision": pre, "recall": rec} | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument("--train", type=str, help="training conll file", default="/home/zyfeng/data/sample.conllx") | |||
parser.add_argument("--dev", type=str, help="dev conll file", default="/home/zyfeng/data/sample.conllx") | |||
parser.add_argument("--test", type=str, help="test conll file", default=None) | |||
parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") | |||
parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") | |||
args = parser.parse_args() | |||
if args.restart is True: | |||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||
if args.checkpoint is None: | |||
raise RuntimeError("Please provide the checkpoint. -cp ") | |||
train(args.checkpoint) | |||
if args.test is not None: | |||
print(run_test(args.test)) | |||
else: | |||
# 一次训练 python train_pos_tag.py | |||
train() | |||
if args.restart is True: | |||
# 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl | |||
if args.checkpoint is None: | |||
raise RuntimeError("Please provide the checkpoint. -cp ") | |||
train(args.train, args.dev, args.checkpoint) | |||
else: | |||
# 一次训练 python train_pos_tag.py | |||
train(args.train, args.dev) |
@@ -89,3 +89,12 @@ class TestCase1(unittest.TestCase): | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
def test_list_of_numpy_to_tensor(self): | |||
ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + | |||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
print(x, y) |
@@ -6,15 +6,29 @@ from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
class TestDataSet(unittest.TestCase): | |||
class TestDataSetInit(unittest.TestCase): | |||
"""初始化DataSet的办法有以下几种: | |||
1) 用dict: | |||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||
2) 用list of Instance: | |||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||
只接受纯list或者最外层ndarray | |||
""" | |||
def test_init_v1(self): | |||
# 一维list | |||
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
def test_init_v2(self): | |||
# 用dict | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
@@ -28,6 +42,8 @@ class TestDataSet(unittest.TestCase): | |||
with self.assertRaises(ValueError): | |||
_ = DataSet(0.00001) | |||
class TestDataSetMethods(unittest.TestCase): | |||
def test_append(self): | |||
dd = DataSet() | |||
for _ in range(3): | |||
@@ -42,13 +42,13 @@ class TestFieldArray(unittest.TestCase): | |||
self.assertEqual(fa.pytype, str) | |||
def test_support_np_array(self): | |||
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True) | |||
self.assertEqual(fa.dtype, np.ndarray) | |||
self.assertEqual(fa.pytype, np.ndarray) | |||
fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) | |||
self.assertEqual(fa.dtype, np.float64) | |||
self.assertEqual(fa.pytype, float) | |||
fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) | |||
self.assertEqual(fa.dtype, np.ndarray) | |||
self.assertEqual(fa.pytype, np.ndarray) | |||
self.assertEqual(fa.dtype, np.float64) | |||
self.assertEqual(fa.pytype, float) | |||
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) | |||
# in this case, pytype is actually a float. We do not care about it. | |||
@@ -1,8 +1,8 @@ | |||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||
import fastNLP | |||
import unittest | |||
import fastNLP | |||
from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric | |||
data_file = """ | |||
1 The _ DET DT _ 3 det _ _ | |||
2 new _ ADJ JJ _ 3 amod _ _ | |||
@@ -41,6 +41,7 @@ data_file = """ | |||
""" | |||
def init_data(): | |||
ds = fastNLP.DataSet() | |||
v = {'word_seq': fastNLP.Vocabulary(), | |||
@@ -60,18 +61,19 @@ def init_data(): | |||
data.append(line) | |||
for name in ['word_seq', 'pos_seq', 'label_true']: | |||
ds.apply(lambda x: ['<st>']+list(x[name]), new_field_name=name) | |||
ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name) | |||
ds.apply(lambda x: v[name].add_word_lst(x[name])) | |||
for name in ['word_seq', 'pos_seq', 'label_true']: | |||
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | |||
ds.apply(lambda x: [0]+list(map(int, x['arc_true'])), new_field_name='arc_true') | |||
ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') | |||
ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | |||
ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | |||
ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | |||
return ds, v['word_seq'], v['pos_seq'], v['label_true'] | |||
class TestBiaffineParser(unittest.TestCase): | |||
def test_train(self): | |||
ds, v1, v2, v3 = init_data() | |||
@@ -84,5 +86,6 @@ class TestBiaffineParser(unittest.TestCase): | |||
n_epochs=10, use_cuda=False, use_tqdm=False) | |||
trainer.train(load_best_model=False) | |||
if __name__ == '__main__': | |||
unittest.main() | |||
unittest.main() |
@@ -1,91 +0,0 @@ | |||
import unittest | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import Tester | |||
from fastNLP import Vocabulary | |||
from fastNLP.core.losses import CrossEntropyLoss | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.models import CNNText | |||
class TestTutorial(unittest.TestCase): | |||
def test_tutorial(self): | |||
# 从csv读取数据到DataSet | |||
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" | |||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||
sep='\t') | |||
print(len(dataset)) | |||
print(dataset[0]) | |||
dataset.append(Instance(raw_sentence='fake data', label='0')) | |||
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||
# label转int | |||
dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||
# 使用空格分割句子 | |||
def split_sent(ins): | |||
return ins['raw_sentence'].split() | |||
dataset.apply(split_sent, new_field_name='words') | |||
# 增加长度信息 | |||
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||
print(len(dataset)) | |||
print(dataset[0]) | |||
# DataSet.drop(func)筛除数据 | |||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||
print(len(dataset)) | |||
# 设置DataSet中,哪些field要转为tensor | |||
# set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||
dataset.set_target("label") | |||
# set input,模型forward时使用 | |||
dataset.set_input("words") | |||
# 分出测试集、训练集 | |||
test_data, train_data = dataset.split(0.5) | |||
print(len(test_data)) | |||
print(len(train_data)) | |||
# 构建词表, Vocabulary.add(word) | |||
vocab = Vocabulary(min_freq=2) | |||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||
vocab.build_vocab() | |||
# index句子, Vocabulary.to_index(word) | |||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
print(test_data[0]) | |||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
from fastNLP import Trainer | |||
from copy import deepcopy | |||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||
train_data.rename_field('label', 'label_seq') | |||
test_data.rename_field('words', 'word_seq') | |||
test_data.rename_field('label', 'label_seq') | |||
# 实例化Trainer,传入模型和数据,进行训练 | |||
copy_model = deepcopy(model) | |||
overfit_trainer = Trainer(train_data=test_data, model=copy_model, | |||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||
dev_data=test_data, save_path="./save") | |||
overfit_trainer.train() | |||
trainer = Trainer(train_data=train_data, model=model, | |||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, | |||
dev_data=test_data, save_path="./save") | |||
trainer.train() | |||
print('Train finished!') | |||
# 使用fastNLP的Tester测试脚本 | |||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
batch_size=4) | |||
acc = tester.test() | |||
print(acc) |
@@ -0,0 +1,432 @@ | |||
import unittest | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import Vocabulary | |||
from fastNLP.core.losses import CrossEntropyLoss | |||
from fastNLP.core.metrics import AccuracyMetric | |||
class TestTutorial(unittest.TestCase): | |||
def test_fastnlp_10min_tutorial(self): | |||
# 从csv读取数据到DataSet | |||
sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), | |||
sep='\t') | |||
print(len(dataset)) | |||
print(dataset[0]) | |||
print(dataset[-3]) | |||
dataset.append(Instance(raw_sentence='fake data', label='0')) | |||
# 将所有数字转为小写 | |||
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||
# label转int | |||
dataset.apply(lambda x: int(x['label']), new_field_name='label') | |||
# 使用空格分割句子 | |||
def split_sent(ins): | |||
return ins['raw_sentence'].split() | |||
dataset.apply(split_sent, new_field_name='words') | |||
# 增加长度信息 | |||
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') | |||
print(len(dataset)) | |||
print(dataset[0]) | |||
# DataSet.drop(func)筛除数据 | |||
dataset.drop(lambda x: x['seq_len'] <= 3) | |||
print(len(dataset)) | |||
# 设置DataSet中,哪些field要转为tensor | |||
# set target,loss或evaluate中的golden,计算loss,模型评估时使用 | |||
dataset.set_target("label") | |||
# set input,模型forward时使用 | |||
dataset.set_input("words", "seq_len") | |||
# 分出测试集、训练集 | |||
test_data, train_data = dataset.split(0.5) | |||
print(len(test_data)) | |||
print(len(train_data)) | |||
# 构建词表, Vocabulary.add(word) | |||
vocab = Vocabulary(min_freq=2) | |||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||
vocab.build_vocab() | |||
# index句子, Vocabulary.to_index(word) | |||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') | |||
print(test_data[0]) | |||
# 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import RandomSampler | |||
batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) | |||
for batch_x, batch_y in batch_iterator: | |||
print("batch_x has: ", batch_x) | |||
print("batch_y has: ", batch_y) | |||
break | |||
from fastNLP.models import CNNText | |||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
from fastNLP import Trainer | |||
from copy import deepcopy | |||
# 更改DataSet中对应field的名称,要以模型的forward等参数名一致 | |||
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 | |||
train_data.rename_field('label', 'label_seq') | |||
test_data.rename_field('words', 'word_seq') | |||
test_data.rename_field('label', 'label_seq') | |||
loss = CrossEntropyLoss(pred="output", target="label_seq") | |||
metric = AccuracyMetric(pred="predict", target="label_seq") | |||
# 实例化Trainer,传入模型和数据,进行训练 | |||
# 先在test_data拟合(确保模型的实现是正确的) | |||
copy_model = deepcopy(model) | |||
overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, | |||
loss=loss, | |||
metrics=metric, | |||
save_path=None, | |||
batch_size=32, | |||
n_epochs=5) | |||
overfit_trainer.train() | |||
# 用train_data训练,在test_data验证 | |||
trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, | |||
loss=CrossEntropyLoss(pred="output", target="label_seq"), | |||
metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
save_path=None, | |||
batch_size=32, | |||
n_epochs=5) | |||
trainer.train() | |||
print('Train finished!') | |||
# 调用Tester在test_data上评价效果 | |||
from fastNLP import Tester | |||
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), | |||
batch_size=4) | |||
acc = tester.test() | |||
print(acc) | |||
def test_fastnlp_1min_tutorial(self): | |||
# tutorials/fastnlp_1min_tutorial.ipynb | |||
data_path = "tutorials/sample_data/tutorial_sample_dataset.csv" | |||
ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t') | |||
print(ds[1]) | |||
# 将所有数字转为小写 | |||
ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') | |||
# label转int | |||
ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True) | |||
def split_sent(ins): | |||
return ins['raw_sentence'].split() | |||
ds.apply(split_sent, new_field_name='words', is_input=True) | |||
# 分割训练集/验证集 | |||
train_data, dev_data = ds.split(0.3) | |||
print("Train size: ", len(train_data)) | |||
print("Test size: ", len(dev_data)) | |||
from fastNLP import Vocabulary | |||
vocab = Vocabulary(min_freq=2) | |||
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) | |||
# index句子, Vocabulary.to_index(word) | |||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', | |||
is_input=True) | |||
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', | |||
is_input=True) | |||
from fastNLP.models import CNNText | |||
model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric | |||
trainer = Trainer(model=model, | |||
train_data=train_data, | |||
dev_data=dev_data, | |||
loss=CrossEntropyLoss(), | |||
metrics=AccuracyMetric() | |||
) | |||
trainer.train() | |||
print('Train finished!') | |||
def test_fastnlp_advanced_tutorial(self): | |||
import os | |||
os.chdir("tutorials/fastnlp_advanced_tutorial") | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import Vocabulary | |||
from fastNLP import Trainer | |||
from fastNLP import Tester | |||
# ### Instance | |||
# Instance表示一个样本,由一个或者多个field(域、属性、特征)组成,每个field具有自己的名字以及值 | |||
# 在初始化Instance的时候可以定义它包含的field,使用"field_name=field_value"的写法 | |||
# In[2]: | |||
# 组织一个Instance,这个Instance由premise、hypothesis、label三个field组成 | |||
instance = Instance(premise='an premise example .', hypothesis='an hypothesis example.', label=1) | |||
instance | |||
# In[3]: | |||
data_set = DataSet([instance] * 5) | |||
data_set.append(instance) | |||
data_set[-2:] | |||
# In[4]: | |||
# 如果某一个field的类型与dataset对应的field类型不一样仍可被加入dataset中 | |||
instance2 = Instance(premise='the second premise example .', hypothesis='the second hypothesis example.', | |||
label='1') | |||
try: | |||
data_set.append(instance2) | |||
except: | |||
pass | |||
data_set[-2:] | |||
# In[5]: | |||
# 如果某一个field的名字不对,则该instance不能被append到dataset中 | |||
instance3 = Instance(premises='the third premise example .', hypothesis='the third hypothesis example.', | |||
label=1) | |||
try: | |||
data_set.append(instance3) | |||
except: | |||
print('cannot append instance') | |||
pass | |||
data_set[-2:] | |||
# In[6]: | |||
# 除了文本以外,还可以将tensor作为其中一个field的value | |||
import torch | |||
tensor_ins = Instance(image=torch.randn(5, 5), label=0) | |||
ds = DataSet() | |||
ds.append(tensor_ins) | |||
ds | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
# 从csv读取数据到DataSet | |||
# 类csv文件,即每一行为一个example的文件,都可以使用这种方法进行数据读取 | |||
dataset = DataSet.read_csv('tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\t') | |||
# 查看DataSet的大小 | |||
len(dataset) | |||
# In[8]: | |||
# 使用数字索引[k],获取第k个样本 | |||
dataset[0] | |||
# In[9]: | |||
# 获取的样本是一个Instance | |||
type(dataset[0]) | |||
# In[10]: | |||
# 使用数字索引[a: b],获取第a到第b个样本 | |||
dataset[0: 3] | |||
# In[11]: | |||
# 索引也可以是负数 | |||
dataset[-1] | |||
data_path = ['premise', 'hypothesis', 'label'] | |||
# 读入文件 | |||
with open(data_path[0]) as f: | |||
premise = f.readlines() | |||
with open(data_path[1]) as f: | |||
hypothesis = f.readlines() | |||
with open(data_path[2]) as f: | |||
label = f.readlines() | |||
assert len(premise) == len(hypothesis) and len(hypothesis) == len(label) | |||
# 组织DataSet | |||
data_set = DataSet() | |||
for p, h, l in zip(premise, hypothesis, label): | |||
p = p.strip() # 将行末空格去除 | |||
h = h.strip() # 将行末空格去除 | |||
data_set.append(Instance(premise=p, hypothesis=h, truth=l)) | |||
data_set[0] | |||
# ### DataSet的其他操作 | |||
# 在构建完毕DataSet后,仍然可以对DataSet的内容进行操作,函数接口为DataSet.apply() | |||
# In[13]: | |||
# 将premise域的所有文本转成小写 | |||
data_set.apply(lambda x: x['premise'].lower(), new_field_name='premise') | |||
data_set[-2:] | |||
# In[14]: | |||
# label转int | |||
data_set.apply(lambda x: int(x['truth']), new_field_name='truth') | |||
data_set[-2:] | |||
# In[15]: | |||
# 使用空格分割句子 | |||
def split_sent(ins): | |||
return ins['premise'].split() | |||
data_set.apply(split_sent, new_field_name='premise') | |||
data_set.apply(lambda x: x['hypothesis'].split(), new_field_name='hypothesis') | |||
data_set[-2:] | |||
# In[16]: | |||
# 筛选数据 | |||
origin_data_set_len = len(data_set) | |||
data_set.drop(lambda x: len(x['premise']) <= 6) | |||
origin_data_set_len, len(data_set) | |||
# In[17]: | |||
# 增加长度信息 | |||
data_set.apply(lambda x: [1] * len(x['premise']), new_field_name='premise_len') | |||
data_set.apply(lambda x: [1] * len(x['hypothesis']), new_field_name='hypothesis_len') | |||
data_set[-1] | |||
# In[18]: | |||
# 设定特征域、标签域 | |||
data_set.set_input("premise", "premise_len", "hypothesis", "hypothesis_len") | |||
data_set.set_target("truth") | |||
# In[19]: | |||
# 重命名field | |||
data_set.rename_field('truth', 'label') | |||
data_set[-1] | |||
# In[20]: | |||
# 切分训练、验证集、测试集 | |||
train_data, vad_data = data_set.split(0.5) | |||
dev_data, test_data = vad_data.split(0.4) | |||
len(train_data), len(dev_data), len(test_data) | |||
# In[21]: | |||
# 深拷贝一个数据集 | |||
import copy | |||
train_data_2, dev_data_2 = copy.deepcopy(train_data), copy.deepcopy(dev_data) | |||
del copy | |||
# 初始化词表,该词表最大的vocab_size为10000,词表中每个词出现的最低频率为2,'<unk>'表示未知词语,'<pad>'表示padding词语 | |||
# Vocabulary默认初始化参数为max_size=None, min_freq=None, unknown='<unk>', padding='<pad>' | |||
vocab = Vocabulary(max_size=10000, min_freq=2, unknown='<unk>', padding='<pad>') | |||
# 构建词表 | |||
train_data.apply(lambda x: [vocab.add(word) for word in x['premise']]) | |||
train_data.apply(lambda x: [vocab.add(word) for word in x['hypothesis']]) | |||
vocab.build_vocab() | |||
# In[23]: | |||
# 根据词表index句子 | |||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||
train_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') | |||
test_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
train_data[-1], dev_data[-1], test_data[-1] | |||
# 读入vocab文件 | |||
with open('vocab.txt') as f: | |||
lines = f.readlines() | |||
vocabs = [] | |||
for line in lines: | |||
vocabs.append(line.strip()) | |||
# 实例化Vocabulary | |||
vocab_bert = Vocabulary(unknown=None, padding=None) | |||
# 将vocabs列表加入Vocabulary | |||
vocab_bert.add_word_lst(vocabs) | |||
# 构建词表 | |||
vocab_bert.build_vocab() | |||
# 更新unknown与padding的token文本 | |||
vocab_bert.unknown = '[UNK]' | |||
vocab_bert.padding = '[PAD]' | |||
# In[25]: | |||
# 根据词表index句子 | |||
train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') | |||
train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], | |||
new_field_name='hypothesis') | |||
dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') | |||
dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') | |||
train_data_2[-1], dev_data_2[-1] | |||
# step 1:加载模型参数(非必选) | |||
from fastNLP.io.config_io import ConfigSection, ConfigLoader | |||
args = ConfigSection() | |||
ConfigLoader().load_config("./data/config", {"esim_model": args}) | |||
args["vocab_size"] = len(vocab) | |||
args.data | |||
# In[27]: | |||
# step 2:加载ESIM模型 | |||
from fastNLP.models import ESIM | |||
model = ESIM(**args.data) | |||
model | |||
# In[28]: | |||
# 另一个例子:加载CNN文本分类模型 | |||
from fastNLP.models import CNNText | |||
cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) | |||
cnn_text_model | |||
from fastNLP import CrossEntropyLoss | |||
from fastNLP import Adam | |||
from fastNLP import AccuracyMetric | |||
trainer = Trainer( | |||
train_data=train_data, | |||
model=model, | |||
loss=CrossEntropyLoss(pred='pred', target='label'), | |||
metrics=AccuracyMetric(), | |||
n_epochs=5, | |||
batch_size=16, | |||
print_every=-1, | |||
validate_every=-1, | |||
dev_data=dev_data, | |||
use_cuda=True, | |||
optimizer=Adam(lr=1e-3, weight_decay=0), | |||
check_code_level=-1, | |||
metric_key='acc', | |||
use_tqdm=False, | |||
) | |||
trainer.train() | |||
tester = Tester( | |||
data=test_data, | |||
model=model, | |||
metrics=AccuracyMetric(), | |||
batch_size=args["batch_size"], | |||
) | |||
tester.test() | |||
os.chdir("../..") |