@@ -105,6 +105,7 @@ class FieldArray(object): | |||||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | ||||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | ||||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | ||||
1.4) list of array: DataSet({"x": [np.array([1,2,3]), np.array([1,2,3])]}) | |||||
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | ||||
然后后面的样本使用FieldArray.append进行添加。 | 然后后面的样本使用FieldArray.append进行添加。 | ||||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | ||||
@@ -119,10 +120,12 @@ class FieldArray(object): | |||||
if isinstance(content, list): | if isinstance(content, list): | ||||
# 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | # 如果DataSet使用dict初始化, content 可能是二维list/二维array/三维list | ||||
# 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] | # 如果DataSet使用list of Instance 初始化, content可能是 [list]/[array]/[2D list] | ||||
if len(content) == 1 and isinstance(content[0], np.ndarray): | |||||
for idx, item in enumerate(content): | |||||
# 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) | # 这是使用list of Instance 初始化时第一个样本:FieldArray(name, [field]) | ||||
# 将[np.array] 转化为 list of list | # 将[np.array] 转化为 list of list | ||||
content[0] = content[0].tolist() | |||||
# 也可以支持[array, array, array]的情况 | |||||
if isinstance(item, np.ndarray): | |||||
content[idx] = content[idx].tolist() | |||||
elif isinstance(content, np.ndarray): | elif isinstance(content, np.ndarray): | ||||
content = content.tolist() # convert np.ndarray into 2-D list | content = content.tolist() # convert np.ndarray into 2-D list | ||||
else: | else: | ||||
@@ -93,7 +93,7 @@ def train(train_data_path, dev_data_path, checkpoint=None): | |||||
target="truth", | target="truth", | ||||
seq_lens="word_seq_origin_len"), | seq_lens="word_seq_origin_len"), | ||||
dev_data=dev_data, metric_key="f", | dev_data=dev_data, metric_key="f", | ||||
use_tqdm=True, use_cuda=True, print_every=5, n_epochs=6, save_path="./save_0") | |||||
use_tqdm=True, use_cuda=True, print_every=10, n_epochs=20, save_path="./save_0117") | |||||
trainer.train(load_best_model=True) | trainer.train(load_best_model=True) | ||||
# save model & pipeline | # save model & pipeline | ||||
@@ -102,14 +102,14 @@ def train(train_data_path, dev_data_path, checkpoint=None): | |||||
pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) | pp = Pipeline([vocab_proc, seq_len_proc, set_input_proc, model_proc, id2tag]) | ||||
save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | save_dict = {"pipeline": pp, "model": model, "tag_vocab": tag_proc.vocab} | ||||
torch.save(save_dict, "model_pp.pkl") | |||||
torch.save(save_dict, "model_pp_0117.pkl") | |||||
print("pipeline saved") | print("pipeline saved") | ||||
def run_test(test_path): | def run_test(test_path): | ||||
test_data = ZhConllPOSReader().load(test_path) | test_data = ZhConllPOSReader().load(test_path) | ||||
with open("model_pp.pkl", "rb") as f: | |||||
with open("model_pp_0117.pkl", "rb") as f: | |||||
save_dict = torch.load(f) | save_dict = torch.load(f) | ||||
tag_vocab = save_dict["tag_vocab"] | tag_vocab = save_dict["tag_vocab"] | ||||
pipeline = save_dict["pipeline"] | pipeline = save_dict["pipeline"] | ||||
@@ -31,6 +31,12 @@ class TestFieldArrayInit(unittest.TestCase): | |||||
# 三维list | # 三维list | ||||
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) | fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True) | ||||
def test_init_v7(self): | |||||
# list of array | |||||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True) | |||||
self.assertEqual(fa.pytype, int) | |||||
self.assertEqual(fa.dtype, np.int) | |||||
def test_init_v4(self): | def test_init_v4(self): | ||||
# 一维list | # 一维list | ||||
val = [1, 2, 3, 4] | val = [1, 2, 3, 4] | ||||