diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 20d7e5e0..96854e72 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -105,6 +105,7 @@ class FieldArray(object): 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) + 1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]}) 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; 然后后面的样本使用FieldArray.append进行添加。 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) @@ -119,10 +120,12 @@ class FieldArray(object): if isinstance(content, list): # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list # 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] - if len(content) == 1 and isinstance(content[0], np.ndarray): + for idx, item in enumerate(content): # 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) # 将[np.array] 转化为 list of list - content[0] = content[0].tolist() + # 也可以支持[array, array, array]的情况 + if isinstance(item, np.ndarray): + content[idx] = content[idx].tolist() elif isinstance(content, np.ndarray): content = content.tolist() # convert np.ndarray into 2-D list else: diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/POS_tagging/train_pos_tag.py index 4bdc23c7..6448c32b 100644 --- a/reproduction/POS_tagging/train_pos_tag.py +++ b/reproduction/POS_tagging/train_pos_tag.py @@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None): target="truth", seq_lens="word_seq_origin_len"), dev_data=dev_data, metric_key="f", - use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save_0") + use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117") trainer.train(load_best_model=True) # save model & pipeline @@ -102,14 +102,14 @@ def train(train_data_path, dev_data_path, checkpoint=None): pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} - torch.save(save_dict, "model_pp.pkl") + torch.save(save_dict, "model_pp_0117.pkl") print("pipeline saved") def run_test(test_path): test_data = ZhConllPOSReader().load(test_path) - with open("model_pp.pkl", "rb") as f: + with open("model_pp_0117.pkl", "rb") as f: save_dict = torch.load(f) tag_vocab = save_dict["tag_vocab"] pipeline = save_dict["pipeline"] diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index 834545c0..151d9335 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -31,6 +31,12 @@ class TestFieldArrayInit(unittest.TestCase): # 三维list fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) + def test_init_v7(self): + # list of array + fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) + self.assertEqual(fa.pytype, int) + self.assertEqual(fa.dtype, np.int) + def test_init_v4(self): # 一维list val = [1, 2, 3, 4]