Browse Source

conflict fix

tags/v0.2.0^2
yh 6 years ago
parent
commit
e779409cf8
15 changed files with 904 additions and 147 deletions
  1. +0
    -1
      fastNLP/core/__init__.py
  2. +2
    -2
      fastNLP/core/batch.py
  3. +4
    -1
      fastNLP/core/dataset.py
  4. +15
    -4
      fastNLP/core/losses.py
  5. +5
    -115
      fastNLP/core/metrics.py
  6. +2
    -2
      fastNLP/core/predictor.py
  7. +1
    -1
      fastNLP/core/sampler.py
  8. +0
    -7
      fastNLP/core/vocabulary.py
  9. +21
    -0
      test/core/test_dataset.py
  10. +22
    -0
      test/core/test_fieldarray.py
  11. +13
    -0
      test/core/test_metrics.py
  12. +29
    -1
      test/core/test_predictor.py
  13. +11
    -1
      test/core/test_sampler.py
  14. +11
    -12
      test/core/test_trainer.py
  15. +768
    -0
      tutorials/fastnlp_tutorial_1204.ipynb

+ 0
- 1
fastNLP/core/__init__.py View File

@@ -3,7 +3,6 @@ from .dataset import DataSet
from .fieldarray import FieldArray
from .instance import Instance
from .losses import Loss
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator
from .optimizer import Optimizer
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester


+ 2
- 2
fastNLP/core/batch.py View File

@@ -62,8 +62,8 @@ class Batch(object):


def to_tensor(batch, dtype):
if dtype in (np.int8, np.int16, np.int32, np.int64):
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
batch = torch.LongTensor(batch)
if dtype in (np.float32, np.float64):
if dtype in (float, np.float32, np.float64):
batch = torch.FloatTensor(batch)
return batch

+ 4
- 1
fastNLP/core/dataset.py View File

@@ -1,4 +1,5 @@
import _pickle as pickle

import numpy as np

from fastNLP.core.fieldarray import FieldArray
@@ -66,10 +67,12 @@ class DataSet(object):
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = idx

def __getitem__(self, item):
assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx]

def __repr__(self):
return self.dataset[self.idx].__repr__()

@@ -339,6 +342,6 @@ class DataSet(object):
pickle.dump(self, f)

@staticmethod
def load(self, path):
def load(path):
with open(path, 'rb') as f:
return pickle.load(f)

+ 15
- 4
fastNLP/core/losses.py View File

@@ -69,9 +69,20 @@ class LossBase(object):
f"positional argument.).")

def _fast_param_map(self, pred_dict, target_dict):
"""

Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return tuple(pred_dict.values())[0], tuple(target_dict.values())[0]
return None
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0]
return fast_param
return fast_param

def __call__(self, pred_dict, target_dict, check=False):
"""
@@ -81,8 +92,8 @@ class LossBase(object):
:return:
"""
fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param is not None:
loss = self.get_loss(*fast_param)
if fast_param:
loss = self.get_loss(**fast_param)
return loss

if not self._checked:


+ 5
- 115
fastNLP/core/metrics.py View File

@@ -82,7 +82,9 @@ class MetricBase(object):
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
return pred_dict.values[0] and target_dict.values[0]
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0]
return fast_param
return fast_param

def __call__(self, pred_dict, target_dict):
@@ -304,118 +306,6 @@ def _prepare_metrics(metrics):
return _metrics


class Evaluator(object):
def __init__(self):
pass

def __call__(self, predict, truth):
"""

:param predict: list of tensors, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return:
"""
raise NotImplementedError


class ClassifyEvaluator(Evaluator):
def __init__(self):
super(ClassifyEvaluator, self).__init__()

def __call__(self, predict, truth):
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(truth, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return {"accuracy": acc}


class SeqLabelEvaluator(Evaluator):
def __init__(self):
super(SeqLabelEvaluator, self).__init__()

def __call__(self, predict, truth, **_):
"""

:param predict: list of List, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
"""
total_correct, total_count = 0., 0.
for x, y in zip(predict, truth):
x = torch.tensor(x)
y = y.to(x) # make sure they are in the same device
mask = (y > 0)
correct = torch.sum(((x == y) * mask).long())
total_correct += float(correct)
total_count += float(torch.sum(mask.long()))
accuracy = total_correct / total_count
return {"accuracy": float(accuracy)}


class SeqLabelEvaluator2(Evaluator):
# 上面的evaluator应该是错误的
def __init__(self, seq_lens_field_name='word_seq_origin_len'):
super(SeqLabelEvaluator2, self).__init__()
self.end_tagidx_set = set()
self.seq_lens_field_name = seq_lens_field_name

def __call__(self, predict, truth, **_):
"""

:param predict: list of batch, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
"""
seq_lens = _[self.seq_lens_field_name]
corr_count = 0
pred_count = 0
truth_count = 0
for x, y, seq_len in zip(predict, truth, seq_lens):
x = x.cpu().numpy()
y = y.cpu().numpy()
for idx, s_l in enumerate(seq_len):
x_ = x[idx]
y_ = y[idx]
x_ = x_[:s_l]
y_ = y_[:s_l]
flag = True
start = 0
for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)):
if x_i in self.end_tagidx_set:
truth_count += 1
for j in range(start, idx_i + 1):
if y_[j] != x_[j]:
flag = False
break
if flag:
corr_count += 1
flag = True
start = idx_i + 1
if y_i in self.end_tagidx_set:
pred_count += 1
P = corr_count / (float(pred_count) + 1e-6)
R = corr_count / (float(truth_count) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)

return {"P": P, 'R': R, 'F': F}


class SNLIEvaluator(Evaluator):
def __init__(self):
super(SNLIEvaluator, self).__init__()

def __call__(self, predict, truth):
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
truth = [t['truth'] for t in truth]
y_true = torch.cat(truth, dim=0).view(-1)
acc = float(torch.sum(y_pred == y_true)) / y_true.size(0)
return {"accuracy": acc}


def _conver_numpy(x):
"""convert input data to numpy array

@@ -467,11 +357,11 @@ def _check_data(y_true, y_pred):
type_true, y_true = _label_types(y_true)
type_pred, y_pred = _label_types(y_pred)

type_set = set(['binary', 'multiclass'])
type_set = {'binary', 'multiclass'}
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass', y_true, y_pred

type_set = set(['multiclass-multioutput', 'multilabel'])
type_set = {'multiclass-multioutput', 'multilabel'}
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred



+ 2
- 2
fastNLP/core/predictor.py View File

@@ -23,13 +23,13 @@ class Predictor(object):

:param network: a PyTorch model (cpu)
:param data: a DataSet object.
:return: list of list of strings, [num_examples, tag_seq_length]
:return: list of batch outputs
"""
# turn on the testing mode; clean up the history
self.mode(network, test=True)
batch_output = []

data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)

for batch_x, _ in data_iterator:
with torch.no_grad():


+ 1
- 1
fastNLP/core/sampler.py View File

@@ -55,7 +55,7 @@ class BucketSampler(BaseSampler):

def __call__(self, data_set):

seq_lens = data_set[self.seq_lens_field_name].content
seq_lens = data_set.get_fields()[self.seq_lens_field_name].content
total_sample_num = len(seq_lens)

bucket_indexes = []


+ 0
- 7
fastNLP/core/vocabulary.py View File

@@ -1,12 +1,5 @@
from collections import Counter

def isiterable(p_object):
try:
_ = iter(p_object)
except TypeError:
return False
return True


def check_build_vocab(func):
"""A decorator to make sure the indexing is built before used.


+ 21
- 0
test/core/test_dataset.py View File

@@ -1,3 +1,4 @@
import os
import unittest

from fastNLP.core.dataset import DataSet
@@ -90,6 +91,18 @@ class TestDataSet(unittest.TestCase):
self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])

ds.apply(lambda ins: len(ins["y"]), new_field_name="y")
self.assertEqual(ds.field_arrays["y"].content[0], 2)

res = ds.apply(lambda ins: len(ins["x"]))
self.assertTrue(isinstance(res, list) and len(res) > 0)
self.assertTrue(res[0], 4)

def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3)
self.assertEqual(len(ds), 20)

def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
@@ -132,9 +145,17 @@ class TestDataSet(unittest.TestCase):
dataset.apply(split_sent, new_field_name='words')
# print(dataset)

def test_save_load(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.save("./my_ds.pkl")
self.assertTrue(os.path.exists("./my_ds.pkl"))

ds_1 = DataSet.load("./my_ds.pkl")
os.remove("my_ds.pkl")

class TestDataSetIter(unittest.TestCase):
def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")


+ 22
- 0
test/core/test_fieldarray.py View File

@@ -75,3 +75,25 @@ class TestFieldArray(unittest.TestCase):
indices = [0, 1, 3, 4, 6]
for a, b in zip(fa[indices], x[indices]):
self.assertListEqual(a.tolist(), b.tolist())

def test_append(self):
with self.assertRaises(Exception):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append(0)

with self.assertRaises(Exception):
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
fa.append([1, 2, 3, 4, 5])

with self.assertRaises(Exception):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append([])

with self.assertRaises(Exception):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append(["str", 0, 0, 0, 1.89])

fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])

+ 13
- 0
test/core/test_metrics.py View File

@@ -4,6 +4,7 @@ import numpy as np
import torch

from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score


class TestAccuracyMetric(unittest.TestCase):
@@ -132,3 +133,15 @@ class TestAccuracyMetric(unittest.TestCase):
print(e)
return
self.assertTrue(True, False), "No exception catches."


class TestUsefulFunctions(unittest.TestCase):
# 测试metrics.py中一些看上去挺有用的函数
def test_case_1(self):
# multi-class
_ = accuracy_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)))
_ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)
_ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None)

# 跑通即可

+ 29
- 1
test/core/test_predictor.py View File

@@ -1,6 +1,34 @@
import unittest

import numpy as np
import torch

from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.predictor import Predictor
from fastNLP.modules.encoder.linear import Linear


def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set


class TestPredictor(unittest.TestCase):
def test(self):
pass
predictor = Predictor()
model = Linear(2, 1)
data = prepare_fake_dataset()
data.set_input("x")
ans = predictor.predict(model, data)
self.assertEqual(len(ans), 2000)
self.assertTrue(isinstance(ans[0], torch.Tensor))

+ 11
- 1
test/core/test_sampler.py View File

@@ -1,9 +1,11 @@
import random
import unittest

import torch

from fastNLP.core.dataset import DataSet
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \
k_means_1d, k_means_bucketing, simple_sort_bucketing
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler


class TestSampler(unittest.TestCase):
@@ -40,3 +42,11 @@ class TestSampler(unittest.TestCase):
def test_simple_sort_bucketing(self):
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10])
assert len(_) == 10

def test_BucketSampler(self):
sampler = BucketSampler(num_buckets=3, batch_size=16, seq_lens_field_name="seq_len")
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len")
indices = sampler(data_set)
self.assertEqual(len(indices), 10)
# 跑通即可,不验证效果

+ 11
- 12
test/core/test_trainer.py View File

@@ -30,7 +30,7 @@ def prepare_fake_dataset():


def prepare_fake_dataset2(*args, size=100):
ys = np.random.randint(4, size=100)
ys = np.random.randint(4, size=100, dtype=np.int64)
data = {'y': ys}
for arg in args:
data[arg] = np.random.randn(size, 5)
@@ -213,12 +213,12 @@ class TrainerTestGround(unittest.TestCase):
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2', 'y')
dataset.set_target('x1')
dataset.set_target('x1', 'x2')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
def forward(self, x1, x2):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
@@ -226,15 +226,14 @@ class TrainerTestGround(unittest.TestCase):
return {'pred': x}

model = Model()
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
losser=CrossEntropyLoss(),
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2)
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
losser=CrossEntropyLoss(),
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2)

def test_case2(self):
# check metrics Wrong


+ 768
- 0
tutorials/fastnlp_tutorial_1204.ipynb View File

@@ -0,0 +1,768 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastNLP上手教程\n",
"-------\n",
"\n",
"fastNLP提供方便的数据预处理,训练和测试模型的功能"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('/Users/yh/Desktop/fastNLP/fastNLP/')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DataSet & Instance\n",
"------\n",
"\n",
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n",
"\n",
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8529\n"
]
}
],
"source": [
"from fastNLP import DataSet\n",
"from fastNLP import Instance\n",
"\n",
"# 从csv读取数据到DataSet\n",
"dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n",
"print(len(dataset))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1}\n",
"{'raw_sentence': -LRB- Tries -RRB- to parody a genre that 's already a joke in the United States .,\n",
"'label': 1}\n"
]
}
],
"source": [
"# 使用数字索引[k],获取第k个样本\n",
"print(dataset[0])\n",
"\n",
"# 索引也可以是负数\n",
"print(dataset[-3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instance\n",
"Instance表示一个样本,由一个或多个field(域,属性,特征)组成,每个field有名字和值。\n",
"\n",
"在初始化Instance时即可定义它包含的域,使用 \"field_name=field_value\"的写法。"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'raw_sentence': fake data,\n",
"'label': 0}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# DataSet.append(Instance)加入新数据\n",
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
"dataset[-1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DataSet.apply方法\n",
"数据预处理利器"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1}\n"
]
}
],
"source": [
"# 将所有数字转为小写\n",
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1}\n"
]
}
],
"source": [
"# label转int\n",
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "Cannot create FieldArray with an empty list.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-9-d70cf5545af4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit_sent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'raw_sentence'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_sent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_field_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'words'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, func, new_field_name, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m **extra_param)\n\u001b[1;32m 266\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_field_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mextra_param\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36madd_field\u001b[0;34m(self, name, fields, padding_val, is_input, is_target)\u001b[0m\n\u001b[1;32m 158\u001b[0m f\"Dataset size {len(self)} != field size {len(fields)}\")\n\u001b[1;32m 159\u001b[0m self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,\n\u001b[0;32m--> 160\u001b[0;31m is_input=is_input)\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, content, padding_val, is_target, is_input)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_input\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36mis_input\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mis_input\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_np_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# content is a 1-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot create FieldArray with an empty list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Cannot create FieldArray with an empty list."
]
}
],
"source": [
"# 使用空格分割句子\n",
"def split_sent(ins):\n",
" return ins['raw_sentence'].split()\n",
"dataset.apply(split_sent, new_field_name='words')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n",
"'label': 1,\n",
"'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n",
"'seq_len': 37}\n"
]
}
],
"source": [
"# 增加长度信息\n",
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n",
"print(dataset[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DataSet.drop\n",
"筛选数据"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"38\n"
]
}
],
"source": [
"dataset.drop(lambda x: x['seq_len'] <= 3)\n",
"print(len(dataset))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 配置DataSet\n",
"1. 哪些域是特征,哪些域是标签\n",
"2. 切分训练集/验证集"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"# 设置DataSet中,哪些field要转为tensor\n",
"\n",
"# set target,loss或evaluate中的golden,计算loss,模型评估时使用\n",
"dataset.set_target(\"label\")\n",
"# set input,模型forward时使用\n",
"dataset.set_input(\"words\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"27\n",
"11"
]
}
],
"source": [
"# 分出测试集、训练集\n",
"\n",
"test_data, train_data = dataset.split(0.3)\n",
"print(len(test_data))\n",
"print(len(train_data))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Vocabulary\n",
"------\n",
"\n",
"fastNLP中的Vocabulary轻松构建词表,将词转成数字"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n",
"'label': 2,\n",
"'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n",
"'seq_len': 25}\n"
]
}
],
"source": [
"from fastNLP import Vocabulary\n",
"\n",
"# 构建词表, Vocabulary.add(word)\n",
"vocab = Vocabulary(min_freq=2)\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
"vocab.build_vocab()\n",
"\n",
"# index句子, Vocabulary.to_index(word)\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n",
"\n",
"\n",
"print(test_data[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model\n",
"定义一个PyTorch模型"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CNNText(\n",
" (embed): Embedding(\n",
" (embed): Embedding(32, 50, padding_idx=0)\n",
" (dropout): Dropout(p=0.0)\n",
" )\n",
" (conv_pool): ConvMaxpool(\n",
" (convs): ModuleList(\n",
" (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n",
" (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n",
" (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1)\n",
" (fc): Linear(\n",
" (linear): Linear(in_features=12, out_features=5, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from fastNLP.models import CNNText\n",
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n",
"\n",
"注意两点:\n",
"1. forward参数名字叫**word_seq**,请记住。\n",
"2. forward的返回值是一个**dict**,其中有个key的名字叫**output**。\n",
"\n",
"```Python\n",
" def forward(self, word_seq):\n",
" \"\"\"\n",
"\n",
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
" :return output: dict of torch.LongTensor, [batch_size, num_classes]\n",
" \"\"\"\n",
" x = self.embed(word_seq) # [N,L] -> [N,L,C]\n",
" x = self.conv_pool(x) # [N,L,C] -> [N,C]\n",
" x = self.dropout(x)\n",
" x = self.fc(x) # [N,C] -> [N, N_class]\n",
" return {'output': x}\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这是上述模型的predict方法,是用来直接输出该任务的预测结果,与forward目的不同。\n",
"\n",
"注意两点:\n",
"1. predict参数名也叫**word_seq**。\n",
"2. predict的返回值是也一个**dict**,其中有个key的名字叫**predict**。\n",
"\n",
"```\n",
" def predict(self, word_seq):\n",
" \"\"\"\n",
"\n",
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
" :return predict: dict of torch.LongTensor, [batch_size, seq_len]\n",
" \"\"\"\n",
" output = self(word_seq)\n",
" _, predict = output['output'].max(dim=1)\n",
" return {'predict': predict}\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Trainer & Tester\n",
"------\n",
"\n",
"使用fastNLP的Trainer训练模型"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer\n",
"from copy import deepcopy\n",
"from fastNLP.core.losses import CrossEntropyLoss\n",
"from fastNLP.core.metrics import AccuracyMetric\n",
"\n",
"\n",
"# 更改DataSet中对应field的名称,与模型的forward的参数名一致\n",
"# 因为forward的参数叫word_seq, 所以要把原本叫words的field改名为word_seq\n",
"# 这里的演示是让你了解这种**命名规则**\n",
"train_data.rename_field('words', 'word_seq')\n",
"test_data.rename_field('words', 'word_seq')\n",
"\n",
"# 顺便把label换名为label_seq\n",
"train_data.rename_field('label', 'label_seq')\n",
"test_data.rename_field('label', 'label_seq')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### loss\n",
"训练模型需要提供一个损失函数\n",
"\n",
"下面提供了一个在分类问题中常用的交叉熵损失。注意它的**初始化参数**。\n",
"\n",
"pred参数对应的是模型的forward返回的dict的一个key的名字,这里是\"output\"。\n",
"\n",
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"loss = CrossEntropyLoss(pred=\"output\", target=\"label_seq\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metric\n",
"定义评价指标\n",
"\n",
"这里使用准确率。参数的“命名规则”跟上面类似。\n",
"\n",
"pred参数对应的是模型的predict方法返回的dict的一个key的名字,这里是\"predict\"。\n",
"\n",
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"metric = AccuracyMetric(pred=\"predict\", target=\"label_seq\")"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-04 22:51:24\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.407407\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.518519\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.481481\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.592593\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
}
],
"source": [
"# 实例化Trainer,传入模型和数据,进行训练\n",
"# 先在test_data拟合\n",
"copy_model = deepcopy(model)\n",
"overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n",
" losser=loss,\n",
" metrics=metric,\n",
" save_path=None,\n",
" batch_size=32,\n",
" n_epochs=5)\n",
"overfit_trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-04 22:52:01\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.222222\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.259259\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.296296\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.259259\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train finished!\n"
]
}
],
"source": [
"# 用train_data训练,在test_data验证\n",
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n",
" losser=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n",
" metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n",
" save_path=None,\n",
" batch_size=32,\n",
" n_epochs=5)\n",
"trainer.train()\n",
"print('Train finished!')"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tester] \n",
"AccuracyMetric: acc=0.259259\n",
"{'AccuracyMetric': {'acc': 0.259259}}\n"
]
}
],
"source": [
"# 调用Tester在test_data上评价效果\n",
"from fastNLP import Tester\n",
"\n",
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n",
" batch_size=4)\n",
"acc = tester.test()\n",
"print(acc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Loading…
Cancel
Save