Browse Source

Merge branch 'trainer' of github.com:FengZiYjun/fastNLP into trainer

tags/v0.2.0^2
yh 5 years ago
parent
commit
306eee9690
12 changed files with 561 additions and 61 deletions
  1. +2
    -2
      fastNLP/core/fieldarray.py
  2. +1
    -1
      fastNLP/core/metrics.py
  3. +0
    -17
      fastNLP/core/predictor.py
  4. +7
    -3
      fastNLP/core/trainer.py
  5. +0
    -0
      test/core/__init__.py
  6. +5
    -3
      test/core/test_dataset.py
  7. +5
    -5
      test/core/test_fieldarray.py
  8. +2
    -27
      test/core/test_loss.py
  9. +3
    -1
      test/core/test_metrics.py
  10. +8
    -0
      test/core/test_optimizer.py
  11. +2
    -2
      test/test_tutorial.py
  12. +526
    -0
      tutorials/fastnlp_tutorial_1203.ipynb

+ 2
- 2
fastNLP/core/fieldarray.py View File

@@ -162,7 +162,7 @@ class FieldArray(object):
if self.is_input is False and self.is_target is False: if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name))
batch_size = len(indices) batch_size = len(indices)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if not is_iterable(self.content[0]): if not is_iterable(self.content[0]):
array = np.array([self.content[i] for i in indices], dtype=self.dtype) array = np.array([self.content[i] for i in indices], dtype=self.dtype)
elif self.dtype in (np.int64, np.float64): elif self.dtype in (np.int64, np.float64):
@@ -170,7 +170,7 @@ class FieldArray(object):
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)
for i, idx in enumerate(indices): for i, idx in enumerate(indices):
array[i][:len(self.content[idx])] = self.content[idx] array[i][:len(self.content[idx])] = self.content[idx]
else: # should only be str
else: # should only be str
array = np.array([self.content[i] for i in indices]) array = np.array([self.content[i] for i in indices])
return array return array




+ 1
- 1
fastNLP/core/metrics.py View File

@@ -467,7 +467,7 @@ def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average) recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
if isinstance(precision, np.ndarray): if isinstance(precision, np.ndarray):
res = 2 * precision * recall / (precision + recall)
res = 2 * precision * recall / (precision + recall + 1e-10)
res[(precision + recall) <= 0] = 0 res[(precision + recall) <= 0] = 0
return res return res
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0 return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0


+ 0
- 17
fastNLP/core/predictor.py View File

@@ -1,4 +1,3 @@
import numpy as np
import torch import torch


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
@@ -48,19 +47,3 @@ class Predictor(object):
"""Forward through network.""" """Forward through network."""
y = network(**x) y = network(**x)
return y return y


def seq_label_post_processor(batch_outputs, label_vocab):
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([label_vocab.to_word(int(x)) for x in example])
return results


def text_classify_post_processor(batch_outputs, label_vocab):
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([label_vocab.to_word(i) for i in idx])
return results

+ 7
- 3
fastNLP/core/trainer.py View File

@@ -2,11 +2,11 @@ import os
import time import time
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
from tqdm.autonotebook import tqdm


import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
from tqdm.autonotebook import tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
@@ -24,6 +24,7 @@ from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature



class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


@@ -263,8 +264,10 @@ class Trainer(object):


def _do_validation(self): def _do_validation(self):
res = self.tester.test() res = self.tester.test()
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
for name, metric in res.items():
for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.step)
if self.save_path is not None and self._better_eval_result(res): if self.save_path is not None and self._better_eval_result(res):
metric_key = self.metric_key if self.metric_key is not None else "None" metric_key = self.metric_key if self.metric_key is not None else "None"
self._save_model(self.model, self._save_model(self.model,
@@ -386,6 +389,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
f"should be torch.size([])") f"should be torch.size([])")
loss.backward() loss.backward()
except CheckError as e: except CheckError as e:
# TODO: another error raised if CheckError caught
pre_func_signature = get_func_signature(model.forward) pre_func_signature = get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,


+ 0
- 0
test/core/__init__.py View File


+ 5
- 3
test/core/test_dataset.py View File

@@ -141,8 +141,10 @@ class TestDataSet(unittest.TestCase):
def test_apply2(self): def test_apply2(self):
def split_sent(ins): def split_sent(ins):
return ins['raw_sentence'].split() return ins['raw_sentence'].split()
dataset = DataSet.read_csv('../../sentence.csv', headers=('raw_sentence', 'label'), sep='\t')
dataset.drop(lambda x:len(x['raw_sentence'].split())==0)

dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'),
sep='\t')
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)
dataset.apply(split_sent, new_field_name='words', is_input=True) dataset.apply(split_sent, new_field_name='words', is_input=True)
# print(dataset) # print(dataset)


@@ -160,9 +162,9 @@ class TestDataSet(unittest.TestCase):
ds_1 = DataSet.load("./my_ds.pkl") ds_1 = DataSet.load("./my_ds.pkl")
os.remove("my_ds.pkl") os.remove("my_ds.pkl")



class TestDataSetIter(unittest.TestCase): class TestDataSetIter(unittest.TestCase):
def test__repr__(self): def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds: for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")


+ 5
- 5
test/core/test_fieldarray.py View File

@@ -31,18 +31,18 @@ class TestFieldArray(unittest.TestCase):
self.assertEqual(fa.pytype, float) self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64) self.assertEqual(fa.dtype, np.float64)


fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=False)
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
fa.append(10) fa.append(10)
self.assertEqual(fa.pytype, float) self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64) self.assertEqual(fa.dtype, np.float64)


fa = FieldArray("y", ["a", "b", "c", "d"], is_input=False)
fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True)
fa.append("e") fa.append("e")
self.assertEqual(fa.dtype, np.str) self.assertEqual(fa.dtype, np.str)
self.assertEqual(fa.pytype, str) self.assertEqual(fa.pytype, str)


def test_support_np_array(self): def test_support_np_array(self):
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=False)
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True)
self.assertEqual(fa.dtype, np.ndarray) self.assertEqual(fa.dtype, np.ndarray)
self.assertEqual(fa.pytype, np.ndarray) self.assertEqual(fa.pytype, np.ndarray)


@@ -50,12 +50,12 @@ class TestFieldArray(unittest.TestCase):
self.assertEqual(fa.dtype, np.ndarray) self.assertEqual(fa.dtype, np.ndarray)
self.assertEqual(fa.pytype, np.ndarray) self.assertEqual(fa.pytype, np.ndarray)


fa = FieldArray("my_field", np.random.rand(3, 5), is_input=False)
fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True)
# in this case, pytype is actually a float. We do not care about it. # in this case, pytype is actually a float. We do not care about it.
self.assertEqual(fa.dtype, np.float64) self.assertEqual(fa.dtype, np.float64)


def test_nested_list(self): def test_nested_list(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=False)
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True)
self.assertEqual(fa.pytype, float) self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64) self.assertEqual(fa.dtype, np.float64)




+ 2
- 27
test/core/test_loss.py View File

@@ -6,7 +6,6 @@ import torch as tc
import torch.nn.functional as F import torch.nn.functional as F


import fastNLP.core.losses as loss import fastNLP.core.losses as loss
from fastNLP.core.losses import LossFunc




class TestLoss(unittest.TestCase): class TestLoss(unittest.TestCase):
@@ -245,31 +244,7 @@ class TestLoss(unittest.TestCase):
self.assertEqual(int(los * 1000), int(r * 1000)) self.assertEqual(int(los * 1000), int(r * 1000))


def test_case_8(self): def test_case_8(self):
def func(a, b):
return F.cross_entropy(a, b)

def func2(a, truth):
return func(a, truth)

def func3(predict, truth):
return func(predict, truth)

def func4(a, b, c=2):
return (a + b) * c

def func6(a, b, **kwargs):
c = kwargs['c']
return (a + b) * c

get_loss = LossFunc(func, {'a': 'predict', 'b': 'truth'})
predict = torch.randn(5, 3)
truth = torch.LongTensor([1, 0, 1, 2, 1])
loss1 = get_loss({'predict': predict}, {'truth': truth})
get_loss_2 = LossFunc(func2, {'a': 'predict'})
loss2 = get_loss_2({'predict': predict}, {'truth': truth})
get_loss_3 = LossFunc(func3)
loss3 = get_loss_3({'predict': predict}, {'truth': truth})
assert loss1 == loss2 and loss1 == loss3
pass




class TestLoss_v2(unittest.TestCase): class TestLoss_v2(unittest.TestCase):
@@ -317,7 +292,7 @@ class TestLosserError(unittest.TestCase):
target_dict = {'target': torch.zeros(16, 3).long()} target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss() los = loss.CrossEntropyLoss()


print(los(pred_dict=pred_dict, target_dict=target_dict))
# print(los(pred_dict=pred_dict, target_dict=target_dict))


def test_losser3(self): def test_losser3(self):
# (2) with corrupted size # (2) with corrupted size


+ 3
- 1
test/core/test_metrics.py View File

@@ -4,7 +4,7 @@ import numpy as np
import torch import torch


from fastNLP.core.metrics import AccuracyMetric from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score
from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score, pred_topk, accuracy_topk




class TestAccuracyMetric(unittest.TestCase): class TestAccuracyMetric(unittest.TestCase):
@@ -143,5 +143,7 @@ class TestUsefulFunctions(unittest.TestCase):
_ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) _ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) _ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) _ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
_ = pred_topk(np.random.randint(0, 3, size=(10, 1)))


# 跑通即可 # 跑通即可

+ 8
- 0
test/core/test_optimizer.py View File

@@ -10,9 +10,13 @@ class TestOptim(unittest.TestCase):
optim = SGD(torch.nn.Linear(10, 3).parameters()) optim = SGD(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"]) self.assertTrue("momentum" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD))


optim = SGD(lr=0.001) optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD))


optim = SGD(lr=0.002, momentum=0.989) optim = SGD(lr=0.002, momentum=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
@@ -27,9 +31,13 @@ class TestOptim(unittest.TestCase):
optim = Adam(torch.nn.Linear(10, 3).parameters()) optim = Adam(torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"]) self.assertTrue("weight_decay" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))


optim = Adam(lr=0.001) optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))


optim = Adam(lr=0.002, weight_decay=0.989) optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)


+ 2
- 2
test/test_tutorial.py View File

@@ -72,13 +72,13 @@ class TestTutorial(unittest.TestCase):
# 实例化Trainer,传入模型和数据,进行训练 # 实例化Trainer,传入模型和数据,进行训练
copy_model = deepcopy(model) copy_model = deepcopy(model)
overfit_trainer = Trainer(train_data=test_data, model=copy_model, overfit_trainer = Trainer(train_data=test_data, model=copy_model,
losser=CrossEntropyLoss(pred="output", target="label_seq"),
loss=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save") dev_data=test_data, save_path="./save")
overfit_trainer.train() overfit_trainer.train()


trainer = Trainer(train_data=train_data, model=model, trainer = Trainer(train_data=train_data, model=model,
losser=CrossEntropyLoss(pred="output", target="label_seq"),
loss=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4, metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save") dev_data=test_data, save_path="./save")
trainer.train() trainer.train()


+ 526
- 0
tutorials/fastnlp_tutorial_1203.ipynb View File

@@ -0,0 +1,526 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastNLP上手教程\n",
"-------\n",
"\n",
"fastNLP提供方便的数据预处理,训练和测试模型的功能"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/yh/miniconda2/envs/python3/lib/python3.6/site-packages/tqdm/autonotebook/__init__.py:14: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
" \" (e.g. in jupyter console)\", TqdmExperimentalWarning)\n"
]
}
],
"source": [
"import sys\n",
"sys.path.append('/Users/yh/Desktop/fastNLP/fastNLP/')\n",
"\n",
"import fastNLP as fnlp"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DataSet & Instance\n",
"------\n",
"\n",
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n",
"\n",
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1}\n"
]
}
],
"source": [
"from fastNLP import DataSet\n",
"from fastNLP import Instance\n",
"\n",
"# 从csv读取数据到DataSet\n",
"dataset = DataSet.read_csv('sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'raw_sentence': fake data,\n",
"'label': 0}"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# DataSet.append(Instance)加入新数据\n",
"\n",
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
"dataset[-1]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# DataSet.apply(func, new_field_name)对数据预处理\n",
"\n",
"# 将所有数字转为小写\n",
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
"# label转int\n",
"dataset.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n",
"# 使用空格分割句子\n",
"dataset.drop(lambda x:len(x['raw_sentence'].split())==0)\n",
"def split_sent(ins):\n",
" return ins['raw_sentence'].split()\n",
"dataset.apply(split_sent, new_field_name='words', is_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# DataSet.drop(func)筛除数据\n",
"# 删除低于某个长度的词语\n",
"# dataset.drop(lambda x: len(x['words']) <= 3)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train size: 5971\n",
"Test size: 2558\n"
]
}
],
"source": [
"# 分出测试集、训练集\n",
"\n",
"test_data, train_data = dataset.split(0.3)\n",
"print(\"Train size: \", len(test_data))\n",
"print(\"Test size: \", len(train_data))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Vocabulary\n",
"------\n",
"\n",
"fastNLP中的Vocabulary轻松构建词表,将词转成数字"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': gussied up with so many distracting special effects and visual party tricks that it 's not clear whether we 're supposed to shriek or laugh .,\n",
"'label': 1,\n",
"'label_seq': 1,\n",
"'words': ['gussied', 'up', 'with', 'so', 'many', 'distracting', 'special', 'effects', 'and', 'visual', 'party', 'tricks', 'that', 'it', \"'s\", 'not', 'clear', 'whether', 'we', \"'re\", 'supposed', 'to', 'shriek', 'or', 'laugh', '.'],\n",
"'word_seq': [1, 65, 16, 43, 108, 1, 329, 433, 7, 319, 1313, 1, 12, 10, 11, 27, 1428, 567, 86, 134, 1949, 8, 1, 49, 506, 2]}\n"
]
}
],
"source": [
"from fastNLP import Vocabulary\n",
"\n",
"# 构建词表, Vocabulary.add(word)\n",
"vocab = Vocabulary(min_freq=2)\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
"vocab.build_vocab()\n",
"\n",
"# index句子, Vocabulary.to_index(word)\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
"\n",
"\n",
"print(test_data[0])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_x has: {'words': array([list(['this', 'kind', 'of', 'hands-on', 'storytelling', 'is', 'ultimately', 'what', 'makes', 'shanghai', 'ghetto', 'move', 'beyond', 'a', 'good', ',', 'dry', ',', 'reliable', 'textbook', 'and', 'what', 'allows', 'it', 'to', 'rank', 'with', 'its', 'worthy', 'predecessors', '.']),\n",
" list(['the', 'entire', 'movie', 'is', 'filled', 'with', 'deja', 'vu', 'moments', '.'])],\n",
" dtype=object), 'word_seq': tensor([[ 19, 184, 6, 1, 481, 9, 206, 50, 91, 1210, 1609, 1330,\n",
" 495, 5, 63, 4, 1269, 4, 1, 1184, 7, 50, 1050, 10,\n",
" 8, 1611, 16, 21, 1039, 1, 2],\n",
" [ 3, 711, 22, 9, 1282, 16, 2482, 2483, 200, 2, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0]])}\n",
"batch_y has: {'label_seq': tensor([3, 2])}\n"
]
}
],
"source": [
"# 假设你们需要做强化学习或者gan之类的项目,也许你们可以使用这里的dataset\n",
"from fastNLP.core.batch import Batch\n",
"from fastNLP.core.sampler import RandomSampler\n",
"\n",
"batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())\n",
"for batch_x, batch_y in batch_iterator:\n",
" print(\"batch_x has: \", batch_x)\n",
" print(\"batch_y has: \", batch_y)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CNNText(\n",
" (embed): Embedding(\n",
" (embed): Embedding(3470, 50, padding_idx=0)\n",
" (dropout): Dropout(p=0.0)\n",
" )\n",
" (conv_pool): ConvMaxpool(\n",
" (convs): ModuleList(\n",
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n",
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n",
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1)\n",
" (fc): Linear(\n",
" (linear): Linear(in_features=12, out_features=5, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 定义一个简单的Pytorch模型\n",
"\n",
"from fastNLP.models import CNNText\n",
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Trainer & Tester\n",
"------\n",
"\n",
"使用fastNLP的Trainer训练模型"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer\n",
"from copy import deepcopy\n",
"from fastNLP.core.losses import CrossEntropyLoss\n",
"from fastNLP.core.metrics import AccuracyMetric"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-05 15:37:15\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=1870), HTML(value='')), layout=Layout(display…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10. Step:187/1870. AccuracyMetric: acc=0.351365\n",
"Epoch 2/10. Step:374/1870. AccuracyMetric: acc=0.470943\n",
"Epoch 3/10. Step:561/1870. AccuracyMetric: acc=0.600402\n",
"Epoch 4/10. Step:748/1870. AccuracyMetric: acc=0.702227\n",
"Epoch 5/10. Step:935/1870. AccuracyMetric: acc=0.79099\n",
"Epoch 6/10. Step:1122/1870. AccuracyMetric: acc=0.846424\n",
"Epoch 7/10. Step:1309/1870. AccuracyMetric: acc=0.874058\n",
"Epoch 8/10. Step:1496/1870. AccuracyMetric: acc=0.898844\n",
"Epoch 9/10. Step:1683/1870. AccuracyMetric: acc=0.910568\n",
"Epoch 10/10. Step:1870/1870. AccuracyMetric: acc=0.921286\n",
"\r"
]
}
],
"source": [
"# 进行overfitting测试\n",
"copy_model = deepcopy(model)\n",
"overfit_trainer = Trainer(model=copy_model, \n",
" train_data=test_data, \n",
" dev_data=test_data,\n",
" losser=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n",
" metrics=AccuracyMetric(),\n",
" n_epochs=10,\n",
" save_path=None)\n",
"overfit_trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-05 15:37:41\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=400), HTML(value='')), layout=Layout(display=…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"ename": "AttributeError",
"evalue": "'NoneType' object has no attribute 'squeeze'",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-5603b8b11a82>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mn_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m save_path='save/')\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Train finished!'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_summary_writer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mSummaryWriter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_tqdm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_tqdm_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_print_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36m_tqdm_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 206\u001b[0m \u001b[0mpbar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 207\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalidate_every\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdev_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 208\u001b[0;31m \u001b[0meval_res\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_validation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 209\u001b[0m \u001b[0meval_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"Epoch {}/{}. Step:{}/{}. \"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_epochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtotal_steps\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 210\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtester\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_format_eval_results\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meval_res\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/trainer.py\u001b[0m in \u001b[0;36m_do_validation\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtester\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 266\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_summary_writer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"valid_{}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglobal_step\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msave_path\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_better_eval_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0mmetric_key\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric_key\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric_key\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;34m\"None\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda2/envs/python3/lib/python3.6/site-packages/tensorboardX/writer.py\u001b[0m in \u001b[0;36madd_scalar\u001b[0;34m(self, tag, scalar_value, global_step, walltime)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_check_caffe2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscalar_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 333\u001b[0m \u001b[0mscalar_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mworkspace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFetchBlob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscalar_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile_writer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtag\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mscalar_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglobal_step\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwalltime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 335\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 336\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0madd_scalars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain_tag\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtag_scalar_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mglobal_step\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwalltime\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/miniconda2/envs/python3/lib/python3.6/site-packages/tensorboardX/summary.py\u001b[0m in \u001b[0;36mscalar\u001b[0;34m(name, scalar, collections)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_clean_tag\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mscalar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmake_np\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscalar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0;32massert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscalar\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'scalar should be 0D'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0mscalar\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mscalar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mSummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvalue\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mSummary\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mValue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtag\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msimple_value\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscalar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'squeeze'"
],
"output_type": "error"
}
],
"source": [
"# 实例化Trainer,传入模型和数据,进行训练\n",
"trainer = Trainer(model=model, \n",
" train_data=train_data, \n",
" dev_data=test_data,\n",
" losser=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n",
" metrics=AccuracyMetric(),\n",
" n_epochs=5,\n",
" save_path='save/')\n",
"trainer.train()\n",
"print('Train finished!')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Tester\n",
"\n",
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric())\n",
"acc = tester.test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# In summary\n",
"\n",
"## fastNLP Trainer的伪代码逻辑\n",
"### 1. 准备DataSet,假设DataSet中共有如下的fields\n",
" ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']\n",
" 通过\n",
" DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input\n",
" 通过\n",
" DataSet.set_target('label', flag=True)将'label'设置为target\n",
"### 2. 初始化模型\n",
" class Model(nn.Module):\n",
" def __init__(self):\n",
" xxx\n",
" def forward(self, word_seq1, word_seq2):\n",
" # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的\n",
" # (2) input field的数量可以多于这里的形参数量。但是不能少于。\n",
" xxxx\n",
" # 输出必须是一个dict\n",
"### 3. Trainer的训练过程\n",
" (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward\n",
" (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。\n",
" 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx}; \n",
" 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;\n",
" 为了解决以上的问题,我们的loss提供映射机制\n",
" 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target\n",
" 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可\n",
" (3) 对于Metric是同理的\n",
" Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值 \n",
" \n",
" \n",
"\n",
"## 一些问题.\n",
"### 1. DataSet中为什么需要设置input和target\n",
" 只有被设置为input或者target的数据才会在train的过程中被取出来\n",
" (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。\n",
" (1.2) 我们在传递值给losser或者metric的时候会使用来自: \n",
" (a)Model.forward的output\n",
" (b)被设置为target的field\n",
" \n",
"\n",
"### 2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数\n",
" (1.1) 构建模型过程中,\n",
" 例如:\n",
" DataSet中x,seq_lens是input,那么forward就应该是\n",
" def forward(self, x, seq_lens):\n",
" pass\n",
" 我们是通过形参名称进行匹配的field的\n",
" \n",
"\n",
"\n",
"### 1. 加载数据到DataSet\n",
"### 2. 使用apply操作对DataSet进行预处理\n",
" (2.1) 处理过程中将某些field设置为input,某些field设置为target\n",
"### 3. 构建模型\n",
" (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。\n",
" 例如:\n",
" DataSet中x,seq_lens是input,那么forward就应该是\n",
" def forward(self, x, seq_lens):\n",
" pass\n",
" 我们是通过形参名称进行匹配的field的\n",
" (3.2) 模型的forward的output需要是dict类型的。\n",
" 建议将输出设置为{\"pred\": xx}.\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Loading…
Cancel
Save