From f26f11608baa202ab18ee627e75e4229a62b6d06 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Tue, 4 Dec 2018 22:57:26 +0800 Subject: [PATCH] =?UTF-8?q?*=20=E6=9B=B4=E6=96=B0=E6=95=99=E7=A8=8B?= =?UTF-8?q?=EF=BC=8C=E6=94=BE=E5=9C=A8=E5=9C=A8./tutorial=20*=20remove=20u?= =?UTF-8?q?nused=20codes=20in=20metrics.py=20*=20add=20tests=20for=20DataS?= =?UTF-8?q?et=20*=20add=20tests=20for=20FieldArray=20*=20add=20tests=20for?= =?UTF-8?q?=20metrics.py=20*=20fix=20predictor,=20add=20tests=20for=20pred?= =?UTF-8?q?ictor=20*=20fix=20bucket=20sampler,=20add=20tests=20for=20bucke?= =?UTF-8?q?t=20sampler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 1 - fastNLP/core/dataset.py | 5 +- fastNLP/core/metrics.py | 116 +-- fastNLP/core/predictor.py | 4 +- fastNLP/core/sampler.py | 2 +- fastNLP/core/vocabulary.py | 7 - test/core/test_dataset.py | 23 + test/core/test_fieldarray.py | 22 + test/core/test_metrics.py | 13 + test/core/test_predictor.py | 30 +- test/core/test_sampler.py | 12 +- tutorials/fastnlp_tutorial_1204.ipynb | 1209 +++++++++++++++++++++++++ 12 files changed, 1316 insertions(+), 128 deletions(-) create mode 100644 tutorials/fastnlp_tutorial_1204.ipynb diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index dfe35f77..b16fe165 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -3,7 +3,6 @@ from .dataset import DataSet from .fieldarray import FieldArray from .instance import Instance from .losses import Loss -from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator from .optimizer import Optimizer from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .tester import Tester diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index cdca4356..3dbea8eb 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,4 +1,5 @@ import _pickle as pickle + import numpy as np from fastNLP.core.fieldarray import FieldArray @@ -66,10 +67,12 @@ class DataSet(object): def __init__(self, dataset, idx): self.dataset = dataset self.idx = idx + def __getitem__(self, item): assert item in self.dataset.field_arrays, "no such field:{} in instance {}".format(item, self.dataset[self.idx]) assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) return self.dataset.field_arrays[item][self.idx] + def __repr__(self): return self.dataset[self.idx].__repr__() @@ -339,6 +342,6 @@ class DataSet(object): pickle.dump(self, f) @staticmethod - def load(self, path): + def load(path): with open(path, 'rb') as f: return pickle.load(f) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index c17d408b..5d808f6a 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -304,118 +304,6 @@ def _prepare_metrics(metrics): return _metrics -class Evaluator(object): - def __init__(self): - pass - - def __call__(self, predict, truth): - """ - - :param predict: list of tensors, the network outputs from all batches. - :param truth: list of dict, the ground truths from all batch_y. - :return: - """ - raise NotImplementedError - - -class ClassifyEvaluator(Evaluator): - def __init__(self): - super(ClassifyEvaluator, self).__init__() - - def __call__(self, predict, truth): - y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] - y_prob = torch.cat(y_prob, dim=0) - y_pred = torch.argmax(y_prob, dim=-1) - y_true = torch.cat(truth, dim=0) - acc = float(torch.sum(y_pred == y_true)) / len(y_true) - return {"accuracy": acc} - - -class SeqLabelEvaluator(Evaluator): - def __init__(self): - super(SeqLabelEvaluator, self).__init__() - - def __call__(self, predict, truth, **_): - """ - - :param predict: list of List, the network outputs from all batches. - :param truth: list of dict, the ground truths from all batch_y. - :return accuracy: - """ - total_correct, total_count = 0., 0. - for x, y in zip(predict, truth): - x = torch.tensor(x) - y = y.to(x) # make sure they are in the same device - mask = (y > 0) - correct = torch.sum(((x == y) * mask).long()) - total_correct += float(correct) - total_count += float(torch.sum(mask.long())) - accuracy = total_correct / total_count - return {"accuracy": float(accuracy)} - - -class SeqLabelEvaluator2(Evaluator): - # 上面的evaluator应该是错误的 - def __init__(self, seq_lens_field_name='word_seq_origin_len'): - super(SeqLabelEvaluator2, self).__init__() - self.end_tagidx_set = set() - self.seq_lens_field_name = seq_lens_field_name - - def __call__(self, predict, truth, **_): - """ - - :param predict: list of batch, the network outputs from all batches. - :param truth: list of dict, the ground truths from all batch_y. - :return accuracy: - """ - seq_lens = _[self.seq_lens_field_name] - corr_count = 0 - pred_count = 0 - truth_count = 0 - for x, y, seq_len in zip(predict, truth, seq_lens): - x = x.cpu().numpy() - y = y.cpu().numpy() - for idx, s_l in enumerate(seq_len): - x_ = x[idx] - y_ = y[idx] - x_ = x_[:s_l] - y_ = y_[:s_l] - flag = True - start = 0 - for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)): - if x_i in self.end_tagidx_set: - truth_count += 1 - for j in range(start, idx_i + 1): - if y_[j] != x_[j]: - flag = False - break - if flag: - corr_count += 1 - flag = True - start = idx_i + 1 - if y_i in self.end_tagidx_set: - pred_count += 1 - P = corr_count / (float(pred_count) + 1e-6) - R = corr_count / (float(truth_count) + 1e-6) - F = 2 * P * R / (P + R + 1e-6) - - return {"P": P, 'R': R, 'F': F} - - -class SNLIEvaluator(Evaluator): - def __init__(self): - super(SNLIEvaluator, self).__init__() - - def __call__(self, predict, truth): - y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict] - y_prob = torch.cat(y_prob, dim=0) - y_pred = torch.argmax(y_prob, dim=-1) - truth = [t['truth'] for t in truth] - y_true = torch.cat(truth, dim=0).view(-1) - acc = float(torch.sum(y_pred == y_true)) / y_true.size(0) - return {"accuracy": acc} - - def _conver_numpy(x): """convert input data to numpy array @@ -467,11 +355,11 @@ def _check_data(y_true, y_pred): type_true, y_true = _label_types(y_true) type_pred, y_pred = _label_types(y_pred) - type_set = set(['binary', 'multiclass']) + type_set = {'binary', 'multiclass'} if type_true in type_set and type_pred in type_set: return type_true if type_true == type_pred else 'multiclass', y_true, y_pred - type_set = set(['multiclass-multioutput', 'multilabel']) + type_set = {'multiclass-multioutput', 'multilabel'} if type_true in type_set and type_pred in type_set: return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py index 7cde4844..9ce1d792 100644 --- a/fastNLP/core/predictor.py +++ b/fastNLP/core/predictor.py @@ -23,13 +23,13 @@ class Predictor(object): :param network: a PyTorch model (cpu) :param data: a DataSet object. - :return: list of list of strings, [num_examples, tag_seq_length] + :return: list of batch outputs """ # turn on the testing mode; clean up the history self.mode(network, test=True) batch_output = [] - data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False) + data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) for batch_x, _ in data_iterator: with torch.no_grad(): diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index f5e83c6b..d568acf3 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -55,7 +55,7 @@ class BucketSampler(BaseSampler): def __call__(self, data_set): - seq_lens = data_set[self.seq_lens_field_name].content + seq_lens = data_set.get_fields()[self.seq_lens_field_name].content total_sample_num = len(seq_lens) bucket_indexes = [] diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 14577635..e8cc0e22 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,12 +1,5 @@ from collections import Counter -def isiterable(p_object): - try: - _ = iter(p_object) - except TypeError: - return False - return True - def check_build_vocab(func): """A decorator to make sure the indexing is built before used. diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 8ca2ed86..a4deb304 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -1,3 +1,4 @@ +import os import unittest from fastNLP.core.dataset import DataSet @@ -90,6 +91,18 @@ class TestDataSet(unittest.TestCase): self.assertTrue("rx" in ds.field_arrays) self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) + ds.apply(lambda ins: len(ins["y"]), new_field_name="y") + self.assertEqual(ds.field_arrays["y"].content[0], 2) + + res = ds.apply(lambda ins: len(ins["x"])) + self.assertTrue(isinstance(res, list) and len(res) > 0) + self.assertTrue(res[0], 4) + + def test_drop(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) + ds.drop(lambda ins: len(ins["y"]) < 3) + self.assertEqual(len(ds), 20) + def test_contains(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) self.assertTrue("x" in ds) @@ -125,9 +138,19 @@ class TestDataSet(unittest.TestCase): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) + def test_save_load(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + ds.save("./my_ds.pkl") + self.assertTrue(os.path.exists("./my_ds.pkl")) + + ds_1 = DataSet.load("./my_ds.pkl") + os.remove("my_ds.pkl") + # 能跑通就行 + class TestDataSetIter(unittest.TestCase): def test__repr__(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) for iter in ds: self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}") + diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py index c22bac5b..c0b8a592 100644 --- a/test/core/test_fieldarray.py +++ b/test/core/test_fieldarray.py @@ -75,3 +75,25 @@ class TestFieldArray(unittest.TestCase): indices = [0, 1, 3, 4, 6] for a, b in zip(fa[indices], x[indices]): self.assertListEqual(a.tolist(), b.tolist()) + + def test_append(self): + with self.assertRaises(Exception): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append(0) + + with self.assertRaises(Exception): + fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True) + fa.append([1, 2, 3, 4, 5]) + + with self.assertRaises(Exception): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append([]) + + with self.assertRaises(Exception): + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append(["str", 0, 0, 0, 1.89]) + + fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True) + fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) + self.assertEqual(len(fa), 3) + self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 76352aba..9286a26f 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -4,6 +4,7 @@ import numpy as np import torch from fastNLP.core.metrics import AccuracyMetric +from fastNLP.core.metrics import accuracy_score, recall_score, precision_score, f1_score class TestAccuracyMetric(unittest.TestCase): @@ -132,3 +133,15 @@ class TestAccuracyMetric(unittest.TestCase): print(e) return self.assertTrue(True, False), "No exception catches." + + +class TestUsefulFunctions(unittest.TestCase): + # 测试metrics.py中一些看上去挺有用的函数 + def test_case_1(self): + # multi-class + _ = accuracy_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1))) + _ = precision_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) + _ = recall_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) + _ = f1_score(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), average=None) + + # 跑通即可 diff --git a/test/core/test_predictor.py b/test/core/test_predictor.py index 7b4f5da9..8be5f289 100644 --- a/test/core/test_predictor.py +++ b/test/core/test_predictor.py @@ -1,6 +1,34 @@ import unittest +import numpy as np +import torch + +from fastNLP.core.dataset import DataSet +from fastNLP.core.instance import Instance +from fastNLP.core.predictor import Predictor +from fastNLP.modules.encoder.linear import Linear + + +def prepare_fake_dataset(): + mean = np.array([-3, -3]) + cov = np.array([[1, 0], [0, 1]]) + class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) + + mean = np.array([3, 3]) + cov = np.array([[1, 0], [0, 1]]) + class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) + + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + + [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) + return data_set + class TestPredictor(unittest.TestCase): def test(self): - pass + predictor = Predictor() + model = Linear(2, 1) + data = prepare_fake_dataset() + data.set_input("x") + ans = predictor.predict(model, data) + self.assertEqual(len(ans), 2000) + self.assertTrue(isinstance(ans[0], torch.Tensor)) diff --git a/test/core/test_sampler.py b/test/core/test_sampler.py index 5da0e6db..b23af470 100644 --- a/test/core/test_sampler.py +++ b/test/core/test_sampler.py @@ -1,9 +1,11 @@ +import random import unittest import torch +from fastNLP.core.dataset import DataSet from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ - k_means_1d, k_means_bucketing, simple_sort_bucketing + k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler class TestSampler(unittest.TestCase): @@ -40,3 +42,11 @@ class TestSampler(unittest.TestCase): def test_simple_sort_bucketing(self): _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) assert len(_) == 10 + + def test_BucketSampler(self): + sampler = BucketSampler(num_buckets=3, batch_size=16, seq_lens_field_name="seq_len") + data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10}) + data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len") + indices = sampler(data_set) + self.assertEqual(len(indices), 10) + # 跑通即可,不验证效果 diff --git a/tutorials/fastnlp_tutorial_1204.ipynb b/tutorials/fastnlp_tutorial_1204.ipynb new file mode 100644 index 00000000..1a002750 --- /dev/null +++ b/tutorials/fastnlp_tutorial_1204.ipynb @@ -0,0 +1,1209 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "fastNLP上手教程\n", + "-------\n", + "\n", + "fastNLP提供方便的数据预处理,训练和测试模型的功能" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('C:/Users/zyfeng/Desktop/FudanNLP/fastNLP')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "DataSet & Instance\n", + "------\n", + "\n", + "fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n", + "\n", + "有一些read_*方法,可以轻松从文件读取数据,存成DataSet。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "38" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "from fastNLP import Instance\n", + "\n", + "# 从csv读取数据到DataSet\n", + "dataset = DataSet.read_csv('./test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), sep='\\t')\n", + "print(len(dataset))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# 使用数字索引[k],获取第k个样本\n", + "print(dataset[0])\n", + "\n", + "# 索引也可以是负数\n", + "print(dataset[-3])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instance\n", + "Instance表示一个样本,由一个或多个field(域,属性,特征)组成,每个field有名字和值。\n", + "\n", + "在初始化Instance时即可定义它包含的域,使用 \"field_name=field_value\"的写法。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'raw_sentence': fake data,\n'label': 0}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# DataSet.append(Instance)加入新数据\n", + "dataset.append(Instance(raw_sentence='fake data', label='0'))\n", + "dataset[-1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DataSet.apply方法\n", + "数据预处理利器" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" + ] + } + ], + "source": [ + "# 将所有数字转为小写\n", + "dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# label转int\n", + "dataset.apply(lambda x: int(x['label']), new_field_name='label')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']}" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# 使用空格分割句子\n", + "def split_sent(ins):\n", + " return ins['raw_sentence'].split()\n", + "dataset.apply(split_sent, new_field_name='words')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n'seq_len': 37}" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# 增加长度信息\n", + "dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n", + "print(dataset[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DataSet.drop\n", + "筛选数据" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "38" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "dataset.drop(lambda x: x['seq_len'] <= 3)\n", + "print(len(dataset))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 配置DataSet\n", + "1. 哪些域是特征,哪些域是标签\n", + "2. 切分训练集/验证集" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# 设置DataSet中,哪些field要转为tensor\n", + "\n", + "# set target,loss或evaluate中的golden,计算loss,模型评估时使用\n", + "dataset.set_target(\"label\")\n", + "# set input,模型forward时使用\n", + "dataset.set_input(\"words\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "27" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11" + ] + } + ], + "source": [ + "# 分出测试集、训练集\n", + "\n", + "test_data, train_data = dataset.split(0.3)\n", + "print(len(test_data))\n", + "print(len(train_data))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Vocabulary\n", + "------\n", + "\n", + "fastNLP中的Vocabulary轻松构建词表,将词转成数字" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'raw_sentence': that the chuck norris `` grenade gag '' occurs about 7 times during windtalkers is a good indication of how serious-minded the film is .,\n'label': 2,\n'words': [6, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 8, 24, 1, 5, 1, 1, 2, 15, 10, 3],\n'seq_len': 25}" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from fastNLP import Vocabulary\n", + "\n", + "# 构建词表, Vocabulary.add(word)\n", + "vocab = Vocabulary(min_freq=2)\n", + "train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n", + "vocab.build_vocab()\n", + "\n", + "# index句子, Vocabulary.to_index(word)\n", + "train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", + "test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n", + "\n", + "\n", + "print(test_data[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model\n", + "定义一个PyTorch模型" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "CNNText(\n (embed): Embedding(\n (embed): Embedding(32, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP.models import CNNText\n", + "model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n", + "model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n", + "\n", + "注意两点:\n", + "1. forward参数名字叫**word_seq**,请记住。\n", + "2. forward的返回值是一个**dict**,其中有个key的名字叫**output**。\n", + "\n", + "```Python\n", + " def forward(self, word_seq):\n", + " \"\"\"\n", + "\n", + " :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", + " :return output: dict of torch.LongTensor, [batch_size, num_classes]\n", + " \"\"\"\n", + " x = self.embed(word_seq) # [N,L] -> [N,L,C]\n", + " x = self.conv_pool(x) # [N,L,C] -> [N,C]\n", + " x = self.dropout(x)\n", + " x = self.fc(x) # [N,C] -> [N, N_class]\n", + " return {'output': x}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "这是上述模型的predict方法,是用来直接输出该任务的预测结果,与forward目的不同。\n", + "\n", + "注意两点:\n", + "1. predict参数名也叫**word_seq**。\n", + "2. predict的返回值是也一个**dict**,其中有个key的名字叫**predict**。\n", + "\n", + "```\n", + " def predict(self, word_seq):\n", + " \"\"\"\n", + "\n", + " :param word_seq: torch.LongTensor, [batch_size, seq_len]\n", + " :return predict: dict of torch.LongTensor, [batch_size, seq_len]\n", + " \"\"\"\n", + " output = self(word_seq)\n", + " _, predict = output['output'].max(dim=1)\n", + " return {'predict': predict}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Trainer & Tester\n", + "------\n", + "\n", + "使用fastNLP的Trainer训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer\n", + "from copy import deepcopy\n", + "from fastNLP.core.losses import CrossEntropyLoss\n", + "from fastNLP.core.metrics import AccuracyMetric\n", + "\n", + "\n", + "# 更改DataSet中对应field的名称,与模型的forward的参数名一致\n", + "# 因为forward的参数叫word_seq, 所以要把原本叫words的field改名为word_seq\n", + "# 这里的演示是让你了解这种**命名规则**\n", + "train_data.rename_field('words', 'word_seq')\n", + "test_data.rename_field('words', 'word_seq')\n", + "\n", + "# 顺便把label换名为label_seq\n", + "train_data.rename_field('label', 'label_seq')\n", + "test_data.rename_field('label', 'label_seq')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### loss\n", + "训练模型需要提供一个损失函数\n", + "\n", + "下面提供了一个在分类问题中常用的交叉熵损失。注意它的**初始化参数**。\n", + "\n", + "pred参数对应的是模型的forward返回的dict的一个key的名字,这里是\"output\"。\n", + "\n", + "target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "loss = CrossEntropyLoss(pred=\"output\", target=\"label_seq\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Metric\n", + "定义评价指标\n", + "\n", + "这里使用准确率。参数的“命名规则”跟上面类似。\n", + "\n", + "pred参数对应的是模型的predict方法返回的dict的一个key的名字,这里是\"predict\"。\n", + "\n", + "target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "metric = AccuracyMetric(pred=\"predict\", target=\"label_seq\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "training epochs started 2018-12-04 22:51:24" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\rEpoch 1/5: 0%| | 0/5 [00:00