From e4f997d52a733c67d62392056eb01924519c2837 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Thu, 17 Jan 2019 12:25:37 +0800 Subject: [PATCH] =?UTF-8?q?refactor=20type=20system=20in=20FieldArray:=20*?= =?UTF-8?q?=20=E9=87=8D=E6=9E=84dtype=E7=9A=84=E6=A3=80=E6=B5=8B=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=EF=BC=8C=E5=9C=A8FieldArray=E7=9A=84=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E5=92=8Cappend=E4=B8=A4=E5=A4=84=EF=BC=8C?= =?UTF-8?q?=E8=BE=BE=E5=88=B0=E6=9B=B4=E5=A5=BD=E7=9A=84=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E5=A4=8D=E7=94=A8=20*=20=E7=B1=BB=E5=9E=8B=E6=A3=80=E6=B5=8B?= =?UTF-8?q?=E7=9A=84=E8=B4=A3=E4=BB=BB=E5=AE=8C=E5=85=A8=E8=90=BD=E5=9C=A8?= =?UTF-8?q?FieldArray=EF=BC=8CDataSet=E4=B8=8E=E4=B9=8B=E9=85=8D=E5=90=88?= =?UTF-8?q?=20=E6=B5=8B=E8=AF=95=EF=BC=9A=20*=20=E6=95=B4=E7=90=86dtype?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E7=9A=84=E6=B5=8B=E8=AF=95=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=20*=20=E7=BB=99=E6=89=80=E6=9C=89tutorial=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=20=E5=85=B6=E4=BB=96=EF=BC=9A=20*=20?= =?UTF-8?q?=E5=AE=8C=E5=96=84=E4=B8=80=E4=B8=AA=E5=AE=8C=E6=95=B4=E7=9A=84?= =?UTF-8?q?Conll=20dataset=20loader=20*=20=E5=8D=87=E7=BA=A7POS=20tag=20mo?= =?UTF-8?q?del=E8=AE=AD=E7=BB=83=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 16 +- fastNLP/core/fieldarray.py | 214 +++++++---- fastNLP/core/instance.py | 6 +- fastNLP/io/dataset_loader.py | 22 +- reproduction/POS_tagging/pos_tag.cfg | 2 +- reproduction/POS_tagging/train_pos_tag.py | 90 ++++- test/core/test_batch.py | 9 + test/core/test_dataset.py | 20 +- test/core/test_fieldarray.py | 10 +- test/models/test_biaffine_parser.py | 15 +- test/test_tutorial.py | 91 ----- test/test_tutorials.py | 432 ++++++++++++++++++++++ 12 files changed, 725 insertions(+), 202 deletions(-) delete mode 100644 test/test_tutorial.py create mode 100644 test/test_tutorials.py diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 2dba3267..f4e64c5d 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -2,8 +2,8 @@ import _pickle as pickle import numpy as np -from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import AutoPadder +from fastNLP.core.fieldarray import FieldArray from fastNLP.core.instance import Instance from fastNLP.core.utils import get_func_signature from fastNLP.io.base_loader import DataLoaderRegister @@ -142,7 +142,8 @@ class DataSet(object): if len(self.field_arrays) == 0: # DataSet has no field yet for name, field in ins.fields.items(): - self.field_arrays[name] = FieldArray(name, [field]) + field = field.tolist() if isinstance(field, np.ndarray) else field + self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 else: if len(self.field_arrays) != len(ins.fields): raise ValueError( @@ -290,9 +291,11 @@ class DataSet(object): extra_param['is_input'] = old_field.is_input if 'is_target' not in extra_param: extra_param['is_target'] = old_field.is_target - self.add_field(name=new_field_name, fields=results) + self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], + is_target=extra_param["is_target"]) else: - self.add_field(name=new_field_name, fields=results) + self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), + is_target=extra_param.get("is_target", None)) else: return results @@ -334,13 +337,14 @@ class DataSet(object): train_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder train_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype train_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype - train_set.field_arrays[field_name].is_2d_list = self.field_arrays[field_name].is_2d_list + train_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim + dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target dev_set.field_arrays[field_name].padder = self.field_arrays[field_name].padder dev_set.field_arrays[field_name].dtype = self.field_arrays[field_name].dtype dev_set.field_arrays[field_name].pytype = self.field_arrays[field_name].pytype - dev_set.field_arrays[field_name].is_2d_list = self.field_arrays[field_name].is_2d_list + dev_set.field_arrays[field_name].content_dim = self.field_arrays[field_name].content_dim return train_set, dev_set diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index afb81697..4cde86ab 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -100,6 +100,22 @@ class FieldArray(object): """ def __init__(self, name, content, is_target=None, is_input=None, padder=AutoPadder(pad_val=0)): + """DataSet在初始化时会有两类方法对FieldArray操作: + 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: + 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) + 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) + 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) + 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; + 然后后面的样本使用FieldArray.append进行添加。 + 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) + 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) + 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) + 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) + + 注意:np.array必须仅在最外层,即np.array([np.array, np.array]) 和 list of np.array不考虑 + 类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 + + """ self.name = name if isinstance(content, list): content = content @@ -107,31 +123,39 @@ class FieldArray(object): content = content.tolist() # convert np.ndarray into 2-D list else: raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) - self.content = content + if len(content) == 0: + raise RuntimeError("Cannot initialize FieldArray with empty list.") + + self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 + self.content_dim = None # 表示content是多少维的list self.set_padder(padder) - self._is_target = None - self._is_input = None + self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array - self.BASIC_TYPES = (int, float, str, np.ndarray) - self.is_2d_list = False - self.pytype = None # int, float, str, or np.ndarray - self.dtype = None # np.int64, np.float64, np.str + self.pytype = None + self.dtype = None + self._is_input = None + self._is_target = None - if is_input is not None: + if is_input is not None or is_target is not None: self.is_input = is_input - if is_target is not None: self.is_target = is_target + def _set_dtype(self): + self.pytype = self._type_detection(self.content) + self.dtype = self._map_to_np_type(self.pytype) + @property def is_input(self): return self._is_input @is_input.setter def is_input(self, value): + """ + 当 field_array.is_input = True / False 时被调用 + """ if value is True: - self.pytype = self._type_detection(self.content) - self.dtype = self._map_to_np_type(self.pytype) + self._set_dtype() self._is_input = value @property @@ -140,46 +164,99 @@ class FieldArray(object): @is_target.setter def is_target(self, value): + """ + 当 field_array.is_target = True / False 时被调用 + """ if value is True: - self.pytype = self._type_detection(self.content) - self.dtype = self._map_to_np_type(self.pytype) + self._set_dtype() self._is_target = value def _type_detection(self, content): - """ - - :param content: a list of int, float, str or np.ndarray, or a list of list of one. - :return type: one of int, float, str, np.ndarray + """当该field被设置为is_input或者is_target时被调用 """ - if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list): - # content is a 2-D list - if not all(isinstance(_, list) for _ in content): # strict check 2-D list - raise TypeError("Please provide 2-D list.") - type_set = set([self._type_detection(x) for x in content]) - if len(type_set) == 2 and int in type_set and float in type_set: - type_set = {float} - elif len(type_set) > 1: - raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set)) - self.is_2d_list = True + if len(content) == 0: + raise RuntimeError("Empty list in Field {}.".format(self.name)) + + type_set = set([type(item) for item in content]) + + if list in type_set: + if len(type_set) > 1: + # list 跟 非list 混在一起 + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + # >1维list + inner_type_set = set() + for l in content: + [inner_type_set.add(type(obj)) for obj in l] + if list not in inner_type_set: + # 二维list + self.content_dim = 2 + return self._basic_type_detection(inner_type_set) + else: + if len(inner_type_set) == 1: + # >2维list + inner_inner_type_set = set() + for _2d_list in content: + for _1d_list in _2d_list: + [inner_inner_type_set.add(type(obj)) for obj in _1d_list] + if list in inner_inner_type_set: + raise RuntimeError("FieldArray cannot handle 4-D or more-D list.") + # 3维list + self.content_dim = 3 + return self._basic_type_detection(inner_inner_type_set) + else: + # list 跟 非list 混在一起 + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, inner_type_set)) + else: + # 一维list + for content_type in type_set: + if content_type not in self.BASIC_TYPES: + raise RuntimeError("Unexpected data type in Field '{}'. Expect one of {}. Got {}.".format( + self.name, self.BASIC_TYPES, content_type)) + self.content_dim = 1 + return self._basic_type_detection(type_set) + + def _basic_type_detection(self, type_set): + """ + :param type_set: a set of Python types + :return: one of self.BASIC_TYPES + """ + if len(type_set) == 1: return type_set.pop() - - elif isinstance(content, list): - # content is a 1-D list - if len(content) == 0: - # the old error is not informative enough. - raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.") - type_set = set([type(item) for item in content]) - - if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES: - return type_set.pop() - elif len(type_set) == 2 and float in type_set and int in type_set: + elif len(type_set) == 2: + # 有多个basic type; 可能需要up-cast + if float in type_set and int in type_set: # up-cast int to float return float else: - raise TypeError("Cannot create FieldArray with type {}".format(*type_set)) + # str 跟 int 或者 float 混在一起 + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) else: - raise TypeError("Cannot create FieldArray with type {}".format(type(content))) + # str, int, float混在一起 + raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, type_set)) + + def _1d_list_check(self, val): + """如果不是1D list就报错 + """ + type_set = set((type(obj) for obj in val)) + if any(obj not in self.BASIC_TYPES for obj in type_set): + raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) + self._basic_type_detection(type_set) + # otherwise: _basic_type_detection will raise error + return True + + def _2d_list_check(self, val): + """如果不是2D list 就报错 + """ + type_set = set(type(obj) for obj in val) + if list(type_set) != [list]: + raise ValueError("Mixed data types in Field {}: {}".format(self.name, type_set)) + inner_type_set = set() + for l in val: + for obj in l: + inner_type_set.add(type(obj)) + self._basic_type_detection(inner_type_set) + return True @staticmethod def _map_to_np_type(basic_type): @@ -194,38 +271,39 @@ class FieldArray(object): :param val: int, float, str, or a list of one. """ - if self.is_target is True or self.is_input is True: - # only check type when used as target or input + if isinstance(val, list): + pass + elif isinstance(val, tuple): # 确保最外层是list + val = list(val) + elif isinstance(val, np.ndarray): + val = val.tolist() + elif any((isinstance(val, t) for t in self.BASIC_TYPES)): + pass + else: + raise RuntimeError( + "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) - val_type = type(val) - if val_type == list: # shape check - if self.is_2d_list is False: - raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.") + if self.is_input is True or self.is_target is True: + if type(val) == list: if len(val) == 0: - raise RuntimeError("Cannot append an empty list.") - val_list_type = set([type(_) for _ in val]) # type check - if len(val_list_type) == 2 and int in val_list_type and float in val_list_type: - # up-cast int to float - val_type = float - elif len(val_list_type) == 1: - val_type = val_list_type.pop() + raise ValueError("Cannot append an empty list.") + if self.content_dim == 2 and self._1d_list_check(val): + # 1维list检查 + pass + elif self.content_dim == 3 and self._2d_list_check(val): + # 2维list检查 + pass else: - raise TypeError("Cannot append a list of {}".format(val_list_type)) - else: - if self.is_2d_list is True: - raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.") - - if val_type == float and self.pytype == int: - # up-cast - self.pytype = float - self.dtype = self._map_to_np_type(self.pytype) - elif val_type == int and self.pytype == float: - pass - elif val_type == self.pytype: - pass + raise RuntimeError( + "Dimension not matched: expect dim={}, got {}.".format(self.content_dim - 1, val)) + elif type(val) in self.BASIC_TYPES and self.content_dim == 1: + # scalar检查 + if type(val) == float and self.pytype == int: + self.pytype = float + self.dtype = self._map_to_np_type(self.pytype) else: - raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype)) - + raise RuntimeError( + "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) self.content.append(val) def __getitem__(self, indices): diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index a102b51c..5ac52e3f 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -11,6 +11,10 @@ class Instance(object): """ def __init__(self, **fields): + """ + + :param fields: 可能是一维或者二维的 list or np.array + """ self.fields = fields def add_field(self, field_name, field): @@ -32,5 +36,5 @@ class Instance(object): def __repr__(self): s = '\'' return "{" + ",\n".join( - "\'" + field_name + "\': " + str(self.fields[field_name]) +\ + "\'" + field_name + "\': " + str(self.fields[field_name]) + \ f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 2d157da3..fb781c3e 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -858,9 +858,22 @@ class ConllPOSReader(object): ds.append(Instance(words=char_seq, tag=pos_seq)) - return ds + def get_one(self, sample): + if len(sample) == 0: + return None + text = [] + pos_tags = [] + for w in sample: + t1, t2, t3, t4 = w[1], w[3], w[6], w[7] + if t3 == '_': + return None + text.append(t1) + pos_tags.append(t2) + return text, pos_tags + + class ConllxDataLoader(object): def load(self, path): @@ -879,7 +892,12 @@ class ConllxDataLoader(object): datalist.append(sample) data = [self.get_one(sample) for sample in datalist] - return list(filter(lambda x: x is not None, data)) + data_list = list(filter(lambda x: x is not None, data)) + + ds = DataSet() + for example in data_list: + ds.append(Instance(words=example[0], tag=example[1])) + return ds def get_one(self, sample): sample = list(map(list, zip(*sample))) diff --git a/reproduction/POS_tagging/pos_tag.cfg b/reproduction/POS_tagging/pos_tag.cfg index c9ee8320..f8224234 100644 --- a/reproduction/POS_tagging/pos_tag.cfg +++ b/reproduction/POS_tagging/pos_tag.cfg @@ -10,7 +10,7 @@ eval_sort_key = 'accuracy' [model] rnn_hidden_units = 300 -word_emb_dim = 100 +word_emb_dim = 300 dropout = 0.5 use_crf = true print_every_step = 10 diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/POS_tagging/train_pos_tag.py index 09a9ba02..e817db44 100644 --- a/reproduction/POS_tagging/train_pos_tag.py +++ b/reproduction/POS_tagging/train_pos_tag.py @@ -8,16 +8,16 @@ import torch # in order to run fastNLP without installation sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) - from fastNLP.api.pipeline import Pipeline -from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor +from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor, SetInputProcessor, IndexerProcessor from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.trainer import Trainer from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.models.sequence_modeling import AdvSeqLabel -from fastNLP.io.dataset_loader import ZhConllPOSReader +from fastNLP.io.dataset_loader import ZhConllPOSReader, ConllxDataLoader from fastNLP.api.processor import ModelProcessor, Index2WordProcessor + cfgfile = './pos_tag.cfg' pickle_path = "save" @@ -35,7 +35,7 @@ def load_tencent_embed(embed_path, word2id): return embedding_tensor -def train(checkpoint=None): +def train(train_data_path, dev_data_path, checkpoint=None): # load config train_param = ConfigSection() model_param = ConfigSection() @@ -43,24 +43,36 @@ def train(checkpoint=None): print("config loaded") # Data Loader - dataset = ZhConllPOSReader().load("/home/hyan/train.conllx") + print("loading training set...") + dataset = ConllxDataLoader().load(train_data_path) + print("loading dev set...") + dev_data = ConllxDataLoader().load(dev_data_path) print(dataset) - print("dataset transformed") + print("================= dataset ready =====================") dataset.rename_field("tag", "truth") + dev_data.rename_field("tag", "truth") vocab_proc = VocabIndexerProcessor("words", new_added_filed_name="word_seq") tag_proc = VocabIndexerProcessor("truth") seq_len_proc = SeqLenProcessor(field_name="word_seq", new_added_field_name="word_seq_origin_len", is_input=True) + set_input_proc = SetInputProcessor("word_seq", "word_seq_origin_len", "truth") vocab_proc(dataset) tag_proc(dataset) seq_len_proc(dataset) + # index dev set + word_vocab, tag_vocab = vocab_proc.vocab, tag_proc.vocab + dev_data.apply(lambda ins: [word_vocab.to_index(w) for w in ins["words"]], new_field_name="word_seq") + dev_data.apply(lambda ins: [tag_vocab.to_index(w) for w in ins["truth"]], new_field_name="truth") + dev_data.apply(lambda ins: len(ins["word_seq"]), new_field_name="word_seq_origin_len") + + # set input & target dataset.set_input("word_seq", "word_seq_origin_len", "truth") + dev_data.set_input("word_seq", "word_seq_origin_len", "truth") dataset.set_target("truth", "word_seq_origin_len") - - print("processors defined") + dev_data.set_target("truth", "word_seq_origin_len") # dataset.set_is_target(tag_ids=True) model_param["vocab_size"] = vocab_proc.get_vocab_size() @@ -71,7 +83,7 @@ def train(checkpoint=None): if checkpoint is None: # pre_trained = load_tencent_embed("/home/zyfeng/data/char_tencent_embedding.pkl", vocab_proc.vocab.word2idx) pre_trained = None - model = AdvSeqLabel(model_param, id2words=tag_proc.vocab.idx2word, emb=pre_trained) + model = AdvSeqLabel(model_param, id2words=None, emb=pre_trained) print(model) else: model = torch.load(checkpoint) @@ -80,33 +92,71 @@ def train(checkpoint=None): trainer = Trainer(dataset, model, loss=None, metrics=SpanFPreRecMetric(tag_proc.vocab, pred="predict", target="truth", seq_lens="word_seq_origin_len"), - dev_data=dataset, metric_key="f", - use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save") + dev_data=dev_data, metric_key="f", + use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save_0") trainer.train(load_best_model=True) # save model & pipeline model_proc = ModelProcessor(model, seq_len_field_name="word_seq_origin_len") id2tag = Index2WordProcessor(tag_proc.vocab, "predict", "tag") - pp = Pipeline([vocab_proc, seq_len_proc, model_proc, id2tag]) + pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} torch.save(save_dict, "model_pp.pkl") print("pipeline saved") - torch.save(model, "./save/best_model.pkl") + +def run_test(test_path): + test_data = ZhConllPOSReader().load(test_path) + + with open("model_pp.pkl", "rb") as f: + save_dict = torch.load(f) + tag_vocab = save_dict["tag_vocab"] + pipeline = save_dict["pipeline"] + index_tag = IndexerProcessor(vocab=tag_vocab, field_name="tag", new_added_field_name="truth", is_input=False) + pipeline.pipeline = [index_tag] + pipeline.pipeline + + pipeline(test_data) + test_data.set_target("truth") + prediction = test_data.field_arrays["predict"].content + truth = test_data.field_arrays["truth"].content + seq_len = test_data.field_arrays["word_seq_origin_len"].content + + # padding by hand + max_length = max([len(seq) for seq in prediction]) + for idx in range(len(prediction)): + prediction[idx] = list(prediction[idx]) + ([0] * (max_length - len(prediction[idx]))) + truth[idx] = list(truth[idx]) + ([0] * (max_length - len(truth[idx]))) + evaluator = SpanFPreRecMetric(tag_vocab=tag_vocab, pred="predict", target="truth", + seq_lens="word_seq_origin_len") + evaluator({"predict": torch.Tensor(prediction), "word_seq_origin_len": torch.Tensor(seq_len)}, + {"truth": torch.Tensor(truth)}) + test_result = evaluator.get_metric() + f1 = round(test_result['f'] * 100, 2) + pre = round(test_result['pre'] * 100, 2) + rec = round(test_result['rec'] * 100, 2) + + return {"F1": f1, "precision": pre, "recall": rec} if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--train", type=str, help="training conll file", default="/home/zyfeng/data/sample.conllx") + parser.add_argument("--dev", type=str, help="dev conll file", default="/home/zyfeng/data/sample.conllx") + parser.add_argument("--test", type=str, help="test conll file", default=None) + parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") args = parser.parse_args() - if args.restart is True: - # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl - if args.checkpoint is None: - raise RuntimeError("Please provide the checkpoint. -cp ") - train(args.checkpoint) + if args.test is not None: + print(run_test(args.test)) else: - # 一次训练 python train_pos_tag.py - train() + if args.restart is True: + # 继续训练 python train_pos_tag.py -c -cp ./save/best_model.pkl + if args.checkpoint is None: + raise RuntimeError("Please provide the checkpoint. -cp ") + train(args.train, args.dev, args.checkpoint) + else: + # 一次训练 python train_pos_tag.py + train(args.train, args.dev) diff --git a/test/core/test_batch.py b/test/core/test_batch.py index 77aebea5..7308ebf0 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -89,3 +89,12 @@ class TestCase1(unittest.TestCase): self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertTrue(isinstance(y["y"], torch.Tensor)) self.assertEqual(tuple(y["y"].shape), (4, 4)) + + def test_list_of_numpy_to_tensor(self): + ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + + [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) + ds.set_input("x") + ds.set_target("y") + iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) + for x, y in iter: + print(x, y) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 261d42b3..72ced912 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -6,15 +6,29 @@ from fastNLP.core.fieldarray import FieldArray from fastNLP.core.instance import Instance -class TestDataSet(unittest.TestCase): - +class TestDataSetInit(unittest.TestCase): + """初始化DataSet的办法有以下几种: + 1) 用dict: + 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) + 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) + 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) + 2) 用list of Instance: + 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) + 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) + 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) + 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) + + 只接受纯list或者最外层ndarray + """ def test_init_v1(self): + # 一维list ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) def test_init_v2(self): + # 用dict ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) @@ -28,6 +42,8 @@ class TestDataSet(unittest.TestCase): with self.assertRaises(ValueError): _ = DataSet(0.00001) + +class TestDataSetMethods(unittest.TestCase): def test_append(self): dd = DataSet() for _ in range(3): diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index 82285462..da287916 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -42,13 +42,13 @@ class TestFieldArray(unittest.TestCase): self.assertEqual(fa.pytype, str) def test_support_np_array(self): - fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True) - self.assertEqual(fa.dtype, np.ndarray) - self.assertEqual(fa.pytype, np.ndarray) + fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True) + self.assertEqual(fa.dtype, np.float64) + self.assertEqual(fa.pytype, float) fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5])) - self.assertEqual(fa.dtype, np.ndarray) - self.assertEqual(fa.pytype, np.ndarray) + self.assertEqual(fa.dtype, np.float64) + self.assertEqual(fa.pytype, float) fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True) # in this case, pytype is actually a float. We do not care about it. diff --git a/test/models/test_biaffine_parser.py b/test/models/test_biaffine_parser.py index d87000a0..88ba09b8 100644 --- a/test/models/test_biaffine_parser.py +++ b/test/models/test_biaffine_parser.py @@ -1,8 +1,8 @@ -from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric -import fastNLP - import unittest +import fastNLP +from fastNLP.models.biaffine_parser import BiaffineParser, ParserLoss, ParserMetric + data_file = """ 1 The _ DET DT _ 3 det _ _ 2 new _ ADJ JJ _ 3 amod _ _ @@ -41,6 +41,7 @@ data_file = """ """ + def init_data(): ds = fastNLP.DataSet() v = {'word_seq': fastNLP.Vocabulary(), @@ -60,18 +61,19 @@ def init_data(): data.append(line) for name in ['word_seq', 'pos_seq', 'label_true']: - ds.apply(lambda x: ['']+list(x[name]), new_field_name=name) + ds.apply(lambda x: [''] + list(x[name]), new_field_name=name) ds.apply(lambda x: v[name].add_word_lst(x[name])) for name in ['word_seq', 'pos_seq', 'label_true']: ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) - ds.apply(lambda x: [0]+list(map(int, x['arc_true'])), new_field_name='arc_true') + ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) return ds, v['word_seq'], v['pos_seq'], v['label_true'] + class TestBiaffineParser(unittest.TestCase): def test_train(self): ds, v1, v2, v3 = init_data() @@ -84,5 +86,6 @@ class TestBiaffineParser(unittest.TestCase): n_epochs=10, use_cuda=False, use_tqdm=False) trainer.train(load_best_model=False) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/test/test_tutorial.py b/test/test_tutorial.py deleted file mode 100644 index 68cb6a41..00000000 --- a/test/test_tutorial.py +++ /dev/null @@ -1,91 +0,0 @@ -import unittest - -from fastNLP import DataSet -from fastNLP import Instance -from fastNLP import Tester -from fastNLP import Vocabulary -from fastNLP.core.losses import CrossEntropyLoss -from fastNLP.core.metrics import AccuracyMetric -from fastNLP.models import CNNText - - -class TestTutorial(unittest.TestCase): - def test_tutorial(self): - # 从csv读取数据到DataSet - sample_path = "test/data_for_tests/tutorial_sample_dataset.csv" - dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), - sep='\t') - print(len(dataset)) - print(dataset[0]) - - dataset.append(Instance(raw_sentence='fake data', label='0')) - dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') - # label转int - dataset.apply(lambda x: int(x['label']), new_field_name='label') - - # 使用空格分割句子 - def split_sent(ins): - return ins['raw_sentence'].split() - - dataset.apply(split_sent, new_field_name='words') - # 增加长度信息 - dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') - print(len(dataset)) - print(dataset[0]) - - # DataSet.drop(func)筛除数据 - dataset.drop(lambda x: x['seq_len'] <= 3) - print(len(dataset)) - - # 设置DataSet中,哪些field要转为tensor - # set target,loss或evaluate中的golden,计算loss,模型评估时使用 - dataset.set_target("label") - # set input,模型forward时使用 - dataset.set_input("words") - - # 分出测试集、训练集 - test_data, train_data = dataset.split(0.5) - print(len(test_data)) - print(len(train_data)) - - # 构建词表, Vocabulary.add(word) - vocab = Vocabulary(min_freq=2) - train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) - vocab.build_vocab() - - # index句子, Vocabulary.to_index(word) - train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') - test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') - print(test_data[0]) - - model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) - - from fastNLP import Trainer - from copy import deepcopy - - # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 - train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 - train_data.rename_field('label', 'label_seq') - test_data.rename_field('words', 'word_seq') - test_data.rename_field('label', 'label_seq') - - # 实例化Trainer,传入模型和数据,进行训练 - copy_model = deepcopy(model) - overfit_trainer = Trainer(train_data=test_data, model=copy_model, - loss=CrossEntropyLoss(pred="output", target="label_seq"), - metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, - dev_data=test_data, save_path="./save") - overfit_trainer.train() - - trainer = Trainer(train_data=train_data, model=model, - loss=CrossEntropyLoss(pred="output", target="label_seq"), - metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, - dev_data=test_data, save_path="./save") - trainer.train() - print('Train finished!') - - # 使用fastNLP的Tester测试脚本 - tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), - batch_size=4) - acc = tester.test() - print(acc) diff --git a/test/test_tutorials.py b/test/test_tutorials.py new file mode 100644 index 00000000..ee48c23b --- /dev/null +++ b/test/test_tutorials.py @@ -0,0 +1,432 @@ +import unittest + +from fastNLP import DataSet +from fastNLP import Instance +from fastNLP import Vocabulary +from fastNLP.core.losses import CrossEntropyLoss +from fastNLP.core.metrics import AccuracyMetric + + +class TestTutorial(unittest.TestCase): + def test_fastnlp_10min_tutorial(self): + # 从csv读取数据到DataSet + sample_path = "tutorials/sample_data/tutorial_sample_dataset.csv" + dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), + sep='\t') + print(len(dataset)) + print(dataset[0]) + print(dataset[-3]) + + dataset.append(Instance(raw_sentence='fake data', label='0')) + # 将所有数字转为小写 + dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') + # label转int + dataset.apply(lambda x: int(x['label']), new_field_name='label') + + # 使用空格分割句子 + def split_sent(ins): + return ins['raw_sentence'].split() + + dataset.apply(split_sent, new_field_name='words') + + # 增加长度信息 + dataset.apply(lambda x: len(x['words']), new_field_name='seq_len') + print(len(dataset)) + print(dataset[0]) + + # DataSet.drop(func)筛除数据 + dataset.drop(lambda x: x['seq_len'] <= 3) + print(len(dataset)) + + # 设置DataSet中,哪些field要转为tensor + # set target,loss或evaluate中的golden,计算loss,模型评估时使用 + dataset.set_target("label") + # set input,模型forward时使用 + dataset.set_input("words", "seq_len") + + # 分出测试集、训练集 + test_data, train_data = dataset.split(0.5) + print(len(test_data)) + print(len(train_data)) + + # 构建词表, Vocabulary.add(word) + vocab = Vocabulary(min_freq=2) + train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) + vocab.build_vocab() + + # index句子, Vocabulary.to_index(word) + train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') + test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words') + print(test_data[0]) + + # 如果你们需要做强化学习或者GAN之类的项目,你们也可以使用这些数据预处理的工具 + from fastNLP.core.batch import Batch + from fastNLP.core.sampler import RandomSampler + + batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler()) + for batch_x, batch_y in batch_iterator: + print("batch_x has: ", batch_x) + print("batch_y has: ", batch_y) + break + + from fastNLP.models import CNNText + model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) + + from fastNLP import Trainer + from copy import deepcopy + + # 更改DataSet中对应field的名称,要以模型的forward等参数名一致 + train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致 + train_data.rename_field('label', 'label_seq') + test_data.rename_field('words', 'word_seq') + test_data.rename_field('label', 'label_seq') + + loss = CrossEntropyLoss(pred="output", target="label_seq") + metric = AccuracyMetric(pred="predict", target="label_seq") + + # 实例化Trainer,传入模型和数据,进行训练 + # 先在test_data拟合(确保模型的实现是正确的) + copy_model = deepcopy(model) + overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data, + loss=loss, + metrics=metric, + save_path=None, + batch_size=32, + n_epochs=5) + overfit_trainer.train() + + # 用train_data训练,在test_data验证 + trainer = Trainer(model=model, train_data=train_data, dev_data=test_data, + loss=CrossEntropyLoss(pred="output", target="label_seq"), + metrics=AccuracyMetric(pred="predict", target="label_seq"), + save_path=None, + batch_size=32, + n_epochs=5) + trainer.train() + print('Train finished!') + + # 调用Tester在test_data上评价效果 + from fastNLP import Tester + + tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"), + batch_size=4) + acc = tester.test() + print(acc) + + def test_fastnlp_1min_tutorial(self): + # tutorials/fastnlp_1min_tutorial.ipynb + data_path = "tutorials/sample_data/tutorial_sample_dataset.csv" + ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t') + print(ds[1]) + + # 将所有数字转为小写 + ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') + # label转int + ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True) + + def split_sent(ins): + return ins['raw_sentence'].split() + + ds.apply(split_sent, new_field_name='words', is_input=True) + + # 分割训练集/验证集 + train_data, dev_data = ds.split(0.3) + print("Train size: ", len(train_data)) + print("Test size: ", len(dev_data)) + + from fastNLP import Vocabulary + vocab = Vocabulary(min_freq=2) + train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) + + # index句子, Vocabulary.to_index(word) + train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', + is_input=True) + dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', + is_input=True) + + from fastNLP.models import CNNText + model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) + + from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric + trainer = Trainer(model=model, + train_data=train_data, + dev_data=dev_data, + loss=CrossEntropyLoss(), + metrics=AccuracyMetric() + ) + trainer.train() + print('Train finished!') + + def test_fastnlp_advanced_tutorial(self): + import os + os.chdir("tutorials/fastnlp_advanced_tutorial") + + from fastNLP import DataSet + from fastNLP import Instance + from fastNLP import Vocabulary + from fastNLP import Trainer + from fastNLP import Tester + + # ### Instance + # Instance表示一个样本,由一个或者多个field(域、属性、特征)组成,每个field具有自己的名字以及值 + # 在初始化Instance的时候可以定义它包含的field,使用"field_name=field_value"的写法 + + # In[2]: + + # 组织一个Instance,这个Instance由premise、hypothesis、label三个field组成 + instance = Instance(premise='an premise example .', hypothesis='an hypothesis example.', label=1) + instance + + # In[3]: + + data_set = DataSet([instance] * 5) + data_set.append(instance) + data_set[-2:] + + # In[4]: + + # 如果某一个field的类型与dataset对应的field类型不一样仍可被加入dataset中 + instance2 = Instance(premise='the second premise example .', hypothesis='the second hypothesis example.', + label='1') + try: + data_set.append(instance2) + except: + pass + data_set[-2:] + + # In[5]: + + # 如果某一个field的名字不对,则该instance不能被append到dataset中 + instance3 = Instance(premises='the third premise example .', hypothesis='the third hypothesis example.', + label=1) + try: + data_set.append(instance3) + except: + print('cannot append instance') + pass + data_set[-2:] + + # In[6]: + + # 除了文本以外,还可以将tensor作为其中一个field的value + import torch + tensor_ins = Instance(image=torch.randn(5, 5), label=0) + ds = DataSet() + ds.append(tensor_ins) + ds + + from fastNLP import DataSet + from fastNLP import Instance + + # 从csv读取数据到DataSet + # 类csv文件,即每一行为一个example的文件,都可以使用这种方法进行数据读取 + dataset = DataSet.read_csv('tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\t') + # 查看DataSet的大小 + len(dataset) + + # In[8]: + + # 使用数字索引[k],获取第k个样本 + dataset[0] + + # In[9]: + + # 获取的样本是一个Instance + type(dataset[0]) + + # In[10]: + + # 使用数字索引[a: b],获取第a到第b个样本 + dataset[0: 3] + + # In[11]: + + # 索引也可以是负数 + dataset[-1] + + data_path = ['premise', 'hypothesis', 'label'] + + # 读入文件 + with open(data_path[0]) as f: + premise = f.readlines() + + with open(data_path[1]) as f: + hypothesis = f.readlines() + + with open(data_path[2]) as f: + label = f.readlines() + + assert len(premise) == len(hypothesis) and len(hypothesis) == len(label) + + # 组织DataSet + data_set = DataSet() + for p, h, l in zip(premise, hypothesis, label): + p = p.strip() # 将行末空格去除 + h = h.strip() # 将行末空格去除 + data_set.append(Instance(premise=p, hypothesis=h, truth=l)) + + data_set[0] + + # ### DataSet的其他操作 + # 在构建完毕DataSet后,仍然可以对DataSet的内容进行操作,函数接口为DataSet.apply() + + # In[13]: + + # 将premise域的所有文本转成小写 + data_set.apply(lambda x: x['premise'].lower(), new_field_name='premise') + data_set[-2:] + + # In[14]: + + # label转int + data_set.apply(lambda x: int(x['truth']), new_field_name='truth') + data_set[-2:] + + # In[15]: + + # 使用空格分割句子 + def split_sent(ins): + return ins['premise'].split() + + data_set.apply(split_sent, new_field_name='premise') + data_set.apply(lambda x: x['hypothesis'].split(), new_field_name='hypothesis') + data_set[-2:] + + # In[16]: + + # 筛选数据 + origin_data_set_len = len(data_set) + data_set.drop(lambda x: len(x['premise']) <= 6) + origin_data_set_len, len(data_set) + + # In[17]: + + # 增加长度信息 + data_set.apply(lambda x: [1] * len(x['premise']), new_field_name='premise_len') + data_set.apply(lambda x: [1] * len(x['hypothesis']), new_field_name='hypothesis_len') + data_set[-1] + + # In[18]: + + # 设定特征域、标签域 + data_set.set_input("premise", "premise_len", "hypothesis", "hypothesis_len") + data_set.set_target("truth") + + # In[19]: + + # 重命名field + data_set.rename_field('truth', 'label') + data_set[-1] + + # In[20]: + + # 切分训练、验证集、测试集 + train_data, vad_data = data_set.split(0.5) + dev_data, test_data = vad_data.split(0.4) + len(train_data), len(dev_data), len(test_data) + + # In[21]: + + # 深拷贝一个数据集 + import copy + train_data_2, dev_data_2 = copy.deepcopy(train_data), copy.deepcopy(dev_data) + del copy + + # 初始化词表,该词表最大的vocab_size为10000,词表中每个词出现的最低频率为2,''表示未知词语,''表示padding词语 + # Vocabulary默认初始化参数为max_size=None, min_freq=None, unknown='', padding='' + vocab = Vocabulary(max_size=10000, min_freq=2, unknown='', padding='') + + # 构建词表 + train_data.apply(lambda x: [vocab.add(word) for word in x['premise']]) + train_data.apply(lambda x: [vocab.add(word) for word in x['hypothesis']]) + vocab.build_vocab() + + # In[23]: + + # 根据词表index句子 + train_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') + train_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') + dev_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') + dev_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') + test_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise') + test_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') + train_data[-1], dev_data[-1], test_data[-1] + + # 读入vocab文件 + with open('vocab.txt') as f: + lines = f.readlines() + vocabs = [] + for line in lines: + vocabs.append(line.strip()) + + # 实例化Vocabulary + vocab_bert = Vocabulary(unknown=None, padding=None) + # 将vocabs列表加入Vocabulary + vocab_bert.add_word_lst(vocabs) + # 构建词表 + vocab_bert.build_vocab() + # 更新unknown与padding的token文本 + vocab_bert.unknown = '[UNK]' + vocab_bert.padding = '[PAD]' + + # In[25]: + + # 根据词表index句子 + train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') + train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], + new_field_name='hypothesis') + dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise') + dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') + train_data_2[-1], dev_data_2[-1] + + # step 1:加载模型参数(非必选) + from fastNLP.io.config_io import ConfigSection, ConfigLoader + args = ConfigSection() + ConfigLoader().load_config("./data/config", {"esim_model": args}) + args["vocab_size"] = len(vocab) + args.data + + # In[27]: + + # step 2:加载ESIM模型 + from fastNLP.models import ESIM + model = ESIM(**args.data) + model + + # In[28]: + + # 另一个例子:加载CNN文本分类模型 + from fastNLP.models import CNNText + cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1) + cnn_text_model + + from fastNLP import CrossEntropyLoss + from fastNLP import Adam + from fastNLP import AccuracyMetric + trainer = Trainer( + train_data=train_data, + model=model, + loss=CrossEntropyLoss(pred='pred', target='label'), + metrics=AccuracyMetric(), + n_epochs=5, + batch_size=16, + print_every=-1, + validate_every=-1, + dev_data=dev_data, + use_cuda=True, + optimizer=Adam(lr=1e-3, weight_decay=0), + check_code_level=-1, + metric_key='acc', + use_tqdm=False, + ) + trainer.train() + + tester = Tester( + data=test_data, + model=model, + metrics=AccuracyMetric(), + batch_size=args["batch_size"], + ) + tester.test() + + os.chdir("../..")