Browse Source

修改 callback 的测试文件

tags/v0.4.10
ChenXin 5 years ago
parent
commit
38c2ef7d74
4 changed files with 100 additions and 166 deletions
  1. +9
    -6
      fastNLP/core/callback.py
  2. +1
    -1
      test/core/test_dataset.py
  3. +7
    -7
      test/test_tutorials.py
  4. +83
    -152
      tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb

+ 9
- 6
fastNLP/core/callback.py View File

@@ -584,7 +584,9 @@ class TensorboardCallback(Callback):
path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time)) path = os.path.join(save_dir, 'tensorboard_logs_{}'.format(self.trainer.start_time))
if tensorboardX_flag: if tensorboardX_flag:
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)
else:
self._summary_writer = None
def on_batch_begin(self, batch_x, batch_y, indices): def on_batch_begin(self, batch_x, batch_y, indices):
if "model" in self.options and self.graph_added is False: if "model" in self.options and self.graph_added is False:
# tesorboardX 这里有大bug,暂时没法画模型图 # tesorboardX 这里有大bug,暂时没法画模型图
@@ -596,10 +598,10 @@ class TensorboardCallback(Callback):
self.graph_added = True self.graph_added = True
def on_backward_begin(self, loss): def on_backward_begin(self, loss):
if "loss" in self.options:
if "loss" in self.options and self._summary_writer:
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.trainer.step)
if "model" in self.options:
if "model" in self.options and self._summary_writer:
for name, param in self.trainer.model.named_parameters(): for name, param in self.trainer.model.named_parameters():
if param.requires_grad: if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step) self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.trainer.step)
@@ -608,15 +610,16 @@ class TensorboardCallback(Callback):
global_step=self.trainer.step) global_step=self.trainer.step)
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
if "metric" in self.options:
if "metric" in self.options and self._summary_writer:
for name, metric in eval_result.items(): for name, metric in eval_result.items():
for metric_key, metric_val in metric.items(): for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val, self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.trainer.step) global_step=self.trainer.step)
def on_train_end(self): def on_train_end(self):
self._summary_writer.close()
del self._summary_writer
if self._summary_writer:
self._summary_writer.close()
del self._summary_writer
def on_exception(self, exception): def on_exception(self, exception):
if hasattr(self, "_summary_writer"): if hasattr(self, "_summary_writer"):


+ 1
- 1
test/core/test_dataset.py View File

@@ -172,7 +172,7 @@ class TestDataSetMethods(unittest.TestCase):
def split_sent(ins): def split_sent(ins):
return ins['raw_sentence'].split() return ins['raw_sentence'].split()
csv_loader = CSVLoader(headers=['raw_sentence', 'label'],sep='\t') csv_loader = CSVLoader(headers=['raw_sentence', 'label'],sep='\t')
dataset = csv_loader.load('../data_for_tests/tutorial_sample_dataset.csv')
dataset = csv_loader.load('test/data_for_tests/tutorial_sample_dataset.csv')
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True)
dataset.apply(split_sent, new_field_name='words', is_input=True) dataset.apply(split_sent, new_field_name='words', is_input=True)
# print(dataset) # print(dataset)


+ 7
- 7
test/test_tutorials.py View File

@@ -10,7 +10,7 @@ from fastNLP.core.metrics import AccuracyMetric
class TestTutorial(unittest.TestCase): class TestTutorial(unittest.TestCase):
def test_fastnlp_10min_tutorial(self): def test_fastnlp_10min_tutorial(self):
# 从csv读取数据到DataSet # 从csv读取数据到DataSet
sample_path = "data_for_tests/tutorial_sample_dataset.csv"
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv"
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'), dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
sep='\t') sep='\t')
print(len(dataset)) print(len(dataset))
@@ -113,14 +113,14 @@ class TestTutorial(unittest.TestCase):


def test_fastnlp_1min_tutorial(self): def test_fastnlp_1min_tutorial(self):
# tutorials/fastnlp_1min_tutorial.ipynb # tutorials/fastnlp_1min_tutorial.ipynb
data_path = "tutorials/sample_data/tutorial_sample_dataset.csv"
data_path = "test/data_for_tests/tutorial_sample_dataset.csv"
ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t') ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t')
print(ds[1]) print(ds[1])


# 将所有数字转为小写 # 将所有数字转为小写
ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence') ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
# label转int # label转int
ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)
ds.apply(lambda x: int(x['label']), new_field_name='target', is_target=True)


def split_sent(ins): def split_sent(ins):
return ins['raw_sentence'].split() return ins['raw_sentence'].split()
@@ -137,9 +137,9 @@ class TestTutorial(unittest.TestCase):
train_data.apply(lambda x: [vocab.add(word) for word in x['words']]) train_data.apply(lambda x: [vocab.add(word) for word in x['words']])


# index句子, Vocabulary.to_index(word) # index句子, Vocabulary.to_index(word)
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq',
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words',
is_input=True) is_input=True)
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq',
dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words',
is_input=True) is_input=True)


from fastNLP.models import CNNText from fastNLP.models import CNNText
@@ -152,14 +152,14 @@ class TestTutorial(unittest.TestCase):
dev_data=dev_data, dev_data=dev_data,
loss=CrossEntropyLoss(), loss=CrossEntropyLoss(),
optimizer= Adam(), optimizer= Adam(),
metrics=AccuracyMetric(target='label_seq')
metrics=AccuracyMetric(target='target')
) )
trainer.train() trainer.train()
print('Train finished!') print('Train finished!')


def test_fastnlp_advanced_tutorial(self): def test_fastnlp_advanced_tutorial(self):
import os import os
os.chdir("tutorials/fastnlp_advanced_tutorial")
os.chdir("test/tutorials/fastnlp_advanced_tutorial")


from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance


+ 83
- 152
tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb View File

@@ -170,11 +170,11 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"DataSet({'image': tensor([[ 4.7106e-01, -1.2246e+00, 3.1234e-01, -1.6781e+00, -8.7967e-01],\n",
" [ 1.1454e+00, 1.2236e-01, 3.0258e-01, -1.5454e+00, 8.9201e-01],\n",
" [-5.7143e-03, 3.9488e-01, 2.0287e-01, -1.5726e+00, 9.3171e-01],\n",
" [ 6.8914e-01, -2.6302e-01, -8.2694e-01, 9.5942e-01, -5.2589e-01],\n",
" [-5.7798e-03, -9.1621e-03, 1.0077e-03, 9.1716e-02, 1.0565e+00]]) type=torch.Tensor,\n",
"DataSet({'image': tensor([[ 0.3582, -1.0358, 1.4785, -1.5288, -0.9982],\n",
" [-0.3973, -0.4294, 0.9215, -1.9631, -1.6556],\n",
" [ 0.3313, -1.7714, 0.8729, 0.6976, -1.3172],\n",
" [-0.6403, 0.5023, -0.9919, 1.1178, -0.3710],\n",
" [-0.3692, 1.8631, -1.3646, -0.7290, -1.0774]]) type=torch.Tensor,\n",
"'label': 0 type=int})" "'label': 0 type=int})"
] ]
}, },
@@ -524,7 +524,11 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# 设定特征域、标签域\n", "# 设定特征域、标签域\n",
"data_set.set_input(\"premise\", \"premise_len\", \"hypothesis\", \"hypothesis_len\")\n",
"data_set.rename_field(\"premise\",\"words1\")\n",
"data_set.rename_field(\"premise_len\",\"seq_len1\")\n",
"data_set.rename_field(\"hypothesis\",\"words2\")\n",
"data_set.rename_field(\"hypothesis_len\",\"seq_len2\")\n",
"data_set.set_input(\"words1\", \"seq_len1\", \"words2\", \"seq_len2\")\n",
"data_set.set_target(\"truth\")" "data_set.set_target(\"truth\")"
] ]
}, },
@@ -536,10 +540,10 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"{'premise': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
"'hypothesis': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
"'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
"'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
"{'words1': ['a', 'woman', 'is', 'walking', 'across', 'the', 'street', 'eating', 'a', 'banana', ',', 'while', 'a', 'man', 'is', 'following', 'with', 'his', 'briefcase', '.'] type=list,\n",
"'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
"'words2': ['a', 'woman', 'eating', 'a', 'banana', 'crosses', 'a', 'street'] type=list,\n",
"'seq_len2': [1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
"'label': 0 type=int}" "'label': 0 type=int}"
] ]
}, },
@@ -613,7 +617,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22,
"execution_count": 23,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -622,49 +626,49 @@
"vocab = Vocabulary(max_size=10000, min_freq=2, unknown='<unk>', padding='<pad>')\n", "vocab = Vocabulary(max_size=10000, min_freq=2, unknown='<unk>', padding='<pad>')\n",
"\n", "\n",
"# 构建词表\n", "# 构建词表\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['premise']])\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['hypothesis']])\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['words1']])\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['words2']])\n",
"vocab.build_vocab()" "vocab.build_vocab()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23,
"execution_count": 24,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"({'premise': [2, 10, 9, 2, 15, 115, 6, 11, 5, 132, 17, 2, 76, 9, 77, 55, 3] type=list,\n",
" 'hypothesis': [1, 2, 56, 17, 1, 4, 13, 49, 123, 12, 6, 11, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int},\n",
" {'premise': [50, 124, 10, 7, 68, 91, 92, 38, 2, 55, 3] type=list,\n",
" 'hypothesis': [21, 10, 5, 2, 55, 7, 99, 64, 48, 1, 22, 1, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
"({'words1': [2, 9, 4, 2, 75, 85, 7, 86, 76, 77, 87, 88, 89, 2, 90, 3] type=list,\n",
" 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'words2': [18, 9, 10, 1, 3] type=list,\n",
" 'seq_len2': [1, 1, 1, 1, 1] type=list,\n",
" 'label': 1 type=int},\n", " 'label': 1 type=int},\n",
" {'premise': [13, 24, 4, 14, 29, 5, 25, 4, 8, 39, 9, 14, 34, 4, 40, 41, 4, 16, 12, 2, 11, 4, 30, 28, 2, 42, 8, 2, 43, 44, 17, 2, 45, 35, 26, 31, 27, 5, 6, 32, 3] type=list,\n",
" 'hypothesis': [37, 49, 123, 30, 28, 2, 55, 12, 2, 11, 3] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int})"
" {'words1': [22, 32, 5, 110, 81, 111, 112, 5, 82, 3] type=list,\n",
" 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'words2': [64, 32, 82, 133, 84, 3] type=list,\n",
" 'seq_len2': [1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int},\n",
" {'words1': [2, 9, 97, 1, 20, 7, 54, 5, 1, 1, 70, 2, 11, 110, 2, 62, 3] type=list,\n",
" 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'words2': [23, 1, 58, 10, 12, 1, 70, 133, 84, 3] type=list,\n",
" 'seq_len2': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 1 type=int})"
] ]
}, },
"execution_count": 23,
"execution_count": 24,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"# 根据词表index句子\n", "# 根据词表index句子\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise')\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise')\n",
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['premise']], new_field_name='premise')\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words1']], new_field_name='words1')\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words2']], new_field_name='words2')\n",
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words1']], new_field_name='words1')\n",
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words2']], new_field_name='words2')\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words1']], new_field_name='words1')\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words2']], new_field_name='words2')\n",
"train_data[-1], dev_data[-1], test_data[-1]" "train_data[-1], dev_data[-1], test_data[-1]"
] ]
}, },
@@ -679,7 +683,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24,
"execution_count": 25,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -703,35 +707,35 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 25,
"execution_count": 26,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"({'premise': [1037, 2158, 1998, 1037, 2450, 2892, 1996, 2395, 1999, 2392, 1997, 1037, 10733, 1998, 100, 4825, 1012] type=list,\n",
" 'hypothesis': [100, 1037, 3232, 1997, 7884, 1010, 2048, 2111, 3328, 2408, 1996, 2395, 1012] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int},\n",
" {'premise': [2019, 3080, 2158, 2003, 5948, 4589, 10869, 2012, 1037, 4825, 1012] type=list,\n",
" 'hypothesis': [100, 2158, 1999, 1037, 4825, 2003, 3403, 2005, 2010, 7954, 2000, 7180, 1012] type=list,\n",
" 'premise_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'hypothesis_len': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 1 type=int})"
"({'words1': [1037, 2450, 1999, 1037, 2665, 6598, 1998, 7415, 2058, 2014, 2132, 2559, 2875, 1037, 3028, 1012] type=list,\n",
" 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'words2': [100, 2450, 2003, 3147, 1012] type=list,\n",
" 'seq_len2': [1, 1, 1, 1, 1] type=list,\n",
" 'label': 1 type=int},\n",
" {'words1': [2048, 2308, 1010, 3173, 2833, 100, 16143, 1010, 8549, 1012] type=list,\n",
" 'seq_len1': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] type=list,\n",
" 'words2': [100, 2308, 8549, 2169, 2060, 1012] type=list,\n",
" 'seq_len2': [1, 1, 1, 1, 1, 1] type=list,\n",
" 'label': 0 type=int})"
] ]
}, },
"execution_count": 25,
"execution_count": 26,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"# 根据词表index句子\n", "# 根据词表index句子\n",
"train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise')\n",
"train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
"dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['premise']], new_field_name='premise')\n",
"dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis')\n",
"train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words1']], new_field_name='words1')\n",
"train_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words2']], new_field_name='words2')\n",
"dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words1']], new_field_name='words1')\n",
"dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['words2']], new_field_name='words2')\n",
"train_data_2[-1], dev_data_2[-1]" "train_data_2[-1], dev_data_2[-1]"
] ]
}, },
@@ -747,7 +751,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26,
"execution_count": 27,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -760,10 +764,10 @@
" 'num_classes': 3,\n", " 'num_classes': 3,\n",
" 'gpu': True,\n", " 'gpu': True,\n",
" 'batch_size': 32,\n", " 'batch_size': 32,\n",
" 'vocab_size': 156}"
" 'vocab_size': 143}"
] ]
}, },
"execution_count": 26,
"execution_count": 27,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -779,7 +783,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27,
"execution_count": 28,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@@ -788,21 +792,17 @@
"ESIM(\n", "ESIM(\n",
" (drop): Dropout(p=0.3)\n", " (drop): Dropout(p=0.3)\n",
" (embedding): Embedding(\n", " (embedding): Embedding(\n",
" (embed): Embedding(156, 300, padding_idx=0)\n",
" 143, 300\n",
" (dropout): Dropout(p=0.3)\n", " (dropout): Dropout(p=0.3)\n",
" )\n", " )\n",
" (embedding_layer): Linear(\n",
" (linear): Linear(in_features=300, out_features=300, bias=True)\n",
" )\n",
" (embedding_layer): Linear(in_features=300, out_features=300, bias=True)\n",
" (encoder): LSTM(\n", " (encoder): LSTM(\n",
" (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n", " (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n",
" )\n", " )\n",
" (bi_attention): Bi_Attention()\n",
" (bi_attention): BiAttention()\n",
" (mean_pooling): MeanPoolWithMask()\n", " (mean_pooling): MeanPoolWithMask()\n",
" (max_pooling): MaxPoolWithMask()\n", " (max_pooling): MaxPoolWithMask()\n",
" (inference_layer): Linear(\n",
" (linear): Linear(in_features=1200, out_features=300, bias=True)\n",
" )\n",
" (inference_layer): Linear(in_features=1200, out_features=300, bias=True)\n",
" (decoder): LSTM(\n", " (decoder): LSTM(\n",
" (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n", " (lstm): LSTM(300, 300, batch_first=True, bidirectional=True)\n",
" )\n", " )\n",
@@ -816,7 +816,7 @@
")" ")"
] ]
}, },
"execution_count": 27,
"execution_count": 28,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@@ -824,49 +824,10 @@
"source": [ "source": [
"# step 2:加载ESIM模型\n", "# step 2:加载ESIM模型\n",
"from fastNLP.models import ESIM\n", "from fastNLP.models import ESIM\n",
"model = ESIM(**args.data)\n",
"model = ESIM(args[\"vocab_size\"], args[\"embed_dim\"], args[\"hidden_size\"], args[\"dropout\"], args[\"num_classes\"])\n",
"model" "model"
] ]
}, },
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CNNText(\n",
" (embed): Embedding(\n",
" (embed): Embedding(156, 50, padding_idx=0)\n",
" (dropout): Dropout(p=0.0)\n",
" )\n",
" (conv_pool): ConvMaxpool(\n",
" (convs): ModuleList(\n",
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n",
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n",
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1)\n",
" (fc): Linear(\n",
" (linear): Linear(in_features=12, out_features=5, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 另一个例子:加载CNN文本分类模型\n",
"from fastNLP.models import CNNText\n",
"cnn_text_model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
"cnn_text_model"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
@@ -1009,54 +970,25 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"training epochs started 2019-04-14-23-22-28\n",
"[epoch: 1 step: 1] train loss: 1.51372 time: 0:00:00\n",
"[epoch: 1 step: 2] train loss: 1.26874 time: 0:00:00\n",
"[epoch: 1 step: 3] train loss: 1.49786 time: 0:00:00\n",
"[epoch: 1 step: 4] train loss: 1.37505 time: 0:00:00\n",
"Evaluation at Epoch 1/5. Step:4/20. AccuracyMetric: acc=0.344828\n",
"\n",
"[epoch: 2 step: 5] train loss: 1.21877 time: 0:00:00\n",
"[epoch: 2 step: 6] train loss: 1.14183 time: 0:00:00\n",
"[epoch: 2 step: 7] train loss: 1.15934 time: 0:00:00\n",
"[epoch: 2 step: 8] train loss: 1.55148 time: 0:00:00\n",
"Evaluation at Epoch 2/5. Step:8/20. AccuracyMetric: acc=0.344828\n",
"\n",
"[epoch: 3 step: 9] train loss: 1.1457 time: 0:00:00\n",
"[epoch: 3 step: 10] train loss: 1.0547 time: 0:00:00\n",
"[epoch: 3 step: 11] train loss: 1.40139 time: 0:00:00\n",
"[epoch: 3 step: 12] train loss: 0.551445 time: 0:00:00\n",
"Evaluation at Epoch 3/5. Step:12/20. AccuracyMetric: acc=0.275862\n",
"\n",
"[epoch: 4 step: 13] train loss: 1.07965 time: 0:00:00\n",
"[epoch: 4 step: 14] train loss: 1.04118 time: 0:00:00\n",
"[epoch: 4 step: 15] train loss: 1.11719 time: 0:00:00\n",
"[epoch: 4 step: 16] train loss: 1.09861 time: 0:00:00\n",
"Evaluation at Epoch 4/5. Step:16/20. AccuracyMetric: acc=0.275862\n",
"\n",
"[epoch: 5 step: 17] train loss: 1.10795 time: 0:00:00\n",
"[epoch: 5 step: 18] train loss: 1.26715 time: 0:00:00\n",
"[epoch: 5 step: 19] train loss: 1.19875 time: 0:00:00\n",
"[epoch: 5 step: 20] train loss: 1.09862 time: 0:00:00\n",
"Evaluation at Epoch 5/5. Step:20/20. AccuracyMetric: acc=0.37931\n",
"\n",
"\n",
"In Epoch:5/Step:20, got best dev performance:AccuracyMetric: acc=0.37931\n",
"Reloaded the best model.\n"
"training epochs started 2019-05-14-19-49-25\n"
] ]
}, },
{ {
"data": {
"text/plain": [
"{'best_eval': {'AccuracyMetric': {'acc': 0.37931}},\n",
" 'best_epoch': 5,\n",
" 'best_step': 20,\n",
" 'seconds': 0.5}"
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
"ename": "AssertionError",
"evalue": "seq_len can only have one dimension, got False.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-29-a3d2740dc8ef>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0muse_tqdm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 19\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, load_best_model)\u001b[0m\n\u001b[1;32m 522\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 523\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 525\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_train_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 526\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mCallbackException\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mKeyboardInterrupt\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36m_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 573\u001b[0m \u001b[0;31m# negative sampling; replace unknown; re-weight batch_y\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 574\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_manager\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_x\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_y\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 575\u001b[0;31m \u001b[0mprediction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_x\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 576\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 577\u001b[0m \u001b[0;31m# edit prediction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36m_data_forward\u001b[0;34m(self, network, x)\u001b[0m\n\u001b[1;32m 661\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_data_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 662\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_build_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnetwork\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 663\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnetwork\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 664\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 665\u001b[0m raise TypeError(\n",
"\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 489\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 491\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 492\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 493\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/models/snli.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, words1, words2, seq_len1, seq_len2, target)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mseq_len1\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0mseq_len1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_len_to_mask\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq_len1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0mseq_len1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpremise0\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpremise0\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/Users/fdujyn/anaconda3/lib/python3.6/site-packages/fastNLP/core/utils.py\u001b[0m in \u001b[0;36mseq_len_to_mask\u001b[0;34m(seq_len)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 627\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 628\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf\"seq_len can only have one dimension, got {seq_len.dim() == 1}.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 629\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 630\u001b[0m \u001b[0mmax_len\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_len\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlong\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAssertionError\u001b[0m: seq_len can only have one dimension, got False."
]
} }
], ],
"source": [ "source": [
@@ -1073,7 +1005,6 @@
" print_every=-1,\n", " print_every=-1,\n",
" validate_every=-1,\n", " validate_every=-1,\n",
" dev_data=dev_data,\n", " dev_data=dev_data,\n",
" use_cuda=True,\n",
" optimizer=Adam(lr=1e-3, weight_decay=0),\n", " optimizer=Adam(lr=1e-3, weight_decay=0),\n",
" check_code_level=-1,\n", " check_code_level=-1,\n",
" metric_key='acc',\n", " metric_key='acc',\n",
@@ -1178,7 +1109,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.0"
"version": "3.6.7"
} }
}, },
"nbformat": 4, "nbformat": 4,


Loading…
Cancel
Save