diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index dfe35f77..b16fe165 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -3,7 +3,6 @@ from .dataset import DataSet from .fieldarray import FieldArray from .instance import Instance from .losses import Loss -from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator from .optimizer import Optimizer from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .tester import Tester diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index a4d7a8ae..1e7d56fd 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -62,8 +62,8 @@ class Batch(object): def to_tensor(batch, dtype): - if dtype in (np.int8, np.int16, np.int32, np.int64): + if dtype in (int, np.int8, np.int16, np.int32, np.int64): batch = torch.LongTensor(batch) - if dtype in (np.float32, np.float64): + if dtype in (float, np.float32, np.float64): batch = torch.FloatTensor(batch) return batch diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index cdca4356..3dbea8eb 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,4 +1,5 @@ import _pickle as pickle + import numpy as np from fastNLP.core.fieldarray import FieldArray @@ -66,10 +67,12 @@ class DataSet(object): def __init__(self, dataset, idx): self.dataset = dataset self.idx = idx + def __getitem__(self, item): assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) return self.dataset.field_arrays[item][self.idx] + def __repr__(self): return self.dataset[self.idx].__repr__() @@ -339,6 +342,6 @@ class DataSet(object): pickle.dump(self, f) @staticmethod - def load(self, path): + def load(path): with open(path, 'rb') as f: return pickle.load(f) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 3bbbf9e2..2a9e89cd 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -69,9 +69,20 @@ class LossBase(object): f"positional argument.).") def _fast_param_map(self, pred_dict, target_dict): + """ + + Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. + such as pred_dict has one element, target_dict has one element + :param pred_dict: + :param target_dict: + :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. + """ + fast_param = {} if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: - return tuple(pred_dict.values())[0], tuple(target_dict.values())[0] - return None + fast_param['pred'] = list(pred_dict.values())[0] + fast_param['target'] = list(pred_dict.values())[0] + return fast_param + return fast_param def __call__(self, pred_dict, target_dict, check=False): """ @@ -81,8 +92,8 @@ class LossBase(object): :return: """ fast_param = self._fast_param_map(pred_dict, target_dict) - if fast_param is not None: - loss = self.get_loss(*fast_param) + if fast_param: + loss = self.get_loss(**fast_param) return loss if not self._checked: diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index c17d408b..f8279d0a 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -82,7 +82,9 @@ class MetricBase(object): """ fast_param = {} if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: - return pred_dict.values[0] and target_dict.values[0] + fast_param['pred'] = list(pred_dict.values())[0] + fast_param['target'] = list(pred_dict.values())[0] + return fast_param return fast_param def __call__(self, pred_dict, target_dict): @@ -304,118 +306,6 @@ def _prepare_metrics(metrics): return _metrics -class Evaluator(object): - def __init__(self): - pass - - def __call__(self, predict, truth): - """ - - :param predict: list of tensors, the network outputs from all batches. - :param truth: list of dict, the ground truths from all batch_y. - :return: - """ - raise NotImplementedError - - -class ClassifyEvaluator(Evaluator): - def __init__(self): - super(ClassifyEvaluator, self).__init__() - - def __call__(self, predict, truth): - y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] - y_prob = torch.cat(y_prob, dim=0) - y_pred = torch.argmax(y_prob, dim=-1) - y_true = torch.cat(truth, dim=0) - acc = float(torch.sum(y_pred == y_true)) / len(y_true) - return {"accuracy": acc} - - -class SeqLabelEvaluator(Evaluator): - def __init__(self): - super(SeqLabelEvaluator, self).__init__() - - def __call__(self, predict, truth, **_): - """ - - :param predict: list of List, the network outputs from all batches. - :param truth: list of dict, the ground truths from all batch_y. - :return accuracy: - """ - total_correct, total_count = 0., 0. - for x, y in zip(predict, truth): - x = torch.tensor(x) - y = y.to(x) # make sure they are in the same device - mask = (y > 0) - correct = torch.sum(((x == y) * mask).long()) - total_correct += float(correct) - total_count += float(torch.sum(mask.long())) - accuracy = total_correct / total_count - return {"accuracy": float(accuracy)} - - -class SeqLabelEvaluator2(Evaluator): - # 上面的evaluator应该是错误的 - def __init__(self, seq_lens_field_name='word_seq_origin_len'): - super(SeqLabelEvaluator2, self).__init__() - self.end_tagidx_set = set() - self.seq_lens_field_name = seq_lens_field_name - - def __call__(self, predict, truth, **_): - """ - - :param predict: list of batch, the network outputs from all batches. - :param truth: list of dict, the ground truths from all batch_y. - :return accuracy: - """ - seq_lens = _[self.seq_lens_field_name] - corr_count = 0 - pred_count = 0 - truth_count = 0 - for x, y, seq_len in zip(predict, truth, seq_lens): - x = x.cpu().numpy() - y = y.cpu().numpy() - for idx, s_l in enumerate(seq_len): - x_ = x[idx] - y_ = y[idx] - x_ = x_[:s_l] - y_ = y_[:s_l] - flag = True - start = 0 - for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)): - if x_i in self.end_tagidx_set: - truth_count += 1 - for j in range(start, idx_i + 1): - if y_[j] != x_[j]: - flag = False - break - if flag: - corr_count += 1 - flag = True - start = idx_i + 1 - if y_i in self.end_tagidx_set: - pred_count += 1 - P = corr_count / (float(pred_count) + 1e-6) - R = corr_count / (float(truth_count) + 1e-6) - F = 2 * P * R / (P + R + 1e-6) - - return {"P": P, 'R': R, 'F': F} - - -class SNLIEvaluator(Evaluator): - def __init__(self): - super(SNLIEvaluator, self).__init__() - - def __call__(self, predict, truth): - y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] - y_prob = torch.cat(y_prob, dim=0) - y_pred = torch.argmax(y_prob, dim=-1) - truth = [t['truth'] for t in truth] - y_true = torch.cat(truth, dim=0).view(-1) - acc = float(torch.sum(y_pred == y_true)) / y_true.size(0) - return {"accuracy": acc} - - def _conver_numpy(x): """convert input data to numpy array @@ -467,11 +357,11 @@ def _check_data(y_true, y_pred): type_true, y_true = _label_types(y_true) type_pred, y_pred = _label_types(y_pred) - type_set = set(['binary', 'multiclass']) + type_set = {'binary', 'multiclass'} if type_true in type_set and type_pred in type_set: return type_true if type_true == type_pred else 'multiclass', y_true, y_pred - type_set = set(['multiclass-multioutput', 'multilabel']) + type_set = {'multiclass-multioutput', 'multilabel'} if type_true in type_set and type_pred in type_set: return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 7cde4844..9ce1d792 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -23,13 +23,13 @@ class Predictor(object): :param network: a PyTorch model (cpu) :param data: a DataSet object. - :return: list of list of strings, [num_examples, tag_seq_length] + :return: list of batch outputs """ # turn on the testing mode; clean up the history self.mode(network, test=True) batch_output = [] - data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) + data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) for batch_x, _ in data_iterator: with torch.no_grad(): diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index f5e83c6b..d568acf3 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -55,7 +55,7 @@ class BucketSampler(BaseSampler): def __call__(self, data_set): - seq_lens = data_set[self.seq_lens_field_name].content + seq_lens = data_set.get_fields()[self.seq_lens_field_name].content total_sample_num = len(seq_lens) bucket_indexes = [] diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 14577635..e8cc0e22 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,12 +1,5 @@ from collections import Counter -def isiterable(p_object): - try: - _ = iter(p_object) - except TypeError: - return False - return True - def check_build_vocab(func): """A decorator to make sure the indexing is built before used. diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 697bcd78..493a740c 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -1,3 +1,4 @@ +import os import unittest from fastNLP.core.dataset import DataSet @@ -90,6 +91,18 @@ class TestDataSet(unittest.TestCase): self.assertTrue("rx" in ds.field_arrays) self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) + ds.apply(lambda ins: len(ins["y"]), new_field_name="y") + self.assertEqual(ds.field_arrays["y"].content[0], 2) + + res = ds.apply(lambda ins: len(ins["x"])) + self.assertTrue(isinstance(res, list) and len(res) > 0) + self.assertTrue(res[0], 4) + + def test_drop(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) + ds.drop(lambda ins: len(ins["y"]) < 3) + self.assertEqual(len(ds), 20) + def test_contains(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) self.assertTrue("x" in ds) @@ -132,9 +145,17 @@ class TestDataSet(unittest.TestCase): dataset.apply(split_sent, new_field_name='words') # print(dataset) + def test_save_load(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + ds.save("./my_ds.pkl") + self.assertTrue(os.path.exists("./my_ds.pkl")) + + ds_1 = DataSet.load("./my_ds.pkl") + os.remove("my_ds.pkl") class TestDataSetIter(unittest.TestCase): def test__repr__(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) for iter in ds: self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") + diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index c22bac5b..c0b8a592 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -75,3 +75,25 @@ class TestFieldArray(unittest.TestCase): indices = [0, 1, 3, 4, 6] for a, b in zip(fa[indices], x[indices]): self.assertListEqual(a.tolist(), b.tolist()) + + def test_append(self): + with self.assertRaises(Exception): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append(0) + + with self.assertRaises(Exception): + fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) + fa.append([1, 2, 3, 4, 5]) + + with self.assertRaises(Exception): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append([]) + + with self.assertRaises(Exception): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append(["str", 0, 0, 0, 1.89]) + + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) + self.assertEqual(len(fa), 3) + self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 76352aba..9286a26f 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -4,6 +4,7 @@ import numpy as np import torch from fastNLP.core.metrics import AccuracyMetric +from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score class TestAccuracyMetric(unittest.TestCase): @@ -132,3 +133,15 @@ class TestAccuracyMetric(unittest.TestCase): print(e) return self.assertTrue(True, False), "No exception catches." + + +class TestUsefulFunctions(unittest.TestCase): + # 测试metrics.py中一些看上去挺有用的函数 + def test_case_1(self): + # multi-class + _ = accuracy_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1))) + _ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) + _ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) + _ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) + + # 跑通即可 diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index 7b4f5da9..8be5f289 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -1,6 +1,34 @@ import unittest +import numpy as np +import torch + +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance +from fastNLP.core.predictor import Predictor +from fastNLP.modules.encoder.linear import Linear + + +def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + return data_set + class TestPredictor(unittest.TestCase): def test(self): - pass + predictor = Predictor() + model = Linear(2, 1) + data = prepare_fake_dataset() + data.set_input("x") + ans = predictor.predict(model, data) + self.assertEqual(len(ans), 2000) + self.assertTrue(isinstance(ans[0], torch.Tensor)) diff --git a/test/core/test_sampler.py b/test/core/test_sampler.py index 5da0e6db..b23af470 100644 --- a/test/core/test_sampler.py +++ b/test/core/test_sampler.py @@ -1,9 +1,11 @@ +import random import unittest import torch +from fastNLP.core.dataset import DataSet from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ - k_means_1d, k_means_bucketing, simple_sort_bucketing + k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler class TestSampler(unittest.TestCase): @@ -40,3 +42,11 @@ class TestSampler(unittest.TestCase): def test_simple_sort_bucketing(self): _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) assert len(_) == 10 + + def test_BucketSampler(self): + sampler = BucketSampler(num_buckets=3, batch_size=16, seq_lens_field_name="seq_len") + data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10}) + data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len") + indices = sampler(data_set) + self.assertEqual(len(indices), 10) + # 跑通即可,不验证效果 diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index e74ec4b5..38fb6e0e 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -30,7 +30,7 @@ def prepare_fake_dataset(): def prepare_fake_dataset2(*args, size=100): - ys = np.random.randint(4, size=100) + ys = np.random.randint(4, size=100, dtype=np.int64) data = {'y': ys} for arg in args: data[arg] = np.random.randn(size, 5) @@ -213,12 +213,12 @@ class TrainerTestGround(unittest.TestCase): dataset = prepare_fake_dataset2('x1', 'x_unused') dataset.rename_field('x_unused', 'x2') dataset.set_input('x1', 'x2', 'y') - dataset.set_target('x1') + dataset.set_target('x1', 'x2') class Model(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(5, 4) - def forward(self, x1, x2, y): + def forward(self, x1, x2): x1 = self.fc(x1) x2 = self.fc(x2) x = x1 + x2 @@ -226,15 +226,14 @@ class TrainerTestGround(unittest.TestCase): return {'pred': x} model = Model() - with self.assertRaises(NameError): - trainer = Trainer( - train_data=dataset, - model=model, - dev_data=dataset, - losser=CrossEntropyLoss(), - metrics=AccuracyMetric(), - use_tqdm=False, - print_every=2) + trainer = Trainer( + train_data=dataset, + model=model, + dev_data=dataset, + losser=CrossEntropyLoss(), + metrics=AccuracyMetric(), + use_tqdm=False, + print_every=2) def test_case2(self): # check metrics Wrong diff --git a/tutorials/fastnlp_tutorial_1204.ipynb b/tutorials/fastnlp_tutorial_1204.ipynb new file mode 100644 index 00000000..1fa1adca --- /dev/null +++ b/tutorials/fastnlp_tutorial_1204.ipynb @@ -0,0 +1,768 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "fastNLP上手教程\n", + "-------\n", + "\n", + "fastNLP提供方便的数据预处理,训练和测试模型的功能" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('/Users/yh/Desktop/fastNLP/fastNLP/')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "DataSet & Instance\n", + "------\n", + "\n", + "fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n", + "\n", + "有一些read_*方法,可以轻松从文件读取数据,存成DataSet。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "8529\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "from fastNLP import Instance\n", + "\n", + "# 从csv读取数据到DataSet\n", + "dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", + "print(len(dataset))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", + "'label': 1}\n", + "{'raw_sentence': -LRB- Tries -RRB- to parody a genre that 's already a joke in the United States .,\n", + "'label': 1}\n" + ] + } + ], + "source": [ + "# 使用数字索引[k],获取第k个样本\n", + "print(dataset[0])\n", + "\n", + "# 索引也可以是负数\n", + "print(dataset[-3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instance\n", + "Instance表示一个样本,由一个或多个field(域,属性,特征)组成,每个field有名字和值。\n", + "\n", + "在初始化Instance时即可定义它包含的域,使用 \"field_name=field_value\"的写法。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'raw_sentence': fake data,\n", + "'label': 0}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# DataSet.append(Instance)加入新数据\n", + "dataset.append(Instance(raw_sentence='fake data', label='0'))\n", + "dataset[-1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DataSet.apply方法\n", + "数据预处理利器" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", + "'label': 1}\n" + ] + } + ], + "source": [ + "# 将所有数字转为小写\n", + "dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", + "'label': 1}\n" + ] + } + ], + "source": [ + "# label转int\n", + "dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "Cannot create FieldArray with an empty list.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit_sent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mins\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'raw_sentence'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_sent\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_field_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'words'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, func, new_field_name, **kwargs)\u001b[0m\n\u001b[1;32m 265\u001b[0m **extra_param)\n\u001b[1;32m 266\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 267\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_field_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfields\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mextra_param\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 268\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 269\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/dataset.py\u001b[0m in \u001b[0;36madd_field\u001b[0;34m(self, name, fields, padding_val, is_input, is_target)\u001b[0m\n\u001b[1;32m 158\u001b[0m f\"Dataset size {len(self)} != field size {len(fields)}\")\n\u001b[1;32m 159\u001b[0m self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,\n\u001b[0;32m--> 160\u001b[0;31m is_input=is_input)\n\u001b[0m\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdelete_field\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, name, content, padding_val, is_target, is_input)\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_input\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 38\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 39\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_target\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_target\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mis_target\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36mis_input\u001b[0;34m(self, value)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mis_input\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mis_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_to_np_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpytype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_input\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# strict check 2-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Please provide 2-D list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 73\u001b[0;31m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_type_detection\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 74\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype_set\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m2\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mint\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mfloat\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtype_set\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 75\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0mfloat\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/Desktop/fastNLP/fastNLP/fastNLP/core/fieldarray.py\u001b[0m in \u001b[0;36m_type_detection\u001b[0;34m(self, content)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;31m# content is a 1-D list\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Cannot create FieldArray with an empty list.\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0mtype_set\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mitem\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mcontent\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: Cannot create FieldArray with an empty list." + ] + } + ], + "source": [ + "# 使用空格分割句子\n", + "def split_sent(ins):\n", + " return ins['raw_sentence'].split()\n", + "dataset.apply(split_sent, new_field_name='words')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n", + "'label': 1,\n", + "'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n", + "'seq_len': 37}\n" + ] + } + ], + "source": [ + "# 增加长度信息\n", + "dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DataSet.drop\n", + "筛选数据" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "38\n" + ] + } + ], + "source": [ + "dataset.drop(lambda x: x['seq_len'] <= 3)\n", + "print(len(dataset))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 配置DataSet\n", + "1. 哪些域是特征,哪些域是标签\n", + "2. 切分训练集/验证集" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# 设置DataSet中,哪些field要转为tensor\n", + "\n", + "# set target,loss或evaluate中的golden,计算loss,模型评估时使用\n", + "dataset.set_target(\"label\")\n", + "# set input,模型forward时使用\n", + "dataset.set_input(\"words\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "27\n", + "11" + ] + } + ], + "source": [ + "# 分出测试集、训练集\n", + "\n", + "test_data, train_data = dataset.split(0.3)\n", + "print(len(test_data))\n", + "print(len(train_data))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Vocabulary\n", + "------\n", + "\n", + "fastNLP中的Vocabulary轻松构建词表,将词转成数字" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n", + "'label': 2,\n", + "'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n", + "'seq_len': 25}\n" + ] + } + ], + "source": [ + "from fastNLP import Vocabulary\n", + "\n", + "# 构建词表, Vocabulary.add(word)\n", + "vocab = Vocabulary(min_freq=2)\n", + "train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", + "vocab.build_vocab()\n", + "\n", + "# index句子, Vocabulary.to_index(word)\n", + "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", + "test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", + "\n", + "\n", + "print(test_data[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model\n", + "定义一个PyTorch模型" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CNNText(\n", + " (embed): Embedding(\n", + " (embed): Embedding(32, 50, padding_idx=0)\n", + " (dropout): Dropout(p=0.0)\n", + " )\n", + " (conv_pool): ConvMaxpool(\n", + " (convs): ModuleList(\n", + " (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n", + " (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n", + " (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1)\n", + " (fc): Linear(\n", + " (linear): Linear(in_features=12, out_features=5, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP.models import CNNText\n", + "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n", + "\n", + "注意两点:\n", + "1. forward参数名字叫**word_seq**,请记住。\n", + "2. forward的返回值是一个**dict**,其中有个key的名字叫**output**。\n", + "\n", + "```Python\n", + " def forward(self, word_seq):\n", + " \"\"\"\n", + "\n", + " :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", + " :return output: dict of torch.LongTensor, [batch_size, num_classes]\n", + " \"\"\"\n", + " x = self.embed(word_seq) # [N,L] -> [N,L,C]\n", + " x = self.conv_pool(x) # [N,L,C] -> [N,C]\n", + " x = self.dropout(x)\n", + " x = self.fc(x) # [N,C] -> [N, N_class]\n", + " return {'output': x}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "这是上述模型的predict方法,是用来直接输出该任务的预测结果,与forward目的不同。\n", + "\n", + "注意两点:\n", + "1. predict参数名也叫**word_seq**。\n", + "2. predict的返回值是也一个**dict**,其中有个key的名字叫**predict**。\n", + "\n", + "```\n", + " def predict(self, word_seq):\n", + " \"\"\"\n", + "\n", + " :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", + " :return predict: dict of torch.LongTensor, [batch_size, seq_len]\n", + " \"\"\"\n", + " output = self(word_seq)\n", + " _, predict = output['output'].max(dim=1)\n", + " return {'predict': predict}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Trainer & Tester\n", + "------\n", + "\n", + "使用fastNLP的Trainer训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer\n", + "from copy import deepcopy\n", + "from fastNLP.core.losses import CrossEntropyLoss\n", + "from fastNLP.core.metrics import AccuracyMetric\n", + "\n", + "\n", + "# 更改DataSet中对应field的名称,与模型的forward的参数名一致\n", + "# 因为forward的参数叫word_seq, 所以要把原本叫words的field改名为word_seq\n", + "# 这里的演示是让你了解这种**命名规则**\n", + "train_data.rename_field('words', 'word_seq')\n", + "test_data.rename_field('words', 'word_seq')\n", + "\n", + "# 顺便把label换名为label_seq\n", + "train_data.rename_field('label', 'label_seq')\n", + "test_data.rename_field('label', 'label_seq')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### loss\n", + "训练模型需要提供一个损失函数\n", + "\n", + "下面提供了一个在分类问题中常用的交叉熵损失。注意它的**初始化参数**。\n", + "\n", + "pred参数对应的是模型的forward返回的dict的一个key的名字,这里是\"output\"。\n", + "\n", + "target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "loss = CrossEntropyLoss(pred=\"output\", target=\"label_seq\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Metric\n", + "定义评价指标\n", + "\n", + "这里使用准确率。参数的“命名规则”跟上面类似。\n", + "\n", + "pred参数对应的是模型的predict方法返回的dict的一个key的名字,这里是\"predict\"。\n", + "\n", + "target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "metric = AccuracyMetric(pred=\"predict\", target=\"label_seq\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training epochs started 2018-12-04 22:51:24\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.407407\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.518519\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.481481\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.592593\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + } + ], + "source": [ + "# 实例化Trainer,传入模型和数据,进行训练\n", + "# 先在test_data拟合\n", + "copy_model = deepcopy(model)\n", + "overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n", + " losser=loss,\n", + " metrics=metric,\n", + " save_path=None,\n", + " batch_size=32,\n", + " n_epochs=5)\n", + "overfit_trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training epochs started 2018-12-04 22:52:01\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.296296\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.222222\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.259259\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.296296\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.259259\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train finished!\n" + ] + } + ], + "source": [ + "# 用train_data训练,在test_data验证\n", + "trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n", + " losser=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n", + " metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n", + " save_path=None,\n", + " batch_size=32,\n", + " n_epochs=5)\n", + "trainer.train()\n", + "print('Train finished!')" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[tester] \n", + "AccuracyMetric: acc=0.259259\n", + "{'AccuracyMetric': {'acc': 0.259259}}\n" + ] + } + ], + "source": [ + "# 调用Tester在test_data上评价效果\n", + "from fastNLP import Tester\n", + "\n", + "tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n", + " batch_size=4)\n", + "acc = tester.test()\n", + "print(acc)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}