From b93ca9bb3059b8c82a7a5a7ae71c2d51ec006dee Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Thu, 17 Jan 2019 15:39:13 +0800 Subject: [PATCH] =?UTF-8?q?*=20FieldArray=E6=B7=BB=E5=8A=A0=E5=AF=B9list?= =?UTF-8?q?=20of=20np.array=E7=9A=84=E6=94=AF=E6=8C=81=20*=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95=EF=BC=9AFieldArray=E7=9A=84?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/fieldarray.py | 8 +++- reproduction/POS_tagging/train_pos_tag.py | 1 + test/core/test_fieldarray.py | 53 ++++++++++++++++++++++- test/test_tutorials.py | 4 +- 4 files changed, 61 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 4cde86ab..20d7e5e0 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -112,13 +112,17 @@ class FieldArray(object): 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) - 注意:np.array必须仅在最外层,即np.array([np.array, np.array]) 和 list of np.array不考虑 类型检查(dtype check)发生在当该field被设置为is_input或者is_target时。 """ self.name = name if isinstance(content, list): - content = content + # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list + # 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] + if len(content) == 1 and isinstance(content[0], np.ndarray): + # 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) + # 将[np.array] 转化为 list of list + content[0] = content[0].tolist() elif isinstance(content, np.ndarray): content = content.tolist() # convert np.ndarray into 2-D list else: diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/POS_tagging/train_pos_tag.py index e817db44..4bdc23c7 100644 --- a/reproduction/POS_tagging/train_pos_tag.py +++ b/reproduction/POS_tagging/train_pos_tag.py @@ -144,6 +144,7 @@ if __name__ == "__main__": parser.add_argument("--train", type=str, help="training conll file", default="/home/zyfeng/data/sample.conllx") parser.add_argument("--dev", type=str, help="dev conll file", default="/home/zyfeng/data/sample.conllx") parser.add_argument("--test", type=str, help="test conll file", default=None) + parser.add_argument("--save", type=str, help="path to save", default=None) parser.add_argument("-c", "--restart", action="store_true", help="whether to continue training") parser.add_argument("-cp", "--checkpoint", type=str, help="checkpoint of the trained model") diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index da287916..834545c0 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -5,8 +5,59 @@ import numpy as np from fastNLP.core.fieldarray import FieldArray +class TestFieldArrayInit(unittest.TestCase): + """ + 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: + 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) + 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) + 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) + 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; + 然后后面的样本使用FieldArray.append进行添加。 + 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) + 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) + 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) + 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) + """ + + def test_init_v1(self): + # 二维list + fa = FieldArray("x", [[1, 2], [3, 4]] * 5, is_input=True) + + def test_init_v2(self): + # 二维array + fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5), is_input=True) + + def test_init_v3(self): + # 三维list + fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) + + def test_init_v4(self): + # 一维list + val = [1, 2, 3, 4] + fa = FieldArray("x", [val], is_input=True) + fa.append(val) + + def test_init_v5(self): + # 一维array + val = np.array([1, 2, 3, 4]) + fa = FieldArray("x", [val], is_input=True) + fa.append(val) + + def test_init_v6(self): + # 二维array + val = [[1, 2], [3, 4]] + fa = FieldArray("x", [val], is_input=True) + fa.append(val) + + def test_init_v7(self): + # 二维list + val = np.array([[1, 2], [3, 4]]) + fa = FieldArray("x", [val], is_input=True) + fa.append(val) + + class TestFieldArray(unittest.TestCase): - def test(self): + def test_main(self): fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) self.assertEqual(len(fa), 5) fa.append(6) diff --git a/test/test_tutorials.py b/test/test_tutorials.py index ee48c23b..68c874fa 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -408,12 +408,12 @@ class TestTutorial(unittest.TestCase): model=model, loss=CrossEntropyLoss(pred='pred', target='label'), metrics=AccuracyMetric(), - n_epochs=5, + n_epochs=3, batch_size=16, print_every=-1, validate_every=-1, dev_data=dev_data, - use_cuda=True, + use_cuda=False, optimizer=Adam(lr=1e-3, weight_decay=0), check_code_level=-1, metric_key='acc',