|
|
@@ -5,8 +5,59 @@ import numpy as np |
|
|
|
from fastNLP.core.fieldarray import FieldArray |
|
|
|
|
|
|
|
|
|
|
|
class TestFieldArrayInit(unittest.TestCase): |
|
|
|
""" |
|
|
|
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: |
|
|
|
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) |
|
|
|
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) |
|
|
|
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) |
|
|
|
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; |
|
|
|
然后后面的样本使用FieldArray.append进行添加。 |
|
|
|
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) |
|
|
|
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) |
|
|
|
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) |
|
|
|
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) |
|
|
|
""" |
|
|
|
|
|
|
|
def test_init_v1(self): |
|
|
|
# 二维list |
|
|
|
fa = FieldArray("x", [[1, 2], [3, 4]] * 5, is_input=True) |
|
|
|
|
|
|
|
def test_init_v2(self): |
|
|
|
# 二维array |
|
|
|
fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5), is_input=True) |
|
|
|
|
|
|
|
def test_init_v3(self): |
|
|
|
# 三维list |
|
|
|
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) |
|
|
|
|
|
|
|
def test_init_v4(self): |
|
|
|
# 一维list |
|
|
|
val = [1, 2, 3, 4] |
|
|
|
fa = FieldArray("x", [val], is_input=True) |
|
|
|
fa.append(val) |
|
|
|
|
|
|
|
def test_init_v5(self): |
|
|
|
# 一维array |
|
|
|
val = np.array([1, 2, 3, 4]) |
|
|
|
fa = FieldArray("x", [val], is_input=True) |
|
|
|
fa.append(val) |
|
|
|
|
|
|
|
def test_init_v6(self): |
|
|
|
# 二维array |
|
|
|
val = [[1, 2], [3, 4]] |
|
|
|
fa = FieldArray("x", [val], is_input=True) |
|
|
|
fa.append(val) |
|
|
|
|
|
|
|
def test_init_v7(self): |
|
|
|
# 二维list |
|
|
|
val = np.array([[1, 2], [3, 4]]) |
|
|
|
fa = FieldArray("x", [val], is_input=True) |
|
|
|
fa.append(val) |
|
|
|
|
|
|
|
|
|
|
|
class TestFieldArray(unittest.TestCase): |
|
|
|
def test(self): |
|
|
|
def test_main(self): |
|
|
|
fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True) |
|
|
|
self.assertEqual(len(fa), 5) |
|
|
|
fa.append(6) |
|
|
|