Browse Source

添加FieldArray对list of np.array的支持

tags/v0.3.1^2
FengZiYjun 5 years ago
parent
commit
864c2238f8
3 changed files with 14 additions and 5 deletions
  1. +5
    -2
      fastNLP/core/fieldarray.py
  2. +3
    -3
      reproduction/POS_tagging/train_pos_tag.py
  3. +6
    -0
      test/core/test_fieldarray.py

+ 5
- 2
fastNLP/core/fieldarray.py View File

@@ -105,6 +105,7 @@ class FieldArray(object):
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]})
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray;
然后后面的样本使用FieldArray.append进行添加。 然后后面的样本使用FieldArray.append进行添加。
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
@@ -119,10 +120,12 @@ class FieldArray(object):
if isinstance(content, list): if isinstance(content, list):
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] # 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list]
if len(content) == 1 and isinstance(content[0], np.ndarray):
for idx, item in enumerate(content):
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) # 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field])
# 将[np.array] 转化为 list of list # 将[np.array] 转化为 list of list
content[0] = content[0].tolist()
# 也可以支持[array, array, array]的情况
if isinstance(item, np.ndarray):
content[idx] = content[idx].tolist()
elif isinstance(content, np.ndarray): elif isinstance(content, np.ndarray):
content = content.tolist() # convert np.ndarray into 2-D list content = content.tolist() # convert np.ndarray into 2-D list
else: else:


+ 3
- 3
reproduction/POS_tagging/train_pos_tag.py View File

@@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None):
target="truth", target="truth",
seq_lens="word_seq_origin_len"), seq_lens="word_seq_origin_len"),
dev_data=dev_data, metric_key="f", dev_data=dev_data, metric_key="f",
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save_0")
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117")
trainer.train(load_best_model=True) trainer.train(load_best_model=True)


# save model & pipeline # save model & pipeline
@@ -102,14 +102,14 @@ def train(train_data_path, dev_data_path, checkpoint=None):


pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag])
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab}
torch.save(save_dict, "model_pp.pkl")
torch.save(save_dict, "model_pp_0117.pkl")
print("pipeline saved") print("pipeline saved")




def run_test(test_path): def run_test(test_path):
test_data = ZhConllPOSReader().load(test_path) test_data = ZhConllPOSReader().load(test_path)


with open("model_pp.pkl", "rb") as f:
with open("model_pp_0117.pkl", "rb") as f:
save_dict = torch.load(f) save_dict = torch.load(f)
tag_vocab = save_dict["tag_vocab"] tag_vocab = save_dict["tag_vocab"]
pipeline = save_dict["pipeline"] pipeline = save_dict["pipeline"]


+ 6
- 0
test/core/test_fieldarray.py View File

@@ -31,6 +31,12 @@ class TestFieldArrayInit(unittest.TestCase):
# 三维list # 三维list
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True)


def test_init_v7(self):
# list of array
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True)
self.assertEqual(fa.pytype, int)
self.assertEqual(fa.dtype, np.int)

def test_init_v4(self): def test_init_v4(self):
# 一维list # 一维list
val = [1, 2, 3, 4] val = [1, 2, 3, 4]


Loading…
Cancel
Save