Browse Source

Merge pull request #108 from FengZiYjun/trainer

FastNLP v0.2
tags/v0.2.0
Xipeng Qiu GitHub 6 years ago
parent
commit
1b477a95b0
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 5432 additions and 1731 deletions
  1. +33
    -14
      README.md
  2. +2
    -1
      docs/quick_tutorial.md
  3. BIN
      docs/source/figures/text_classification.png
  4. +2
    -6
      fastNLP/api/model_zoo.py
  5. +25
    -13
      fastNLP/api/processor.py
  6. +6
    -4
      fastNLP/core/__init__.py
  7. +15
    -2
      fastNLP/core/batch.py
  8. +219
    -155
      fastNLP/core/dataset.py
  9. +140
    -22
      fastNLP/core/fieldarray.py
  10. +2
    -3
      fastNLP/core/instance.py
  11. +0
    -196
      fastNLP/core/loss.py
  12. +358
    -0
      fastNLP/core/losses.py
  13. +279
    -257
      fastNLP/core/metrics.py
  14. +32
    -45
      fastNLP/core/optimizer.py
  15. +2
    -19
      fastNLP/core/predictor.py
  16. +1
    -1
      fastNLP/core/sampler.py
  17. +69
    -36
      fastNLP/core/tester.py
  18. +301
    -277
      fastNLP/core/trainer.py
  19. +317
    -12
      fastNLP/core/utils.py
  20. +42
    -40
      fastNLP/core/vocabulary.py
  21. +18
    -14
      fastNLP/io/base_loader.py
  22. +148
    -2
      fastNLP/io/config_io.py
  23. +0
    -149
      fastNLP/io/config_loader.py
  24. +43
    -38
      fastNLP/io/dataset_loader.py
  25. +54
    -23
      fastNLP/io/embed_loader.py
  26. +28
    -0
      fastNLP/io/model_io.py
  27. +0
    -28
      fastNLP/io/model_loader.py
  28. +14
    -3
      fastNLP/models/base_model.py
  29. +5
    -29
      fastNLP/models/cnn_text_classification.py
  30. +1
    -1
      fastNLP/modules/encoder/char_embedding.py
  31. +1
    -1
      reproduction/Biaffine_parser/infer.py
  32. +2
    -3
      reproduction/Biaffine_parser/run.py
  33. +2
    -2
      reproduction/LSTM+self_attention_sentiment_analysis/main.py
  34. +2
    -3
      reproduction/chinese_word_segment/run.py
  35. +1
    -1
      requirements.txt
  36. +2
    -2
      setup.py
  37. +12
    -0
      test/api/test_processor.py
  38. +0
    -0
      test/core/__init__.py
  39. +2
    -2
      test/core/test_batch.py
  40. +128
    -3
      test/core/test_dataset.py
  41. +77
    -0
      test/core/test_fieldarray.py
  42. +6
    -0
      test/core/test_instance.py
  43. +82
    -301
      test/core/test_loss.py
  44. +145
    -0
      test/core/test_metrics.py
  45. +54
    -0
      test/core/test_optimizer.py
  46. +29
    -1
      test/core/test_predictor.py
  47. +11
    -1
      test/core/test_sampler.py
  48. +59
    -1
      test/core/test_tester.py
  49. +239
    -3
      test/core/test_trainer.py
  50. +35
    -8
      test/core/test_vocabulary.py
  51. +1
    -7
      test/data_for_tests/glove.6B.50d_test.txt
  52. +77
    -0
      test/data_for_tests/tutorial_sample_dataset.csv
  53. +0
    -0
      test/io/__init__.py
  54. +1
    -2
      test/io/test_config_saver.py
  55. +12
    -0
      test/io/test_embed_loader.py
  56. +91
    -0
      test/test_tutorial.py
  57. +911
    -0
      tutorials/fastnlp_10min_tutorial_v2.ipynb
  58. +860
    -0
      tutorials/fastnlp_10tmin_tutorial.ipynb
  59. +333
    -0
      tutorials/fastnlp_1_minute_tutorial.ipynb
  60. +101
    -0
      tutorials/fastnlp_advanced_tutorial.ipynb

+ 33
- 14
README.md View File

@@ -6,16 +6,39 @@
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)


fastNLP is a modular Natural Language Processing system based on PyTorch, for fast development of NLP tools. It divides the NLP model based on deep learning into different modules. These modules fall into 4 categories: encoder, interaction, aggregation and decoder, while each category contains different implemented modules. Encoder modules encode the input into some abstract representation, interaction modules make the information in the representation interact with each other, aggregation modules aggregate and reduce information, and decoder modules decode the representation into the output. Most current NLP models could be built on these modules, which vastly simplifies the process of developing NLP models. The architecture of fastNLP is as the figure below:
FastNLP is a modular Natural Language Processing system based on PyTorch, built for fast development of NLP models.


![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/procedures.PNG)
![](https://github.com/fastnlp/fastNLP/raw/master/docs/source/figures/text_classification.png)
A deep learning NLP model is the composition of three types of modules:
<table>
<tr>
<td><b> module type </b></td>
<td><b> functionality </b></td>
<td><b> example </b></td>
</tr>
<tr>
<td> encoder </td>
<td> encode the input into some abstract representation </td>
<td> embedding, RNN, CNN, transformer
</tr>
<tr>
<td> aggregator </td>
<td> aggregate and reduce information </td>
<td> self-attention, max-pooling </td>
</tr>
<tr>
<td> decoder </td>
<td> decode the representation into the output </td>
<td> MLP, CRF </td>
</tr>

For example:

![](docs/source/figures/text_classification.png)


## Requirements ## Requirements


- numpy>=1.14.2 - numpy>=1.14.2
- torch>=0.4.0 - torch>=0.4.0
- torchvision>=0.1.8
- tensorboardX - tensorboardX




@@ -39,12 +62,12 @@ pip install fastNLP
<td> an open-source NLP library </td> <td> an open-source NLP library </td>
</tr> </tr>
<tr> <tr>
<td><b> fastNLP.core </b></td>
<td> trainer, tester, predictor </td>
<td><b> fastNLP.api </b></td>
<td> APIs for end-to-end prediction </td>
</tr> </tr>
<tr> <tr>
<td><b> fastNLP.loader </b></td>
<td> all kinds of loaders/readers </td>
<td><b> fastNLP.core </b></td>
<td> data representation & train/test presedure </td>
</tr> </tr>
<tr> <tr>
<td><b> fastNLP.models </b></td> <td><b> fastNLP.models </b></td>
@@ -55,11 +78,7 @@ pip install fastNLP
<td> a collection of PyTorch sub-models/components/wheels </td> <td> a collection of PyTorch sub-models/components/wheels </td>
</tr> </tr>
<tr> <tr>
<td><b> fastNLP.saver </b></td>
<td> all kinds of savers/writers </td>
</tr>
<tr>
<td><b> fastNLP.fastnlp </b></td>
<td> a high-level interface for prediction </td>
<td><b> fastNLP.io </b></td>
<td> readers & savers </td>
</tr> </tr>
</table> </table>

+ 2
- 1
docs/quick_tutorial.md View File

@@ -1 +1,2 @@
# FastNLP Quick Tutorial
# FastNLP Quick Tutorial


BIN
docs/source/figures/text_classification.png View File

Before After
Width: 1217  |  Height: 543  |  Size: 54 kB Width: 1699  |  Height: 747  |  Size: 73 kB

+ 2
- 6
fastNLP/api/model_zoo.py View File

@@ -1,5 +1,3 @@
import torch

import hashlib import hashlib
import os import os
import re import re
@@ -7,6 +5,8 @@ import shutil
import sys import sys
import tempfile import tempfile


import torch

try: try:
from requests.utils import urlparse from requests.utils import urlparse
from requests import get as urlopen from requests import get as urlopen
@@ -132,7 +132,3 @@ if tqdm is None:


sys.stderr.write('\n') sys.stderr.write('\n')



if __name__ == '__main__':
pipeline = load_url('http://10.141.208.102:5000/file/download/infer_context-4e86fd93.pkl', model_dir='.')
print(type(pipeline))

+ 25
- 13
fastNLP/api/processor.py View File

@@ -1,14 +1,15 @@
import torch
from collections import defaultdict
import re import re
from collections import defaultdict

import torch


from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.vocabulary import Vocabulary




class Processor:
class Processor(object):
def __init__(self, field_name, new_added_field_name): def __init__(self, field_name, new_added_field_name):
self.field_name = field_name self.field_name = field_name
if new_added_field_name is None: if new_added_field_name is None:
@@ -17,7 +18,7 @@ class Processor:
self.new_added_field_name = new_added_field_name self.new_added_field_name = new_added_field_name


def process(self, *args, **kwargs): def process(self, *args, **kwargs):
pass
raise NotImplementedError


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs) return self.process(*args, **kwargs)
@@ -132,13 +133,14 @@ class Num2TagProcessor(Processor):




class IndexerProcessor(Processor): class IndexerProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False):
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False, is_input=True):


assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))


super(IndexerProcessor, self).__init__(field_name, new_added_field_name) super(IndexerProcessor, self).__init__(field_name, new_added_field_name)
self.vocab = vocab self.vocab = vocab
self.delete_old_field = delete_old_field self.delete_old_field = delete_old_field
self.is_input = is_input


def set_vocab(self, vocab): def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab))
@@ -146,13 +148,14 @@ class IndexerProcessor(Processor):
self.vocab = vocab self.vocab = vocab


def process(self, dataset): def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
for ins in dataset: for ins in dataset:
tokens = ins[self.field_name] tokens = ins[self.field_name]
index = [self.vocab.to_index(token) for token in tokens] index = [self.vocab.to_index(token) for token in tokens]
ins[self.new_added_field_name] = index ins[self.new_added_field_name] = index


dataset._set_need_tensor(**{self.new_added_field_name: True})
if self.is_input:
dataset.set_input(self.new_added_field_name)


if self.delete_old_field: if self.delete_old_field:
dataset.delete_field(self.field_name) dataset.delete_field(self.field_name)
@@ -161,6 +164,9 @@ class IndexerProcessor(Processor):




class VocabProcessor(Processor): class VocabProcessor(Processor):
"""Build vocabulary with a field in the data set.

"""
def __init__(self, field_name): def __init__(self, field_name):
super(VocabProcessor, self).__init__(field_name, None) super(VocabProcessor, self).__init__(field_name, None)
self.vocab = Vocabulary() self.vocab = Vocabulary()
@@ -178,17 +184,20 @@ class VocabProcessor(Processor):




class SeqLenProcessor(Processor): class SeqLenProcessor(Processor):
def __init__(self, field_name, new_added_field_name='seq_lens'):
def __init__(self, field_name, new_added_field_name='seq_lens', is_input=True):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)
self.is_input = is_input


def process(self, dataset): def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset: for ins in dataset:
length = len(ins[self.field_name]) length = len(ins[self.field_name])
ins[self.new_added_field_name] = length ins[self.new_added_field_name] = length
dataset._set_need_tensor(**{self.new_added_field_name: True})
if self.is_input:
dataset.set_input(self.new_added_field_name)
return dataset return dataset



class ModelProcessor(Processor): class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
""" """
@@ -238,6 +247,7 @@ class ModelProcessor(Processor):
device = torch.device(device) device = torch.device(device)
self.model.to(device) self.model.to(device)



class Index2WordProcessor(Processor): class Index2WordProcessor(Processor):
def __init__(self, vocab, field_name, new_added_field_name): def __init__(self, vocab, field_name, new_added_field_name):
super(Index2WordProcessor, self).__init__(field_name, new_added_field_name) super(Index2WordProcessor, self).__init__(field_name, new_added_field_name)
@@ -251,26 +261,28 @@ class Index2WordProcessor(Processor):




class SetTensorProcessor(Processor): class SetTensorProcessor(Processor):
# TODO: remove it. It is strange.
def __init__(self, field_dict, default=False): def __init__(self, field_dict, default=False):
super(SetTensorProcessor, self).__init__(None, None) super(SetTensorProcessor, self).__init__(None, None)
self.field_dict = field_dict self.field_dict = field_dict
self.default = default self.default = default


def process(self, dataset): def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_fields().keys()}
set_dict = {name: self.default for name in dataset.get_all_fields().keys()}
set_dict.update(self.field_dict) set_dict.update(self.field_dict)
dataset._set_need_tensor(**set_dict) dataset._set_need_tensor(**set_dict)
return dataset return dataset




class SetIsTargetProcessor(Processor): class SetIsTargetProcessor(Processor):
# TODO; remove it.
def __init__(self, field_dict, default=False): def __init__(self, field_dict, default=False):
super(SetIsTargetProcessor, self).__init__(None, None) super(SetIsTargetProcessor, self).__init__(None, None)
self.field_dict = field_dict self.field_dict = field_dict
self.default = default self.default = default


def process(self, dataset): def process(self, dataset):
set_dict = {name: self.default for name in dataset.get_fields().keys()}
set_dict = {name: self.default for name in dataset.get_all_fields().keys()}
set_dict.update(self.field_dict) set_dict.update(self.field_dict)
dataset.set_target(**set_dict) dataset.set_target(**set_dict)
return dataset return dataset

+ 6
- 4
fastNLP/core/__init__.py View File

@@ -1,11 +1,13 @@
from .batch import Batch from .batch import Batch
from .dataset import DataSet
# from .dataset import DataSet
from .fieldarray import FieldArray from .fieldarray import FieldArray
from .instance import Instance from .instance import Instance
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
from .metrics import AccuracyMetric
from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from .optimizer import Optimizer
from .loss import Loss
from ..io.dataset_loader import DataSet

+ 15
- 2
fastNLP/core/batch.py View File

@@ -1,3 +1,4 @@
import numpy as np
import torch import torch




@@ -25,6 +26,7 @@ class Batch(object):
self.as_numpy = as_numpy self.as_numpy = as_numpy
self.idx_list = None self.idx_list = None
self.curidx = 0 self.curidx = 0
self.num_batches = len(dataset)//batch_size + int(len(dataset)%batch_size!=0)


def __iter__(self): def __iter__(self):
self.idx_list = self.sampler(self.dataset) self.idx_list = self.sampler(self.dataset)
@@ -41,11 +43,11 @@ class Batch(object):


indices = self.idx_list[self.curidx:endidx] indices = self.idx_list[self.curidx:endidx]


for field_name, field in self.dataset.get_fields().items():
for field_name, field in self.dataset.get_all_fields().items():
if field.is_target or field.is_input: if field.is_target or field.is_input:
batch = field.get(indices) batch = field.get(indices)
if not self.as_numpy: if not self.as_numpy:
batch = torch.from_numpy(batch)
batch = to_tensor(batch, field.dtype)
if field.is_target: if field.is_target:
batch_y[field_name] = batch batch_y[field_name] = batch
if field.is_input: if field.is_input:
@@ -54,3 +56,14 @@ class Batch(object):
self.curidx = endidx self.curidx = endidx


return batch_x, batch_y return batch_x, batch_y

def __len__(self):
return self.num_batches


def to_tensor(batch, dtype):
if dtype in (int, np.int8, np.int16, np.int32, np.int64):
batch = torch.LongTensor(batch)
if dtype in (float, np.float32, np.float64):
batch = torch.FloatTensor(batch)
return batch

+ 219
- 155
fastNLP/core/dataset.py View File

@@ -1,24 +1,11 @@
import _pickle as pickle

import numpy as np import numpy as np
from copy import copy


from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance

_READERS = {}


def construct_dataset(sentences):
"""Construct a data set from a list of sentences.

:param sentences: list of list of str
:return dataset: a DataSet object
"""
dataset = DataSet()
for sentence in sentences:
instance = Instance()
instance['raw_sentence'] = sentence
dataset.append(instance)
return dataset
from fastNLP.core.utils import get_func_signature
from fastNLP.io.base_loader import DataLoaderRegister




class DataSet(object): class DataSet(object):
@@ -28,45 +15,13 @@ class DataSet(object):


""" """


class Instance(object):
def __init__(self, dataset, idx=-1, **fields):
self.dataset = dataset
self.idx = idx
self.fields = fields

def __next__(self):
self.idx += 1
if self.idx >= len(self.dataset):
raise StopIteration
return copy(self)

def add_field(self, field_name, field):
"""Add a new field to the instance.

:param field_name: str, the name of the field.
:param field:
"""
self.fields[field_name] = field

def __getitem__(self, name):
return self.dataset[name][self.idx]

def __setitem__(self, name, val):
if name not in self.dataset:
new_fields = [None] * len(self.dataset)
self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val

def __repr__(self):
return "\n".join(['{}: {}'.format(name, repr(self.dataset[name][self.idx])) for name
in self.dataset.get_fields().keys()])

def __init__(self, data=None): def __init__(self, data=None):
""" """


:param data: a dict or a list. If it is a dict, the key is the name of a field and the value is the field.
All values must be of the same length.
If it is a list, it must be a list of Instance objects.
:param data: a dict or a list.
If `data` is a dict, the key is the name of a FieldArray and the value is the FieldArray. All values
must be of the same length.
If `data` is a list, it must be a list of Instance objects.
""" """
self.field_arrays = {} self.field_arrays = {}
if data is not None: if data is not None:
@@ -89,14 +44,95 @@ class DataSet(object):
return item in self.field_arrays return item in self.field_arrays


def __iter__(self): def __iter__(self):
return self.Instance(self)
def iter_func():
for idx in range(len(self)):
yield self[idx]

return iter_func()

def _inner_iter(self):
class Iter_ptr:
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = idx

def __getitem__(self, item):
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx]

def __repr__(self):
return self.dataset[self.idx].__repr__()

def inner_iter_func():
for idx in range(len(self)):
yield Iter_ptr(self, idx)

return inner_iter_func()

def __getitem__(self, idx):
"""Fetch Instance(s) at the `idx` position(s) in the dataset.
Notice: This method returns a copy of the actual instance(s). Any change to the returned value would not modify
the origin instance(s) of the DataSet.
If you want to make in-place changes to all Instances, use `apply` method.

:param idx: can be int or slice.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
"""
if isinstance(idx, int):
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)):
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}")
data_set = DataSet()
for field in self.field_arrays.values():
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __getattr__(self, item):
# Not tested. Don't use !!
if item == "field_arrays":
raise AttributeError
if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item]
try:
reader = DataLoaderRegister.get_reader(item)
return reader
except AttributeError:
raise


def _convert_ins(self, ins_list):
if isinstance(ins_list, list):
for ins in ins_list:
self.append(ins)
def __setstate__(self, state):
self.__dict__ = state

def __getstate__(self):
return self.__dict__

def __len__(self):
"""Fetch the length of the dataset.

:return int length:
"""
if len(self.field_arrays) == 0:
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)

def __inner_repr__(self):
if len(self) < 20:
return ",\n".join([ins.__repr__() for ins in self])
else: else:
self.append(ins_list)
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__()

def __repr__(self):
return "DataSet(" + self.__inner_repr__() + ")"


def append(self, ins): def append(self, ins):
"""Add an instance to the DataSet. """Add an instance to the DataSet.
@@ -125,7 +161,9 @@ class DataSet(object):
:param bool is_target: whether this field is label or target. :param bool is_target: whether this field is label or target.
""" """
if len(self.field_arrays) != 0: if len(self.field_arrays) != 0:
assert len(self) == len(fields)
if len(self) != len(fields):
raise RuntimeError(f"The field to append must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fields)}")
self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,
is_input=is_input) is_input=is_input)


@@ -136,146 +174,121 @@ class DataSet(object):
""" """
self.field_arrays.pop(name) self.field_arrays.pop(name)


def get_fields(self):
def get_field(self, field_name):
if field_name not in self.field_arrays:
raise KeyError("Field name {} not found in DataSet".format(field_name))
return self.field_arrays[field_name]

def get_all_fields(self):
"""Return all the fields with their names. """Return all the fields with their names.


:return dict field_arrays: the internal data structure of DataSet. :return dict field_arrays: the internal data structure of DataSet.
""" """
return self.field_arrays return self.field_arrays


def __getitem__(self, idx):
"""

:param idx: can be int, slice, or str.
:return: If `idx` is int, return an Instance object.
If `idx` is slice, return a DataSet object.
If `idx` is str, it must be a field name, return the field.

"""
if isinstance(idx, int):
return self.Instance(self, idx, **{name: self.field_arrays[name][idx] for name in self.field_arrays})
elif isinstance(idx, slice):
data_set = DataSet()
for field in self.field_arrays.values():
data_set.add_field(name=field.name,
fields=field.content[idx],
padding_val=field.padding_val,
is_input=field.is_input,
is_target=field.is_target)
return data_set
elif isinstance(idx, str):
return self.field_arrays[idx]
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __len__(self):
if len(self.field_arrays) == 0:
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)

def get_length(self): def get_length(self):
"""The same as __len__
"""Fetch the length of the dataset.


:return int length:
""" """
return len(self) return len(self)


def rename_field(self, old_name, new_name): def rename_field(self, old_name, new_name):
"""rename a field
"""Rename a field.

:param str old_name:
:param str new_name:
""" """
if old_name in self.field_arrays: if old_name in self.field_arrays:
self.field_arrays[new_name] = self.field_arrays.pop(old_name) self.field_arrays[new_name] = self.field_arrays.pop(old_name)
self.field_arrays[new_name].name = new_name
else: else:
raise KeyError("{} is not a valid name. ".format(old_name))
raise KeyError("DataSet has no field named {}.".format(old_name))


def set_target(self, **fields):
"""Change the flag of `is_target` for all instance. For fields not set here, leave their `is_target` unchanged.
def set_target(self, *field_names, flag=True):
"""Change the target flag of these fields.


:param key-value pairs for field-name and `is_target` value(True, False).
:param field_names: a sequence of str, indicating field names
:param bool flag: Set these fields as target if True. Unset them if False.
""" """
for name, val in fields.items():
for name in field_names:
if name in self.field_arrays: if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_target = val
self.field_arrays[name].is_target = flag
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
return self


def set_input(self, **fields):
for name, val in fields.items():
def set_input(self, *field_name, flag=True):
"""Set the input flag of these fields.

:param field_name: a sequence of str, indicating field names.
:param bool flag: Set these fields as input if True. Unset them if False.
"""
for name in field_name:
if name in self.field_arrays: if name in self.field_arrays:
assert isinstance(val, bool)
self.field_arrays[name].is_input = val
self.field_arrays[name].is_input = flag
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
return self


def get_input_name(self): def get_input_name(self):
"""Get all field names with `is_input` as True.

:return list field_names: a list of str
"""
return [name for name, field in self.field_arrays.items() if field.is_input] return [name for name, field in self.field_arrays.items() if field.is_input]


def get_target_name(self): def get_target_name(self):
return [name for name, field in self.field_arrays.items() if field.is_target]

def __getattr__(self, item):
# block infinite recursion for copy, pickle
if item == '__setstate__':
raise AttributeError(item)
try:
return self.field_arrays.__getitem__(item)
except KeyError:
pass
try:
reader_cls = _READERS[item]

# add read_*data() support
def _read(*args, **kwargs):
data = reader_cls().load(*args, **kwargs)
self.extend(data)
return self
"""Get all field names with `is_target` as True.


return _read
except KeyError:
raise AttributeError('{} does not exist.'.format(item))

@classmethod
def set_reader(cls, method_name):
"""decorator to add dataloader support
:return list field_names: a list of str
""" """
assert isinstance(method_name, str)

def wrapper(read_cls):
_READERS[method_name] = read_cls
return read_cls

return wrapper
return [name for name, field in self.field_arrays.items() if field.is_target]


def apply(self, func, new_field_name=None):
def apply(self, func, new_field_name=None, **kwargs):
"""Apply a function to every instance of the DataSet. """Apply a function to every instance of the DataSet.


:param func: a function that takes an instance as input. :param func: a function that takes an instance as input.
:param str new_field_name: If not None, results of the function will be stored as a new field. :param str new_field_name: If not None, results of the function will be stored as a new field.
:return results: returned values of the function over all instances.
:param **kwargs: Accept parameters will be
(1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input.
(2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target.
:return results: if new_field_name is not passed, returned values of the function over all instances.
""" """
results = [func(ins) for ins in self]
results = [func(ins) for ins in self._inner_iter()]
if len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(get_func_signature(func=func)))

extra_param = {}
if 'is_input' in kwargs:
extra_param['is_input'] = kwargs['is_input']
if 'is_target' in kwargs:
extra_param['is_target'] = kwargs['is_target']
if new_field_name is not None: if new_field_name is not None:
if new_field_name in self.field_arrays: if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes # overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name] old_field = self.field_arrays[new_field_name]
if 'is_input' not in extra_param:
extra_param['is_input'] = old_field.is_input
if 'is_target' not in extra_param:
extra_param['is_target'] = old_field.is_target
self.add_field(name=new_field_name, self.add_field(name=new_field_name,
fields=results, fields=results,
padding_val=old_field.padding_val, padding_val=old_field.padding_val,
is_input=old_field.is_input,
is_target=old_field.is_target)
**extra_param)
else: else:
self.add_field(name=new_field_name, fields=results)
self.add_field(name=new_field_name, fields=results, **extra_param)
else: else:
return results return results


def drop(self, func): def drop(self, func):
results = [ins for ins in self if not func(ins)]
"""Drop instances if a condition holds.

:param func: a function that takes an Instance object as input, and returns bool.
The instance will be dropped if the function returns True.

"""
results = [ins for ins in self._inner_iter() if not func(ins)]
for name, old_field in self.field_arrays.items(): for name, old_field in self.field_arrays.items():
self.field_arrays[name].content = [ins[name] for ins in results] self.field_arrays[name].content = [ins[name] for ins in results]
# print(self.field_arrays[name])


def split(self, dev_ratio): def split(self, dev_ratio):
"""Split the dataset into training and development(validation) set. """Split the dataset into training and development(validation) set.
@@ -297,30 +310,81 @@ class DataSet(object):
dev_set.append(self[idx]) dev_set.append(self[idx])
for idx in train_indices: for idx in train_indices:
train_set.append(self[idx]) train_set.append(self[idx])
for field_name in self.field_arrays:
train_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input
train_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target
dev_set.field_arrays[field_name].is_input = self.field_arrays[field_name].is_input
dev_set.field_arrays[field_name].is_target = self.field_arrays[field_name].is_target

return train_set, dev_set return train_set, dev_set


@classmethod @classmethod
def read_csv(cls, csv_path, headers=None, sep='\t', dropna=True):
with open(csv_path, 'r') as f:
def read_csv(cls, csv_path, headers=None, sep=",", dropna=True):
"""Load data from a CSV file and return a DataSet object.

:param str csv_path: path to the CSV file
:param List[str] or Tuple[str] headers: headers of the CSV file
:param str sep: delimiter in CSV file. Default: ","
:param bool dropna: If True, drop rows that have less entries than headers.
:return DataSet dataset:

"""
with open(csv_path, "r") as f:
start_idx = 0 start_idx = 0
if headers is None: if headers is None:
headers = f.readline().rstrip('\r\n') headers = f.readline().rstrip('\r\n')
headers = headers.split(sep) headers = headers.split(sep)
start_idx += 1 start_idx += 1
else: else:
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(type(headers))
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format(
type(headers))
_dict = {} _dict = {}
for col in headers: for col in headers:
_dict[col] = [] _dict[col] = []
for line_idx, line in enumerate(f, start_idx): for line_idx, line in enumerate(f, start_idx):
contents = line.split(sep)
if len(contents)!=len(headers):
contents = line.rstrip('\r\n').split(sep)
if len(contents) != len(headers):
if dropna: if dropna:
continue continue
else: else:
#TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts."\
.format(line_idx, len(contents), len(headers)))
# TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts." \
.format(line_idx, len(contents), len(headers)))
for header, content in zip(headers, contents): for header, content in zip(headers, contents):
_dict[header].append(content) _dict[header].append(content)
return cls(_dict) return cls(_dict)

# def read_pos(self):
# return DataLoaderRegister.get_reader('read_pos')

def save(self, path):
"""Save the DataSet object as pickle.

:param str path: the path to the pickle
"""
with open(path, 'wb') as f:
pickle.dump(self, f)

@staticmethod
def load(path):
"""Load a DataSet object from pickle.

:param str path: the path to the pickle
:return DataSet data_set:
"""
with open(path, 'rb') as f:
return pickle.load(f)


def construct_dataset(sentences):
"""Construct a data set from a list of sentences.

:param sentences: list of list of str
:return dataset: a DataSet object
"""
dataset = DataSet()
for sentence in sentences:
instance = Instance()
instance['raw_sentence'] = sentence
dataset.append(instance)
return dataset

+ 140
- 22
fastNLP/core/fieldarray.py View File

@@ -6,35 +6,150 @@ class FieldArray(object):
It is the basic element of DataSet class. It is the basic element of DataSet class.


""" """
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False):

def __init__(self, name, content, padding_val=0, is_target=None, is_input=None):
""" """


:param str name: the name of the FieldArray :param str name: the name of the FieldArray
:param list content: a list of int, float, or other objects.
:param list content: a list of int, float, str or np.ndarray, or a list of list of one, or a np.ndarray.
:param int padding_val: the integer for padding. Default: 0. :param int padding_val: the integer for padding. Default: 0.
:param bool is_target: If True, this FieldArray is used to compute loss. :param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input. :param bool is_input: If True, this FieldArray is used to the model input.
""" """
self.name = name self.name = name
if isinstance(content, list):
content = content
elif isinstance(content, np.ndarray):
content = content.tolist() # convert np.ndarray into 2-D list
else:
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content)))
self.content = content self.content = content
self.padding_val = padding_val self.padding_val = padding_val
self.is_target = is_target
self.is_input = is_input
# TODO: auto detect dtype
self.dtype = None

self._is_target = None
self._is_input = None

self.BASIC_TYPES = (int, float, str, np.ndarray)
self.is_2d_list = False
self.pytype = None # int, float, str, or np.ndarray
self.dtype = None # np.int64, np.float64, np.str

if is_input is not None:
self.is_input = is_input
if is_target is not None:
self.is_target = is_target

@property
def is_input(self):
return self._is_input

@is_input.setter
def is_input(self, value):
if value is True:
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)
self._is_input = value

@property
def is_target(self):
return self._is_target

@is_target.setter
def is_target(self, value):
if value is True:
self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype)
self._is_target = value

def _type_detection(self, content):
"""

:param content: a list of int, float, str or np.ndarray, or a list of list of one.
:return type: one of int, float, str, np.ndarray

"""
if isinstance(content, list) and len(content) > 0 and isinstance(content[0], list):
# content is a 2-D list
if not all(isinstance(_, list) for _ in content): # strict check 2-D list
raise TypeError("Please provide 2-D list.")
type_set = set([self._type_detection(x) for x in content])
if len(type_set) == 2 and int in type_set and float in type_set:
type_set = {float}
elif len(type_set) > 1:
raise TypeError("Cannot create FieldArray with more than one type. Provided {}".format(type_set))
self.is_2d_list = True
return type_set.pop()

elif isinstance(content, list):
# content is a 1-D list
if len(content) == 0:
# the old error is not informative enough.
raise RuntimeError("Cannot create FieldArray with an empty list. Or one element in the list is empty.")
type_set = set([type(item) for item in content])

if len(type_set) == 1 and tuple(type_set)[0] in self.BASIC_TYPES:
return type_set.pop()
elif len(type_set) == 2 and float in type_set and int in type_set:
# up-cast int to float
return float
else:
raise TypeError("Cannot create FieldArray with type {}".format(*type_set))
else:
raise TypeError("Cannot create FieldArray with type {}".format(type(content)))

@staticmethod
def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray}
return type_mapping[basic_type]


def __repr__(self): def __repr__(self):
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) return "FieldArray {}: {}".format(self.name, self.content.__repr__())


def append(self, val): def append(self, val):
"""Add a new item to the tail of FieldArray.

:param val: int, float, str, or a list of one.
"""
if self.is_target is True or self.is_input is True:
# only check type when used as target or input

val_type = type(val)
if val_type == list: # shape check
if self.is_2d_list is False:
raise RuntimeError("Cannot append a list into a 1-D FieldArray. Please provide an element.")
if len(val) == 0:
raise RuntimeError("Cannot append an empty list.")
val_list_type = set([type(_) for _ in val]) # type check
if len(val_list_type) == 2 and int in val_list_type and float in val_list_type:
# up-cast int to float
val_type = float
elif len(val_list_type) == 1:
val_type = val_list_type.pop()
else:
raise TypeError("Cannot append a list of {}".format(val_list_type))
else:
if self.is_2d_list is True:
raise RuntimeError("Cannot append a non-list into a 2-D list. Please provide a list.")

if val_type == float and self.pytype == int:
# up-cast
self.pytype = float
self.dtype = self._map_to_np_type(self.pytype)
elif val_type == int and self.pytype == float:
pass
elif val_type == self.pytype:
pass
else:
raise TypeError("Cannot append type {} into type {}".format(val_type, self.pytype))

self.content.append(val) self.content.append(val)


def __getitem__(self, name):
return self.get(name)
def __getitem__(self, indices):
return self.get(indices)


def __setitem__(self, name, val):
assert isinstance(name, int)
self.content[name] = val
def __setitem__(self, idx, val):
assert isinstance(idx, int)
self.content[idx] = val


def get(self, indices): def get(self, indices):
"""Fetch instances based on indices. """Fetch instances based on indices.
@@ -44,29 +159,32 @@ class FieldArray(object):
""" """
if isinstance(indices, int): if isinstance(indices, int):
return self.content[indices] return self.content[indices]
assert self.is_input is True or self.is_target is True
if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name))
batch_size = len(indices) batch_size = len(indices)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if not isiterable(self.content[0]):
if self.dtype is None:
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double

if not is_iterable(self.content[0]):
array = np.array([self.content[i] for i in indices], dtype=self.dtype) array = np.array([self.content[i] for i in indices], dtype=self.dtype)
else:
if self.dtype is None:
self.dtype = np.int64
elif self.dtype in (np.int64, np.float64):
max_len = max([len(self.content[i]) for i in indices]) max_len = max([len(self.content[i]) for i in indices])
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)

for i, idx in enumerate(indices): for i, idx in enumerate(indices):
array[i][:len(self.content[idx])] = self.content[idx] array[i][:len(self.content[idx])] = self.content[idx]
else: # should only be str
array = np.array([self.content[i] for i in indices])
return array return array


def __len__(self): def __len__(self):
"""Returns the size of FieldArray.

:return int length:
"""
return len(self.content) return len(self.content)


def isiterable(content):

def is_iterable(content):
try: try:
_ = (e for e in content) _ = (e for e in content)
except TypeError: except TypeError:
return False return False
return True
return True

+ 2
- 3
fastNLP/core/instance.py View File

@@ -1,5 +1,3 @@


class Instance(object): class Instance(object):
"""An Instance is an example of data. It is the collection of Fields. """An Instance is an example of data. It is the collection of Fields.


@@ -33,4 +31,5 @@ class Instance(object):
return self.add_field(name, field) return self.add_field(name, field)


def __repr__(self): def __repr__(self):
return self.fields.__repr__()
return "{" + ",\n".join(
"\'" + field_name + "\': " + str(self.fields[field_name]) for field_name in self.fields) + "}"

+ 0
- 196
fastNLP/core/loss.py View File

@@ -1,196 +0,0 @@
import torch

def squash(predict , truth , **kwargs):
'''To reshape tensors in order to fit Loss functions in pytorch

:param predict : Tensor, model output
:param truth : Tensor, truth from dataset
:param **kwargs : extra arguments

:return predict , truth: predict & truth after processing
'''
return predict.view(-1 , predict.size()[-1]) , truth.view(-1,)

def unpad(predict , truth , **kwargs):
'''To process padded sequence output to get true loss
Using pack_padded_sequence() method
This method contains squash()

:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
kwargs["lens"] : list or LongTensor, [batch_size]
the i-th element is true lengths of i-th sequence
:return predict , truth: predict & truth after processing
'''
if kwargs.get("lens") is None:
return predict , truth
lens = torch.LongTensor(kwargs["lens"])
lens , idx = torch.sort(lens , descending = True)
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx] , lens , batch_first = True).data
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx] , lens , batch_first = True).data
return predict , truth

def unpad_mask(predict , truth , **kwargs):
'''To process padded sequence output to get true loss
Using mask() method
This method contains squash()

:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
kwargs["lens"] : list or LongTensor, [batch_size]
the i-th element is true lengths of i-th sequence
:return predict , truth: predict & truth after processing
'''
if kwargs.get("lens") is None:
return predict , truth
mas = make_mask(kwargs["lens"] , truth.size()[1])
return mask(predict , truth , mask = mas)

def mask(predict , truth , **kwargs):
'''To select specific elements from Tensor
This method contains squash()

:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist
kwargs["mask"] : ByteTensor, [batch_size , max_len]
the mask Tensor , the position that is 1 will be selected
:return predict , truth: predict & truth after processing
'''
if kwargs.get("mask") is None:
return predict , truth
mask = kwargs["mask"]
predict , truth = squash(predict , truth)
mask = mask.view(-1,)

predict = torch.masked_select(predict.permute(1,0) , mask).view(predict.size()[-1] , -1).permute(1,0)
truth = torch.masked_select(truth , mask)

return predict , truth

def make_mask(lens , tar_len):
'''to generate a mask that select [:lens[i]] for i-th element
embezzle from fastNLP.models.sequence_modeling.seq_mask

:param lens : list or LongTensor, [batch_size]
:param tar_len : int
:return mask : ByteTensor
'''
lens = torch.LongTensor(lens)
mask = [torch.ge(lens, i + 1) for i in range(tar_len)]
mask = torch.stack(mask, 1)
return mask

#map string to function. Just for more elegant using
method_dict = {
"squash" : squash,
"unpad" : unpad,
"unpad_mask" : unpad_mask,
"mask" : mask,
}

loss_function_name = {
"L1Loss".lower() : torch.nn.L1Loss,
"BCELoss".lower() : torch.nn.BCELoss,
"MSELoss".lower() : torch.nn.MSELoss,
"NLLLoss".lower() : torch.nn.NLLLoss,
"KLDivLoss".lower() : torch.nn.KLDivLoss,
"NLLLoss2dLoss".lower() : torch.nn.NLLLoss2d, #every name should end with "loss"
"SmoothL1Loss".lower() : torch.nn.SmoothL1Loss,
"SoftMarginLoss".lower() : torch.nn.SoftMarginLoss,
"PoissonNLLLoss".lower() : torch.nn.PoissonNLLLoss,
"MultiMarginLoss".lower() : torch.nn.MultiMarginLoss,
"CrossEntropyLoss".lower() : torch.nn.CrossEntropyLoss,
"BCEWithLogitsLoss".lower() : torch.nn.BCEWithLogitsLoss,
"MarginRankingLoss".lower() : torch.nn.MarginRankingLoss,
"TripletMarginLoss".lower() : torch.nn.TripletMarginLoss,
"HingeEmbeddingLoss".lower() : torch.nn.HingeEmbeddingLoss,
"CosineEmbeddingLoss".lower() : torch.nn.CosineEmbeddingLoss,
"MultiLabelMarginLoss".lower() : torch.nn.MultiLabelMarginLoss,
"MultiLabelSoftMarginLoss".lower() : torch.nn.MultiLabelSoftMarginLoss,
}

class Loss(object):
'''a Loss object is a callable object represents loss functions
'''

def __init__(self , loss_name , pre_pro = [squash], **kwargs):
'''

:param loss_name: str or None , the name of loss function
:param pre_pro : list of function or str, methods to reform parameters before calculating loss
the strings will be auto translated to pre-defined functions
:param **kwargs: kwargs for torch loss function

pre_pro funcsions should have three arguments: predict, truth, **arg
predict and truth is the necessary parameters in loss function
kwargs is the extra parameters passed-in when calling loss function
pre_pro functions should return two objects, respectively predict and truth that after processed

'''

if loss_name is None:
# this is useful when Trainer.__init__ performs type check
self._loss = None
else:
if not isinstance(loss_name, str):
raise NotImplementedError
else:
self._loss = self._get_loss(loss_name , **kwargs)

self.pre_pro = [f if callable(f) else method_dict.get(f) for f in pre_pro]

def add_pre_pro(self , func):
'''add a pre_pro function

:param func: a function or str, methods to reform parameters before calculating loss
the strings will be auto translated to pre-defined functions
'''
if not callable(func):
func = method_dict.get(func)
if func is None:
return
self.pre_pro.append(func)

@staticmethod
def _get_loss(loss_name , **kwargs):
'''Get loss function from torch

:param loss_name: str, the name of loss function
:param **kwargs: kwargs for torch loss function
:return: A callable loss function object
'''
loss_name = loss_name.strip().lower()
loss_name = "".join(loss_name.split("_"))

if len(loss_name) < 4 or loss_name[-4 : ] != "loss":
loss_name += "loss"
return loss_function_name[loss_name](**kwargs)

def get(self):
'''This method exists just for make some existing codes run error-freely
'''
return self

def __call__(self , predict , truth , **kwargs):
'''call a loss function
predict and truth will be processed by pre_pro methods in order of addition

:param predict : Tensor, model output
:param truth : Tensor, truth from dataset
:param **kwargs : extra arguments, pass to pre_pro functions
for example, if used unpad_mask() in pre_pro, there should be a kwarg named lens
'''
for f in self.pre_pro:
if f is None:
continue
predict , truth = f(predict , truth , **kwargs)

return self._loss(predict , truth)

+ 358
- 0
fastNLP/core/losses.py View File

@@ -0,0 +1,358 @@
import inspect
from collections import defaultdict

import torch
import torch.nn.functional as F

from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _check_function_or_method
from fastNLP.core.utils import get_func_signature


class LossBase(object):
def __init__(self):
self.param_map = {}
self._checked = False

def get_loss(self, *args, **kwargs):
raise NotImplementedError

def _init_param_map(self, key_map=None, **kwargs):
"""Check the validity of key_map and other param map. Add these into self.param_map

:param key_map: dict
:param kwargs:
:return: None
"""
value_counter = defaultdict(set)
if key_map is not None:
if not isinstance(key_map, dict):
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
for key, value in key_map.items():
if value is None:
self.param_map[key] = key
continue
if not isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if not isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for key, value in kwargs.items():
if value is None:
self.param_map[key] = key
continue
if not isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for value, key_set in value_counter.items():
if len(key_set) > 1:
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")

# check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.get_loss)
func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.get_loss)}. Please check the "
f"initialization parameters, or change its signature.")

# evaluate should not have varargs.
if func_spect.varargs:
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
f"positional argument.).")

def _fast_param_map(self, pred_dict, target_dict):
"""

Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
return fast_param

def __call__(self, pred_dict, target_dict, check=False):
"""
:param pred_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y.
:param check: Boolean. Force to check the mapping functions when it is running.
:return:
"""
fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param:
loss = self.get_loss(**fast_param)
return loss

if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.get_loss)
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.get_loss)}.")

# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}

# need to wrap inputs in dict.
mapped_pred_dict = {}
mapped_target_dict = {}
duplicated = []
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
not_duplicate_flag = 0
if input_arg in self._reverse_param_map:
mapped_arg = self._reverse_param_map[input_arg]
not_duplicate_flag += 1
else:
mapped_arg = input_arg
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
not_duplicate_flag += 1
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]
not_duplicate_flag += 1
if not_duplicate_flag == 3:
duplicated.append(input_arg)

# missing
if not self._checked:
check_res = _check_arg_dict_list(self.get_loss, [mapped_pred_dict, mapped_target_dict])
# replace missing.
missing = check_res.missing
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"

check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.get_loss))
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)

loss = self.get_loss(**refined_args)
self._checked = True

return loss


class LossFunc(LossBase):
"""A wrapper of user-provided loss function.

"""
def __init__(self, func, key_map=None, **kwargs):
"""

:param func: a callable object, such as a function.
:param dict key_map:
:param kwargs:
"""
super(LossFunc, self).__init__()
_check_function_or_method(func)
if key_map is not None:
if not isinstance(key_map, dict):
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
self.param_map = key_map
if len(kwargs) > 0:
for key, val in kwargs.items():
self.param_map.update({key: val})

self.get_loss = func


class CrossEntropyLoss(LossBase):
def __init__(self, pred=None, target=None, padding_idx=-100):
# TODO 需要做一些检查,F.cross_entropy在计算时,如果pred是(16, 10 ,4), target的形状按道理应该是(16, 10), 但实际却需要
# TODO (16, 4)
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target)
self.padding_idx = padding_idx

def get_loss(self, pred, target):
return F.cross_entropy(input=pred, target=target,
ignore_index=self.padding_idx)


class L1Loss(LossBase):
def __init__(self, pred=None, target=None):
super(L1Loss, self).__init__()
self._init_param_map(pred=pred, target=target)

def get_loss(self, pred, target):
return F.l1_loss(input=pred, target=target)


class BCELoss(LossBase):
def __init__(self, pred=None, target=None):
super(BCELoss, self).__init__()
self._init_param_map(pred=pred, target=target)

def get_loss(self, pred, target):
return F.binary_cross_entropy(input=pred, target=target)


class NLLLoss(LossBase):
def __init__(self, pred=None, target=None):
super(NLLLoss, self).__init__()
self._init_param_map(pred=pred, target=target)

def get_loss(self, pred, target):
return F.nll_loss(input=pred, target=target)


class LossInForward(LossBase):
def __init__(self, loss_key='loss'):
super().__init__()
if not isinstance(loss_key, str):
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.")
self.loss_key = loss_key

def get_loss(self, **kwargs):
if self.loss_key not in kwargs:
check_res = CheckRes(missing=[self.loss_key + f"(assign to `{self.loss_key}` " \
f"in `{self.__class__.__name__}`"],
unused=[],
duplicated=[],
required=[],
all_needed=[],
varargs=[])
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss))
return kwargs[self.loss_key]

def __call__(self, pred_dict, target_dict, check=False):

loss = self.get_loss(**pred_dict)

if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor):
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}")
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}")

return loss


def _prepare_losser(losser):
if losser is None:
losser = LossInForward()
return losser
elif isinstance(losser, LossBase):
return losser
else:
raise TypeError(f"Type of loss should be `fastNLP.LossBase`, got {type(losser)}")


def squash(predict, truth, **kwargs):
"""To reshape tensors in order to fit loss functions in pytorch

:param predict : Tensor, model output
:param truth : Tensor, truth from dataset
:param **kwargs : extra arguments

:return predict , truth: predict & truth after processing
"""
return predict.view(-1, predict.size()[-1]), truth.view(-1, )


def unpad(predict, truth, **kwargs):
"""To process padded sequence output to get true loss
Using pack_padded_sequence() method
This method contains squash()

:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
kwargs["lens"] : list or LongTensor, [batch_size]
the i-th element is true lengths of i-th sequence

:return predict , truth: predict & truth after processing
"""
if kwargs.get("lens") is None:
return predict, truth
lens = torch.LongTensor(kwargs["lens"])
lens, idx = torch.sort(lens, descending=True)
predict = torch.nn.utils.rnn.pack_padded_sequence(predict[idx], lens, batch_first=True).data
truth = torch.nn.utils.rnn.pack_padded_sequence(truth[idx], lens, batch_first=True).data
return predict, truth


def unpad_mask(predict, truth, **kwargs):
"""To process padded sequence output to get true loss
Using mask() method
This method contains squash()

:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["lens"] is expected to be exsist
kwargs["lens"] : list or LongTensor, [batch_size]
the i-th element is true lengths of i-th sequence

:return predict , truth: predict & truth after processing
"""
if kwargs.get("lens") is None:
return predict, truth
mas = make_mask(kwargs["lens"], truth.size()[1])
return mask(predict, truth, mask=mas)


def mask(predict, truth, **kwargs):
"""To select specific elements from Tensor
This method contains squash()

:param predict : Tensor, [batch_size , max_len , tag_size]
:param truth : Tensor, [batch_size , max_len]
:param **kwargs : extra arguments, kwargs["mask"] is expected to be exsist
kwargs["mask"] : ByteTensor, [batch_size , max_len]
the mask Tensor , the position that is 1 will be selected

:return predict , truth: predict & truth after processing
"""
if kwargs.get("mask") is None:
return predict, truth
mask = kwargs["mask"]

predict, truth = squash(predict, truth)
mask = mask.view(-1, )

predict = torch.masked_select(predict.permute(1, 0), mask).view(predict.size()[-1], -1).permute(1, 0)
truth = torch.masked_select(truth, mask)

return predict, truth


def make_mask(lens, tar_len):
"""to generate a mask that select [:lens[i]] for i-th element
embezzle from fastNLP.models.sequence_modeling.seq_mask

:param lens : list or LongTensor, [batch_size]
:param tar_len : int

:return mask : ByteTensor
"""
lens = torch.LongTensor(lens)
mask = [torch.ge(lens, i + 1) for i in range(tar_len)]
mask = torch.stack(mask, 1)
return mask


+ 279
- 257
fastNLP/core/metrics.py View File

@@ -1,288 +1,310 @@
import warnings
import inspect
from collections import defaultdict


import numpy as np import numpy as np
import torch import torch



class Evaluator(object):
def __init__(self):
pass

def __call__(self, predict, truth):
"""

:param predict: list of tensors, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return:
"""
raise NotImplementedError
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks




class ClassifyEvaluator(Evaluator):
class MetricBase(object):
def __init__(self): def __init__(self):
super(ClassifyEvaluator, self).__init__()
self.param_map = {} # key is param in function, value is input param.
self._checked = False


def __call__(self, predict, truth):
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
y_true = torch.cat(truth, dim=0)
acc = float(torch.sum(y_pred == y_true)) / len(y_true)
return {"accuracy": acc}
def evaluate(self, *args, **kwargs):
raise NotImplementedError


def _init_param_map(self, key_map=None, **kwargs):
"""Check the validity of key_map and other param map. Add these into self.param_map


class SeqLabelEvaluator(Evaluator):
def __init__(self):
super(SeqLabelEvaluator, self).__init__()

def __call__(self, predict, truth, **_):
:param key_map: dict
:param kwargs:
:return: None
"""
value_counter = defaultdict(set)
if key_map is not None:
if not isinstance(key_map, dict):
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map)))
for key, value in key_map.items():
if value is None:
self.param_map[key] = key
continue
if not isinstance(key, str):
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.")
if not isinstance(value, str):
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for key, value in kwargs.items():
if value is None:
self.param_map[key] = key
continue
if not isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
for value, key_set in value_counter.items():
if len(key_set) > 1:
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")

# check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = [arg for arg in func_spect.args if arg != 'self']
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.")

# evaluate should not have varargs.
if func_spect.varargs:
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use "
f"positional argument.).")

def get_metric(self, reset=True):
raise NotImplemented

def _fast_param_map(self, pred_dict, target_dict):
""" """


:param predict: list of List, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
""" """
total_correct, total_count = 0., 0.
for x, y in zip(predict, truth):
x = torch.tensor(x)
y = y.to(x) # make sure they are in the same device
mask = (y > 0)
correct = torch.sum(((x == y) * mask).long())
total_correct += float(correct)
total_count += float(torch.sum(mask.long()))
accuracy = total_correct / total_count
return {"accuracy": float(accuracy)}

class SeqLabelEvaluator2(Evaluator):
# 上面的evaluator应该是错误的
def __init__(self, seq_lens_field_name='word_seq_origin_len'):
super(SeqLabelEvaluator2, self).__init__()
self.end_tagidx_set = set()
self.seq_lens_field_name = seq_lens_field_name

def __call__(self, predict, truth, **_):
fast_param = {}
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(pred_dict.values())[0]
return fast_param
return fast_param

def __call__(self, pred_dict, target_dict):
""" """


:param predict: list of batch, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y.
:return accuracy:
This method will call self.evaluate method.
Before calling self.evaluate, it will first check the validity of output_dict, target_dict
(1) whether self.evaluate has varargs, which is not supported.
(2) whether params needed by self.evaluate is not included in output_dict,target_dict.
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted.)
This function also support _fast_param_map.
:param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target..
:return:
"""
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict)
if fast_param:
self.evaluate(**fast_param)
return

if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = set([arg for arg in func_spect.args if arg != 'self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")

# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}

# need to wrap inputs in dict.
mapped_pred_dict = {}
mapped_target_dict = {}
duplicated = []
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
not_duplicate_flag = 0
if input_arg in self._reverse_param_map:
mapped_arg = self._reverse_param_map[input_arg]
not_duplicate_flag += 1
else:
mapped_arg = input_arg
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
not_duplicate_flag += 1
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]
not_duplicate_flag += 1
if not_duplicate_flag == 3:
duplicated.append(input_arg)

# missing
if not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
# only check missing.
# replace missing.
missing = check_res.missing
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"

check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)

self.evaluate(**refined_args)
self._checked = True

return


class AccuracyMetric(MetricBase):
def __init__(self, pred=None, target=None, seq_lens=None):
super().__init__()

self._init_param_map(pred=pred, target=target, seq_lens=seq_lens)

self.total = 0
self.acc_count = 0

def _fast_param_map(self, pred_dict, target_dict):
""" """
seq_lens = _[self.seq_lens_field_name]
corr_count = 0
pred_count = 0
truth_count = 0
for x, y, seq_len in zip(predict, truth, seq_lens):
x = x.cpu().numpy()
y = y.cpu().numpy()
for idx, s_l in enumerate(seq_len):
x_ = x[idx]
y_ = y[idx]
x_ = x_[:s_l]
y_ = y_[:s_l]
flag = True
start = 0
for idx_i, (x_i, y_i) in enumerate(zip(x_, y_)):
if x_i in self.end_tagidx_set:
truth_count += 1
for j in range(start, idx_i + 1):
if y_[j]!=x_[j]:
flag = False
break
if flag:
corr_count += 1
flag = True
start = idx_i + 1
if y_i in self.end_tagidx_set:
pred_count += 1
P = corr_count / (float(pred_count) + 1e-6)
R = corr_count / (float(truth_count) + 1e-6)
F = 2 * P * R / (P + R + 1e-6)

return {"P": P, 'R':R, 'F': F}



class SNLIEvaluator(Evaluator):
def __init__(self):
super(SNLIEvaluator, self).__init__()

def __call__(self, predict, truth):
y_prob = [torch.nn.functional.softmax(y_logit, dim=-1) for y_logit in predict]
y_prob = torch.cat(y_prob, dim=0)
y_pred = torch.argmax(y_prob, dim=-1)
truth = [t['truth'] for t in truth]
y_true = torch.cat(truth, dim=0).view(-1)
acc = float(torch.sum(y_pred == y_true)) / y_true.size(0)
return {"accuracy": acc}


Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
targets = list(target_dict.values())
if len(targets) == 1 and isinstance(targets[0], torch.Tensor):
if len(pred_dict) == 1:
pred = list(pred_dict.values())[0]
fast_param['pred'] = pred
elif len(pred_dict) == 2:
pred1 = list(pred_dict.values())[0]
pred2 = list(pred_dict.values())[1]
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)):
return fast_param
if len(pred1.size()) < len(pred2.size()) and len(pred1.size()) == 1:
seq_lens = pred1
pred = pred2
elif len(pred1.size()) > len(pred2.size()) and len(pred2.size()) == 1:
seq_lens = pred2
pred = pred1
else:
return fast_param
fast_param['pred'] = pred
fast_param['seq_lens'] = seq_lens
else:
return fast_param
fast_param['target'] = targets[0]
# TODO need to make sure they all have same batch_size
return fast_param

def evaluate(self, pred, target, seq_lens=None):
"""


def _conver_numpy(x):
"""convert input data to numpy array
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes])
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len])
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float})
"""
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.")

if seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)
else:
masks = None


"""
if isinstance(x, np.ndarray):
return x
elif isinstance(x, torch.Tensor):
return x.numpy()
elif isinstance(x, list):
return np.array(x)
raise TypeError('cannot accept object: {}'.format(x))
if pred.size() == target.size():
pass
elif len(pred.size()) == len(target.size()) + 1:
pred = pred.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")


pred = pred.float()
target = target.float()


def _check_same_len(*arrays, axis=0):
"""check if input array list has same length for one dimension
if masks is not None:
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item()
else:
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.total += np.prod(list(pred.size()))


"""
lens = set([x.shape[axis] for x in arrays if x is not None])
return len(lens) == 1
def get_metric(self, reset=True):
evaluate_result = {'acc': round(self.acc_count / self.total, 6)}
if reset:
self.acc_count = 0
self.total = 0
return evaluate_result




def _label_types(y):
"""Determine the type
- "binary"
- "multiclass"
- "multiclass-multioutput"
- "multilabel"
- "unknown"
def _prepare_metrics(metrics):
""" """
# never squeeze the first dimension
y = y.squeeze() if y.shape[0] > 1 else y.resize(1, -1)
shape = y.shape
if len(shape) < 1:
raise ValueError('cannot accept data: {}'.format(y))
if len(shape) == 1:
return 'multiclass' if np.unique(y).shape[0] > 2 else 'binary', y
if len(shape) == 2:
return 'multiclass-multioutput' if np.unique(y).shape[0] > 2 else 'multilabel', y
return 'unknown', y


def _check_data(y_true, y_pred):
"""Check if y_true and y_pred is same type of data e.g both binary or multiclass


Prepare list of Metric based on input
:param metrics:
:return: List[fastNLP.MetricBase]
""" """
y_true, y_pred = _conver_numpy(y_true), _conver_numpy(y_pred)
if not _check_same_len(y_true, y_pred):
raise ValueError('cannot accept data with different shape {0}, {1}'.format(y_true, y_pred))
type_true, y_true = _label_types(y_true)
type_pred, y_pred = _label_types(y_pred)

type_set = set(['binary', 'multiclass'])
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass', y_true, y_pred

type_set = set(['multiclass-multioutput', 'multilabel'])
if type_true in type_set and type_pred in type_set:
return type_true if type_true == type_pred else 'multiclass-multioutput', y_true, y_pred

raise ValueError('cannot accept data mixed of {0} and {1} target'.format(type_true, type_pred))


def _weight_sum(y, normalize=True, sample_weight=None):
if normalize:
return np.average(y, weights=sample_weight)
if sample_weight is None:
return y.sum()
else:
return np.dot(y, sample_weight)


def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if y_type == 'multiclass-multioutput':
raise ValueError('cannot accept data type {0}'.format(y_type))
if y_type == 'multilabel':
equel = (y_true == y_pred).sum(1)
count = equel == y_true.shape[1]
else:
count = y_true == y_pred
return _weight_sum(count, normalize=normalize, sample_weight=sample_weight)


def recall_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if average == 'binary':
if y_type != 'binary':
raise ValueError("data type is {} but use average type {}".format(y_type, average))
else:
pos = (y_true == pos_label)
tp = np.logical_and((y_true == y_pred), pos).sum()
pos_sum = pos.sum()
return tp / pos_sum if pos_sum > 0 else 0
elif average == None:
y_labels = set(list(np.unique(y_true)))
if labels is None:
labels = list(y_labels)
else:
for i in labels:
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]):
warnings.warn('label {} is not contained in data'.format(i), UserWarning)

if y_type in ['binary', 'multiclass']:
y_pred_right = y_true == y_pred
pos_list = [y_true == i for i in labels]
pos_sum_list = [pos_i.sum() for pos_i in pos_list]
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
elif y_type == 'multilabel':
y_pred_right = y_true == y_pred
pos = (y_true == pos_label)
tp = np.logical_and(y_pred_right, pos).sum(0)
pos_sum = pos.sum(0)
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels])
_metrics = []
if metrics:
if isinstance(metrics, list):
for metric in metrics:
if isinstance(metric, type):
metric = metric()
if isinstance(metric, MetricBase):
metric_name = metric.__class__.__name__
if not callable(metric.evaluate):
raise TypeError(f"{metric_name}.evaluate must be callable, got {type(metric.evaluate)}.")
if not callable(metric.get_metric):
raise TypeError(f"{metric_name}.get_metric must be callable, got {type(metric.get_metric)}.")
_metrics.append(metric)
else:
raise TypeError(
f"The type of metric in metrics must be `fastNLP.MetricBase`, not `{type(metric)}`.")
elif isinstance(metrics, MetricBase):
_metrics = [metrics]
else: else:
raise ValueError('not support targets type {}'.format(y_type))
raise ValueError('not support for average type {}'.format(average))


def precision_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
y_type, y_true, y_pred = _check_data(y_true, y_pred)
if average == 'binary':
if y_type != 'binary':
raise ValueError("data type is {} but use average type {}".format(y_type, average))
else:
pos = (y_true == pos_label)
tp = np.logical_and((y_true == y_pred), pos).sum()
pos_pred = (y_pred == pos_label).sum()
return tp / pos_pred if pos_pred > 0 else 0
elif average == None:
y_labels = set(list(np.unique(y_true)))
if labels is None:
labels = list(y_labels)
else:
for i in labels:
if (i not in y_labels and y_type != 'multilabel') or (y_type == 'multilabel' and i >= y_true.shape[1]):
warnings.warn('label {} is not contained in data'.format(i), UserWarning)

if y_type in ['binary', 'multiclass']:
y_pred_right = y_true == y_pred
pos_list = [y_true == i for i in labels]
pos_sum_list = [(y_pred == i).sum() for i in labels]
return np.array([np.logical_and(y_pred_right, pos_i).sum() / sum_i if sum_i > 0 else 0 \
for pos_i, sum_i in zip(pos_list, pos_sum_list)])
elif y_type == 'multilabel':
y_pred_right = y_true == y_pred
pos = (y_true == pos_label)
tp = np.logical_and(y_pred_right, pos).sum(0)
pos_sum = (y_pred == pos_label).sum(0)
return np.array([tp[i] / pos_sum[i] if pos_sum[i] > 0 else 0 for i in labels])
else:
raise ValueError('not support targets type {}'.format(y_type))
raise ValueError('not support for average type {}'.format(average))


def f1_score(y_true, y_pred, labels=None, pos_label=1, average='binary'):
precision = precision_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
recall = recall_score(y_true, y_pred, labels=labels, pos_label=pos_label, average=average)
if isinstance(precision, np.ndarray):
res = 2 * precision * recall / (precision + recall)
res[(precision + recall) <= 0] = 0
return res
return 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0


def classification_report(y_true, y_pred, labels=None, target_names=None, digits=2):
raise NotImplementedError
raise TypeError(f"The type of metrics should be `list[fastNLP.MetricBase]` or `fastNLP.MetricBase`, "
f"got {type(metrics)}.")
return _metrics




def accuracy_topk(y_true, y_prob, k=1): def accuracy_topk(y_true, y_prob, k=1):


+ 32
- 45
fastNLP/core/optimizer.py View File

@@ -2,61 +2,48 @@ import torch




class Optimizer(object): class Optimizer(object):
"""Wrapper of optimizer from framework
def __init__(self, model_params, **kwargs):
if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs


1. Adam: lr (float), weight_decay (float)
2. AdaGrad
3. RMSProp
4. SGD: lr (float), momentum (float)


"""

def __init__(self, optimizer_name, **kwargs):
class SGD(Optimizer):
def __init__(self, lr=0.01, momentum=0, model_params=None):
""" """
:param optimizer_name: str, the name of the optimizer
:param kwargs: the arguments


:param float lr: learning rate. Default: 0.01
:param float momentum: momentum. Default: 0
:param model_params: a generator. E.g. model.parameters() for PyTorch models.
""" """
self.optim_name = optimizer_name
self.kwargs = kwargs
if not isinstance(lr, float):
raise TypeError("learning rate has to be float.")
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)


@property
def name(self):
"""The name of the optimizer.
def construct_from_pytorch(self, model_params):
if self.model_params is None:
# careful! generator cannot be assigned.
return torch.optim.SGD(model_params, **self.settings)
else:
return torch.optim.SGD(self.model_params, **self.settings)


:return: str
"""
return self.optim_name


@property
def params(self):
"""The arguments used to create the optimizer.
class Adam(Optimizer):
def __init__(self, lr=0.01, weight_decay=0, model_params=None):
"""


:return: dict of (str, *)
:param float lr: learning rate
:param float weight_decay:
:param model_params: a generator. E.g. model.parameters() for PyTorch models.
""" """
return self.kwargs
if not isinstance(lr, float):
raise TypeError("learning rate has to be float.")
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
"""Construct a optimizer from framework over given model parameters."""

if self.optim_name in ["SGD", "sgd"]:
if "lr" in self.kwargs:
if "momentum" not in self.kwargs:
self.kwargs["momentum"] = 0
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"])
else:
raise ValueError("requires learning rate for SGD optimizer")

elif self.optim_name in ["adam", "Adam"]:
if "lr" in self.kwargs:
if "weight_decay" not in self.kwargs:
self.kwargs["weight_decay"] = 0
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"],
weight_decay=self.kwargs["weight_decay"])
else:
raise ValueError("requires learning rate for Adam optimizer")

if self.model_params is None:
# careful! generator cannot be assigned.
return torch.optim.Adam(model_params, **self.settings)
else: else:
raise NotImplementedError

return optimizer
return torch.optim.Adam(self.model_params, **self.settings)

+ 2
- 19
fastNLP/core/predictor.py View File

@@ -1,4 +1,3 @@
import numpy as np
import torch import torch


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
@@ -23,13 +22,13 @@ class Predictor(object):


:param network: a PyTorch model (cpu) :param network: a PyTorch model (cpu)
:param data: a DataSet object. :param data: a DataSet object.
:return: list of list of strings, [num_examples, tag_seq_length]
:return: list of batch outputs
""" """
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
self.mode(network, test=True) self.mode(network, test=True)
batch_output = [] batch_output = []


data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), use_cuda=False)
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)


for batch_x, _ in data_iterator: for batch_x, _ in data_iterator:
with torch.no_grad(): with torch.no_grad():
@@ -48,19 +47,3 @@ class Predictor(object):
"""Forward through network.""" """Forward through network."""
y = network(**x) y = network(**x)
return y return y


def seq_label_post_processor(batch_outputs, label_vocab):
results = []
for batch in batch_outputs:
for example in np.array(batch):
results.append([label_vocab.to_word(int(x)) for x in example])
return results


def text_classify_post_processor(batch_outputs, label_vocab):
results = []
for batch_out in batch_outputs:
idx = np.argmax(batch_out.detach().numpy(), axis=-1)
results.extend([label_vocab.to_word(i) for i in idx])
return results

+ 1
- 1
fastNLP/core/sampler.py View File

@@ -55,7 +55,7 @@ class BucketSampler(BaseSampler):


def __call__(self, data_set): def __call__(self, data_set):


seq_lens = data_set[self.seq_lens_field_name].content
seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content
total_sample_num = len(seq_lens) total_sample_num = len(seq_lens)


bucket_indexes = [] bucket_indexes = []


+ 69
- 36
fastNLP/core/tester.py View File

@@ -1,60 +1,88 @@
import itertools
from collections import defaultdict from collections import defaultdict


import torch import torch
from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.dataset import DataSet
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature



class Tester(object): class Tester(object):
"""An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """


def __init__(self, data, model, batch_size=16, use_cuda=False):
def __init__(self, data, model, metrics, batch_size=16, use_cuda=False, verbose=1):
super(Tester, self).__init__() super(Tester, self).__init__()
self.use_cuda = use_cuda

if not isinstance(data, DataSet):
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")

self.metrics = _prepare_metrics(metrics)

self.data = data self.data = data
self.use_cuda = use_cuda
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose
self._model_device = model.parameters().__next__().device

if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = model.cuda() self._model = model.cuda()
else: else:
self._model = model self._model = model

# check predict
if hasattr(self._model, 'predict'): if hasattr(self._model, 'predict'):
assert callable(self._model.predict)
self._predict_func = self._model.predict self._predict_func = self._model.predict
if not callable(self._predict_func):
_model_name = model.__class__.__name__
raise TypeError(f"`{_model_name}.predict` must be callable to be used "
f"for evaluation, not `{type(self._predict_func)}`.")
else: else:
self._predict_func = self._model
assert hasattr(model, 'evaluate')
self._evaluator = model.evaluate
self.eval_history = [] # evaluation results of all batches
self._predict_func = self._model.forward


def test(self): def test(self):
# turn on the testing mode; clean up the history # turn on the testing mode; clean up the history
network = self._model network = self._model
self.mode(network, is_test=True)
self.eval_history.clear()
output, truths = defaultdict(list), defaultdict(list)
data_iterator = Batch(self.data, self.batch_size, sampler=RandomSampler(), as_numpy=False)

with torch.no_grad():
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(network, batch_x)
assert isinstance(prediction, dict)
for k, v in prediction.items():
output[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
for k, v in output.items():
output[k] = itertools.chain(*v)
for k, v in truths.items():
truths[k] = itertools.chain(*v)
args = _build_args(self._evaluator, **output, **truths)
eval_results = self._evaluator(**args)
print("[tester] {}".format(self.print_eval_results(eval_results)))
self.mode(network, is_test=False)
self._mode(network, is_test=True)
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
eval_results = {}
try:
with torch.no_grad():
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
pred_dict = self._data_forward(self._predict_func, batch_x)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {get_func_signature(self._predict_func)} "
f"must be `dict`, got {type(pred_dict)}.")
for metric in self.metrics:
metric(pred_dict, batch_y)
for metric in self.metrics:
eval_result = metric.get_metric()
if not isinstance(eval_result, dict):
raise TypeError(f"The return value of {get_func_signature(metric.get_metric)} must be "
f"`dict`, got {type(eval_result)}")
metric_name = metric.__class__.__name__
eval_results[metric_name] = eval_result
except CheckError as e:
prev_func_signature = get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0)

if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results return eval_results


def mode(self, model, is_test=False):
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


:param model: a PyTorch model :param model: a PyTorch model
@@ -66,16 +94,21 @@ class Tester(object):
else: else:
model.train() model.train()


def data_forward(self, network, x):
def _data_forward(self, func, x):
"""A forward pass of the model. """ """A forward pass of the model. """
x = _build_args(network.forward, **x)
y = self._predict_func(**x)
x = _build_args(func, **x)
y = func(**x)
return y return y


def print_eval_results(self, results):
def _format_eval_results(self, results):
"""Override this method to support more print formats. """Override this method to support more print formats.


:param results: dict, (str: float) is (metrics name: value) :param results: dict, (str: float) is (metrics name: value)


""" """
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])
_str = ''
for metric_name, metric_result in results.items():
_str += metric_name + ': '
_str += ", ".join([str(key) + "=" + str(value) for key, value in metric_result.items()])
_str += '\n'
return _str[:-1]

+ 301
- 277
fastNLP/core/trainer.py View File

@@ -1,160 +1,275 @@
import os
import time import time
from datetime import timedelta
from datetime import datetime from datetime import datetime
import warnings
from collections import defaultdict
import os
import itertools
import shutil
from datetime import timedelta


from tensorboardX import SummaryWriter
import torch import torch
from tensorboardX import SummaryWriter
from torch import nn
from tqdm.autonotebook import tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import BaseSampler
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _syn_model_data
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature



class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """
def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True,
**kwargs):
def __init__(self, train_data, model, loss=None, metrics=None, n_epochs=3, batch_size=32, print_every=50,
validate_every=-1, dev_data=None, use_cuda=False, save_path=None,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, sampler=RandomSampler(), use_tqdm=True):
"""

:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param LossBase loss: a loss object
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data
:param use_cuda:
:param save_path: file path to save models
:param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
it will raise error if some field are not used.
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add a `-` character in front of the string. For example
::
metric_key="-PPL" # language model gets better as perplexity gets smaller
:param sampler: method used to generate batch data.
:param use_tqdm: boolean, use tqdm to show train progress.

"""
super(Trainer, self).__init__() super(Trainer, self).__init__()


if not isinstance(train_data, DataSet):
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.")
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")

# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")

# check save_path
if not (save_path is None or isinstance(save_path, str)):
raise ValueError("save_path can only be None or `str`.")
# prepare evaluate
metrics = _prepare_metrics(metrics)

# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = True
if metric_key is not None:
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
elif len(metrics) > 0:
self.metric_key = metrics[0].__class__.__name__.lower().strip('metric')

# prepare loss
losser = _prepare_losser(loss)

# sampler check
if not isinstance(sampler, BaseSampler):
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))

if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level)

self.train_data = train_data self.train_data = train_data
self.dev_data = dev_data # If None, No validation. self.dev_data = dev_data # If None, No validation.
self.model = model self.model = model
self.losser = losser
self.metrics = metrics
self.n_epochs = int(n_epochs) self.n_epochs = int(n_epochs)
self.batch_size = int(batch_size) self.batch_size = int(batch_size)
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
self.save_path = save_path self.save_path = save_path
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self._best_accuracy = 0
self.best_metric_indicator = None
self.sampler = sampler


if need_check_code:
_check_code(dataset=train_data, model=model, dev_data=dev_data)
self._model_device = model.parameters().__next__().device


model_name = model.__class__.__name__
assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name)
self.loss_func = self.model.get_loss
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
else: else:
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())


assert hasattr(self.model, 'evaluate'), "model {} has to have a 'evaluate' function.".format(model_name)
self.evaluator = self.model.evaluate
self.use_tqdm = use_tqdm
if self.use_tqdm:
tester_verbose = 0
else:
tester_verbose = 1


if self.dev_data is not None: if self.dev_data is not None:
self.tester = Tester(model=self.model, self.tester = Tester(model=self.model,
data=self.dev_data, data=self.dev_data,
metrics=self.metrics,
batch_size=self.batch_size, batch_size=self.batch_size,
use_cuda=self.use_cuda)

for k, v in kwargs.items():
setattr(self, k, v)
use_cuda=self.use_cuda,
verbose=tester_verbose)


self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp


# print(self.__dict__)

def train(self): def train(self):
"""Start Training. """Start Training.


:return:
""" """
try: try:
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda() self.model = self.model.cuda()


self.mode(self.model, is_test=False)
self._mode(self.model, is_test=False)


start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
print("training epochs started " + self.start_time)
self.start_time = str(datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
print("training epochs started " + self.start_time, flush=True)
if self.save_path is None: if self.save_path is None:
class psudoSW: class psudoSW:
def __getattr__(self, item): def __getattr__(self, item):
def pass_func(*args, **kwargs): def pass_func(*args, **kwargs):
pass pass

return pass_func return pass_func

self._summary_writer = psudoSW() self._summary_writer = psudoSW()
else: else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)
if self.use_tqdm:
self._tqdm_train()
else:
self._print_train()


epoch = 1
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)

self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self.do_validation()
epoch += 1
finally: finally:
self._summary_writer.close() self._summary_writer.close()
del self._summary_writer del self._summary_writer


def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
"""Training process in one epoch.

kwargs should contain:
- n_print: int, print training information every n steps.
- start: time.time(), the starting time of this step.
- epoch: int,
"""
for batch_x, batch_y in data_iterator:
prediction = self.data_forward(model, batch_x)

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)

if self.validate_every > 0 and self.step % self.validate_every == 0:
self.do_validation()

self.step += 1

def do_validation(self):
def _tqdm_train(self):
self.step = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
total_steps = data_iterator.num_batches*self.n_epochs
epoch = 1
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
ava_loss = 0
for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)
loss = self._compute_loss(prediction, batch_y)
ava_loss += loss.item()
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if (self.step+1) % self.print_every == 0:
pbar.set_postfix_str("loss:{0:<6.5f}".format(ava_loss / self.print_every))
ava_loss = 0
pbar.update(1)
self.step += 1
if self.validate_every > 0 and self.step % self.validate_every == 0 \
and self.dev_data is not None:
eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)
if self.validate_every < 0 and self.dev_data:
eval_res = self._do_validation()
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str)
if epoch!=self.n_epochs:
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
pbar.close()

def _print_train(self):
epoch = 1
start = time.time()
while epoch <= self.n_epochs:

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)

for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)

if (self.validate_every > 0 and self.step % self.validate_every == 0 and
self.dev_data is not None):
self._do_validation()

self.step += 1

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self._do_validation()
epoch += 1

def _do_validation(self):
res = self.tester.test() res = self.tester.test()
for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self.best_eval_result(res):
self.save_model(self.model, 'best_model_' + self.start_time)

def mode(self, model, is_test=False):
for name, metric in res.items():
for metric_key, metric_val in metric.items():
self._summary_writer.add_scalar("valid_{}_{}".format(name, metric_key), metric_val,
global_step=self.step)
if self.save_path is not None and self._better_eval_result(res):
metric_key = self.metric_key if self.metric_key is not None else ""
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, metric_key, self.start_time]))
return res

def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


:param model: a PyTorch model :param model: a PyTorch model
:param is_test: bool, whether in test mode or not.
:param bool is_test: whether in test mode or not.


""" """
if is_test: if is_test:
@@ -162,18 +277,20 @@ class Trainer(object):
else: else:
model.train() model.train()


def update(self):
def _update(self):
"""Perform weight update on a model. """Perform weight update on a model.


""" """
self.optimizer.step() self.optimizer.step()


def data_forward(self, network, x):
def _data_forward(self, network, x):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y


def grad_backward(self, loss):
def _grad_backward(self, loss):
"""Compute gradient with link rules. """Compute gradient with link rules.


:param loss: a scalar where back-prop starts :param loss: a scalar where back-prop starts
@@ -183,223 +300,130 @@ class Trainer(object):
self.model.zero_grad() self.model.zero_grad()
loss.backward() loss.backward()


def get_loss(self, predict, truth):
def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth. """Compute loss given prediction and ground truth.


:param predict: prediction label vector
:param truth: ground truth label vector
:param predict: prediction dict, produced by model.forward
:param truth: ground truth dict, produced by batch_y
:return: a scalar :return: a scalar
""" """
assert isinstance(predict, dict) and isinstance(truth, dict)
args = _build_args(self.loss_func, **predict, **truth)
return self.loss_func(**args)

def save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:
torch.save(model, model_name)
return self.losser(predict, truth)

def _save_model(self, model, model_name, only_param=False):
if self.save_path is not None:
model_name = os.path.join(self.save_path, model_name)
if only_param:
torch.save(model.state_dict(), model_name)
else:
torch.save(model, model_name)


def best_eval_result(self, metrics):
def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results. """Check if the current epoch yields better validation results.


:return: bool, True means current results on dev set is the best.
:return bool value: True means current results on dev set is the best.
""" """
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
accuracy = list(metrics.values())[0]
else:
accuracy = metrics[self.eval_sort_key]
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else: else:
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better




DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2


IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2


def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL):
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None,
check_level=0):
# check get_loss 方法 # check get_loss 方法
model_name = model.__class__.__name__
if not hasattr(model, 'get_loss'):
raise AttributeError("{} has to have a 'get_loss' function.".format(model_name))
model_devcie = model.parameters().__next__().device


batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check # forward check
if batch_count==0: if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level,
batch_x=batch_x)
_check_forward_error(forward_func=model.forward, dataset=dataset,
batch_x=batch_x, check_level=check_level)


refined_batch_x = _build_args(model.forward, **batch_x) refined_batch_x = _build_args(model.forward, **batch_x)
output = model(**refined_batch_x)
pred_dict = model(**refined_batch_x)
func_signature = get_func_signature(model.forward) func_signature = get_func_signature(model.forward)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
if not isinstance(pred_dict, dict):
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.")


# loss check # loss check
if batch_count == 0:
_check_loss_evaluate(prev_func=model.forward, func=model.get_loss, check_level=check_level,
output=output, batch_y=batch_y)
loss_input = _build_args(model.get_loss, **output, **batch_y)
loss = model.get_loss(**loss_input)

# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise ValueError("The return value of {}.get_loss() should be torch.Tensor, but {} got.".
format(model_name, type(loss)))
if len(loss.size())!=0:
raise ValueError("The size of return value of {}.get_loss() is {}, should be torch.size([])".format(
model_name, loss.size()
))
loss.backward()
try:
loss = losser(pred_dict, batch_y)
# check loss output
if batch_count == 0:
if not isinstance(loss, torch.Tensor):
raise TypeError(
f"The return value of {get_func_signature(losser.get_loss)} should be `torch.Tensor`, "
f"but got `{type(loss)}`.")
if len(loss.size()) != 0:
raise ValueError(
f"The size of return value of {get_func_signature(losser.get_loss)} is {loss.size()}, "
f"should be torch.size([])")
loss.backward()
except CheckError as e:
# TODO: another error raised if CheckError caught
pre_func_signature = get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=dataset, check_level=check_level)
model.zero_grad() model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break break


if dev_data is not None: if dev_data is not None:
if not hasattr(model, 'evaluate'):
raise AttributeError("{} has to have a 'evaluate' function to do evaluation. Or set"
"dev_data to 'None'."
.format(model_name))
outputs, truths = defaultdict(list), defaultdict(list)
dev_batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
with torch.no_grad():
for batch_count, (batch_x, batch_y) in enumerate(dev_batch):
_syn_model_data(model, batch_x, batch_y)

if hasattr(model, 'predict'):
refined_batch_x = _build_args(model.predict, **batch_x)
prev_func = model.predict
output = prev_func(**refined_batch_x)
func_signature = get_func_signature(model.predict)
assert isinstance(output, dict), "The return value of {} should be dict.".format(func_signature)
else:
refined_batch_x = _build_args(model.forward, **batch_x)
prev_func = model.forward
output = prev_func(**refined_batch_x)
for k, v in output.items():
outputs[k].append(v)
for k, v in batch_y.items():
truths[k].append(v)
if batch_count+1>DEFAULT_CHECK_NUM_BATCH:
break
for k, v in outputs.items():
outputs[k] = itertools.chain(*v)
for k, v in truths.items():
truths[k] = itertools.chain(*v)
_check_loss_evaluate(prev_func=prev_func, func=model.evaluate, check_level=check_level,
output=outputs, batch_y=truths)
refined_input = _build_args(model.evaluate, **outputs, **truths)
metrics = model.evaluate(**refined_input)
func_signature = get_func_signature(model.evaluate)
assert isinstance(metrics, dict), "The return value of {} should be dict.". \
format(func_signature)


def _check_forward_error(model_func, check_level, batch_x):
check_res = _check_arg_dict_list(model_func, batch_x)
_missing = ''
_unused = ''
func_signature = get_func_signature(model_func)
if len(check_res.missing)!=0:
_missing = "Function {} misses {}, only provided with {}, " \
".\n".format(func_signature, check_res.missing,
list(batch_x.keys()))
if len(check_res.unused)!=0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1)
evaluate_results = tester.test()
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)


def _check_eval_results(metrics, metric_key, metric_list):
# metrics: tester返回的结果
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化
# metric_list: 多个用来做评价的指标,来自Trainer的初始化
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else: else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if _missing:
if len(_unused)>0 and STRICT_CHECK_LEVEL:
_error_str = "(1).{}\n(2).{}".format(_missing, _unused)
metrics_name = metric_list[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else: else:
_error_str = _missing
# TODO 这里可能需要自定义一些Error类型
raise TypeError(_error_str)
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
raise ValueError(_unused)
elif check_level == WARNING_CHECK_LEVEL:
warnings.warn(message=_unused)

def _check_loss_evaluate(prev_func, func, check_level, output, batch_y):

check_res = _check_arg_dict_list(func, [output, batch_y])
_missing = ''
_unused = ''
_duplicated = ''
func_signature = get_func_signature(func)
prev_func_signature = get_func_signature(prev_func)
if len(check_res.missing)>0:
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \
"{}(from target in Dataset)." \
.format(func_signature, check_res.missing,
list(output.keys()), prev_func_signature,
list(batch_y.keys()))
if len(check_res.unused)>0:
if len(check_res.unused) > 1:
_unused = "{} are not used ".format(check_res.unused)
else:
_unused = "{} is not used ".format(check_res.unused)
_unused += "in function {}.\n".format(func_signature)
if len(check_res.duplicated)>0:
if len(check_res.duplicated) > 1:
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \
"them in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
else:
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \
"it in {} at the same time.".format(check_res.duplicated,
func_signature,
check_res.duplicated,
prev_func_signature)
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0)
if _number_errs > 0:
_error_strs = []
if _number_errs > 1:
count = 0
order_words = ['Firstly', 'Secondly', 'Thirdly']
if _missing:
_error_strs.append('{}, {}'.format(order_words[count], _missing))
count += 1
if _duplicated:
_error_strs.append('{}, {}'.format(order_words[count], _duplicated))
count += 1
if _unused and check_level == STRICT_CHECK_LEVEL:
_error_strs.append('{}, {}'.format(order_words[count], _unused))
else:
if _unused:
if check_level == STRICT_CHECK_LEVEL:
# TODO 这里可能需要自定义一些Error类型
_error_strs.append(_unused)
elif check_level == WARNING_CHECK_LEVEL:
_unused = _unused.strip()
warnings.warn(_unused)
else:
if _missing:
_error_strs.append(_missing)
if _duplicated:
_error_strs.append(_duplicated)

if _error_strs:
raise ValueError('\n' + '\n'.join(_error_strs))
# metric_key is set
if metric_key not in metric_dict:
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}")
indicator_val = metric_dict[metric_key]
else:
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics)))
return indicator_val

+ 317
- 12
fastNLP/core/utils.py View File

@@ -1,10 +1,15 @@
import _pickle import _pickle
import inspect import inspect
import os import os
import warnings
from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)
import numpy as np
import torch

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'], verbose=False)




def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
@@ -50,6 +55,7 @@ def pickle_exist(pickle_path, pickle_name):
else: else:
return False return False



def _build_args(func, **kwargs): def _build_args(func, **kwargs):
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
if spect.varkw is not None: if spect.varkw is not None:
@@ -64,6 +70,38 @@ def _build_args(func, **kwargs):
return output return output




def _map_args(maps: dict, **kwargs):
# maps: key=old name, value= new name
output = {}
for name, val in kwargs.items():
if name in maps:
assert isinstance(maps[name], str)
output.update({maps[name]: val})
else:
output.update({name: val})
for keys in maps.keys():
if keys not in output.keys():
# TODO: add UNUSED warning.
pass
return output


def _get_arg_list(func):
assert callable(func)
spect = inspect.getfullargspec(func)
if spect.defaults is not None:
args = spect.args[: -len(spect.defaults)]
defaults = spect.args[-len(spect.defaults):]
defaults_val = spect.defaults
else:
args = spect.args
defaults = None
defaults_val = None
varargs = spect.varargs
kwargs = spect.varkw
return args, defaults, defaults_val, varargs, kwargs


# check args # check args
def _check_arg_dict_list(func, args): def _check_arg_dict_list(func, args):
if isinstance(args, dict): if isinstance(args, dict):
@@ -73,8 +111,7 @@ def _check_arg_dict_list(func, args):
assert callable(func) and isinstance(arg_dict_list, (list, tuple)) assert callable(func) and isinstance(arg_dict_list, (list, tuple))
assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict) assert len(arg_dict_list) > 0 and isinstance(arg_dict_list[0], dict)
spect = inspect.getfullargspec(func) spect = inspect.getfullargspec(func)
assert spect.varargs is None, 'Positional Arguments({}) are not supported.'.format(spect.varargs)
all_args = set([arg for arg in spect.args if arg!='self'])
all_args = set([arg for arg in spect.args if arg != 'self'])
defaults = [] defaults = []
if spect.defaults is not None: if spect.defaults is not None:
defaults = [arg for arg in spect.defaults] defaults = [arg for arg in spect.defaults]
@@ -88,19 +125,39 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys()) input_args = set(input_arg_count.keys())
missing = list(require_args - input_args) missing = list(require_args - input_args)
unused = list(input_args - all_args) unused = list(input_args - all_args)
varargs = [] if not spect.varargs else [arg for arg in spect.varargs]
return CheckRes(missing=missing, return CheckRes(missing=missing,
unused=unused, unused=unused,
duplicated=duplicated, duplicated=duplicated,
required=list(require_args), required=list(require_args),
all_needed=list(all_args))
all_needed=list(all_args),
varargs=varargs)



def get_func_signature(func): def get_func_signature(func):
# can only be used in function or class method
"""

Given a function or method, return its signature.
For example:
(1) function
def func(a, b='a', *args):
xxxx
get_func_signature(func) # 'func(a, b='a', *args)'
(2) method
class Demo:
def __init__(self):
xxx
def forward(self, a, b='a', **args)
demo = Demo()
get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)'
:param func: a function or a method
:return: str or None
"""
if inspect.ismethod(func): if inspect.ismethod(func):
class_name = func.__self__.__class__.__name__ class_name = func.__self__.__class__.__name__
signature = inspect.signature(func) signature = inspect.signature(func)
signature_str = str(signature) signature_str = str(signature)
if len(signature_str)>2:
if len(signature_str) > 2:
_self = '(self, ' _self = '(self, '
else: else:
_self = '(self' _self = '(self'
@@ -113,15 +170,263 @@ def get_func_signature(func):
return signature_str return signature_str




# move data to model's device
import torch
def _syn_model_data(model, *args):
assert len(model.state_dict())!=0, "This model has no parameter."
device = model.parameters().__next__().device
def _is_function_or_method(func):
"""

:param func:
:return:
"""
if not inspect.ismethod(func) and not inspect.isfunction(func):
return False
return True


def _check_function_or_method(func):
if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.")


def _move_dict_value_to_device(*args, device: torch.device):
"""

move data to model's device, element in *args should be dict. This is a inplace change.
:param device: torch.device
:param args:
:return:
"""
if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`")

for arg in args: for arg in args:
if isinstance(arg, dict): if isinstance(arg, dict):
for key, value in arg.items(): for key, value in arg.items():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
arg[key] = value.to(device) arg[key] = value.to(device)
else: else:
raise ValueError("Only support dict type right now.")
raise TypeError("Only support `dict` type right now.")


class CheckError(Exception):
"""

CheckError. Used in losses.LossBase, metrics.MetricBase.
"""

def __init__(self, check_res: CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`']

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused:
errs.append(f"\tunused param: {check_res.unused}")

Exception.__init__(self, '\n'.join(errs))

self.check_res = check_res
self.func_signature = func_signature


IGNORE_CHECK_LEVEL = 0
WARNING_CHECK_LEVEL = 1
STRICT_CHECK_LEVEL = 2


def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_res: CheckRes,
pred_dict: dict, target_dict: dict, dataset, check_level=0):
errs = []
unuseds = []
_unused_field = []
_unused_param = []
suggestions = []
if check_res.varargs:
errs.append(f"\tvarargs: *{check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")

if check_res.unused:
for _unused in check_res.unused:
if _unused in target_dict:
_unused_field.append(_unused)
else:
_unused_param.append(_unused)
if _unused_field:
unuseds.append(f"\tunused field: {_unused_field}")
if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward

module_name = func_signature.split('.')[0]
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
import re
mapped_missing = []
unmapped_missing = []
input_func_map = {}
for _miss in check_res.missing:
if '(' in _miss:
# if they are like 'SomeParam(assign to xxx)'
_miss = _miss.split('(')[0]
matches = re.findall("(?<=`)[a-zA-Z0-9]*?(?=`)", _miss)
if len(matches) == 2:
fun_arg, module_name = matches
input_func_map[_miss] = fun_arg
if fun_arg == _miss:
unmapped_missing.append(_miss)
else:
mapped_missing.append(_miss)
else:
unmapped_missing.append(_miss)

for _miss in mapped_missing:
if _miss in dataset:
suggestions.append(f"Set {_miss} as target.")
else:
_tmp = ''
if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
_tmp = f'Provide {_miss} in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp)
for _miss in unmapped_missing:
if _miss in dataset:
suggestions.append(f"Set {_miss} as target.")
else:
_tmp = ''
if check_res.unused:
_tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
if _tmp:
_tmp += f' Or provide {_miss} in DataSet or output of {prev_func_signature}.'
else:
_tmp = f'Provide {_miss} in output of {prev_func_signature} or DataSet.'
suggestions.append(_tmp)

if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")

if len(errs)>0:
errs.extend(unuseds)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds)

if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = ""
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions):
if idx>0:
sugg_str += '\t\t\t'
sugg_str += f'({idx+1}). {sugg}\n'
sugg_str = sugg_str[:-1]
else:
sugg_str += suggestions[0]
errs.append(f'\ttarget field: {list(target_dict.keys())}')
errs.append(f'\tparam from {prev_func_signature}: {list(pred_dict.keys())}')
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if check_res.unused:
if check_level == WARNING_CHECK_LEVEL:
if not module_name:
module_name = func_signature.split('.')[0]
_unused_warn = f'{check_res.unused} is not used by {module_name}.'
warnings.warn(message=_unused_warn)

def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = get_func_signature(forward_func)

errs = []
suggestions = []
_unused = []

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = []
_miss_out_dataset = []
for _miss in check_res.missing:
if _miss in dataset:
_miss_in_dataset.append(_miss)
else:
_miss_out_dataset.append(_miss)
if _miss_in_dataset:
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
if _miss_out_dataset:
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
# if check_res.unused:
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`."
suggestions.append(_tmp)

if check_res.unused:
_unused = [f"\tunused field: {check_res.unused}"]
if len(errs)>0:
errs.extend(_unused)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)

if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = ""
if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
else:
sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
raise NameError(err_str)
if _unused:
if check_level == WARNING_CHECK_LEVEL:
_unused_warn = _unused[0] + f' in {func_signature}.'
warnings.warn(message=_unused_warn)


def seq_lens_to_masks(seq_lens, float=False):
"""

Convert seq_lens to masks.
:param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,)
:param float: if True, the return masks is in float type, otherwise it is byte.
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length)
"""
if isinstance(seq_lens, np.ndarray):
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}."
raise NotImplemented
elif isinstance(seq_lens, torch.LongTensor):
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
masks = indexes.lt(seq_lens.unsqueeze(1))

if float:
masks = masks.float()

return masks
elif isinstance(seq_lens, list):
raise NotImplemented
else:
raise NotImplemented


def seq_mask(seq_len, max_len):
"""Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]

+ 42
- 40
fastNLP/core/vocabulary.py View File

@@ -1,24 +1,31 @@
from collections import Counter from collections import Counter
from copy import deepcopy


DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1


DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1}
def check_build_vocab(func):
"""A decorator to make sure the indexing is built before used.

"""


def _wrapper(self, *args, **kwargs):
if self.word2idx is None or self.rebuild is True:
self.build_vocab()
return func(self, *args, **kwargs)


def isiterable(p_object):
try:
_ = iter(p_object)
except TypeError:
return False
return True
return _wrapper




def check_build_vocab(func):
def check_build_status(func):
"""A decorator to check whether the vocabulary updates after the last build.

"""

def _wrapper(self, *args, **kwargs): def _wrapper(self, *args, **kwargs):
if self.word2idx is None:
self.build_vocab()
if self.rebuild is False:
self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size:
print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs) return func(self, *args, **kwargs)


return _wrapper return _wrapper
@@ -36,25 +43,21 @@ class Vocabulary(object):
vocab.to_word(5) vocab.to_word(5)
""" """


def __init__(self, need_default=True, max_size=None, min_freq=None):
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'):
""" """
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True.
:param int max_size: set the max number of words in Vocabulary. Default: None :param int max_size: set the max number of words in Vocabulary. Default: None
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None
""" """
self.max_size = max_size self.max_size = max_size
self.min_freq = min_freq self.min_freq = min_freq
self.word_count = Counter() self.word_count = Counter()
self.has_default = need_default
if self.has_default:
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
else:
self.padding_label = None
self.unknown_label = None
self.unknown = unknown
self.padding = padding
self.word2idx = None self.word2idx = None
self.idx2word = None self.idx2word = None
self.rebuild = True


@check_build_status
def update(self, word_lst): def update(self, word_lst):
"""Add a list of words into the vocabulary. """Add a list of words into the vocabulary.


@@ -62,6 +65,7 @@ class Vocabulary(object):
""" """
self.word_count.update(word_lst) self.word_count.update(word_lst)


@check_build_status
def add(self, word): def add(self, word):
"""Add a single word into the vocabulary. """Add a single word into the vocabulary.


@@ -69,6 +73,7 @@ class Vocabulary(object):
""" """
self.word_count[word] += 1 self.word_count[word] += 1


@check_build_status
def add_word(self, word): def add_word(self, word):
"""Add a single word into the vocabulary. """Add a single word into the vocabulary.


@@ -76,6 +81,7 @@ class Vocabulary(object):
""" """
self.add(word) self.add(word)


@check_build_status
def add_word_lst(self, word_lst): def add_word_lst(self, word_lst):
"""Add a list of words into the vocabulary. """Add a list of words into the vocabulary.


@@ -87,20 +93,22 @@ class Vocabulary(object):
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`.


""" """
if self.has_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL)
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL)
else:
self.word2idx = {}
self.word2idx = {}
if self.padding is not None:
self.word2idx[self.padding] = 0
if self.unknown is not None:
self.word2idx[self.unknown] = 1


max_size = min(self.max_size, len(self.word_count)) if self.max_size else None max_size = min(self.max_size, len(self.word_count)) if self.max_size else None
words = self.word_count.most_common(max_size) words = self.word_count.most_common(max_size)
if self.min_freq is not None: if self.min_freq is not None:
words = filter(lambda kv: kv[1] >= self.min_freq, words) words = filter(lambda kv: kv[1] >= self.min_freq, words)
if self.word2idx is not None:
words = filter(lambda kv: kv[0] not in self.word2idx, words)
start_idx = len(self.word2idx) start_idx = len(self.word2idx)
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab() self.build_reverse_vocab()
self.rebuild = False


def build_reverse_vocab(self): def build_reverse_vocab(self):
"""Build 'index to word' dict based on 'word to index' dict. """Build 'index to word' dict based on 'word to index' dict.
@@ -132,8 +140,8 @@ class Vocabulary(object):
""" """
if w in self.word2idx: if w in self.word2idx:
return self.word2idx[w] return self.word2idx[w]
elif self.has_default:
return self.word2idx[self.unknown_label]
if self.unknown is not None:
return self.word2idx[self.unknown]
else: else:
raise ValueError("word {} not in vocabulary".format(w)) raise ValueError("word {} not in vocabulary".format(w))


@@ -148,21 +156,16 @@ class Vocabulary(object):
@property @property
@check_build_vocab @check_build_vocab
def unknown_idx(self): def unknown_idx(self):
if self.unknown_label is None:
if self.unknown is None:
return None return None
return self.word2idx[self.unknown_label]

def __setattr__(self, name, val):
self.__dict__[name] = val
if name in ["unknown_label", "padding_label"]:
self.word2idx = None
return self.word2idx[self.unknown]


@property @property
@check_build_vocab @check_build_vocab
def padding_idx(self): def padding_idx(self):
if self.padding_label is None:
if self.padding is None:
return None return None
return self.word2idx[self.padding_label]
return self.word2idx[self.padding]


@check_build_vocab @check_build_vocab
def to_word(self, idx): def to_word(self, idx):
@@ -188,4 +191,3 @@ class Vocabulary(object):
""" """
self.__dict__.update(state) self.__dict__.update(state)
self.build_reverse_vocab() self.build_reverse_vocab()


+ 18
- 14
fastNLP/io/base_loader.py View File

@@ -31,17 +31,21 @@ class BaseLoader(object):
return obj return obj




class ToyLoader0(BaseLoader):
"""
For CharLM
"""

def __init__(self, data_path):
super(ToyLoader0, self).__init__(data_path)

def load(self):
with open(self.data_path, 'r') as f:
corpus = f.read().lower()
import re
corpus = re.sub(r"<unk>", "unk", corpus)
return corpus.split()
class DataLoaderRegister:
""""register for data sets"""
_readers = {}

@classmethod
def set_reader(cls, reader_cls, read_fn_name):
# def wrapper(reader_cls):
if read_fn_name in cls._readers:
raise KeyError('duplicate reader: {} and {} for read_func: {}'.format(cls._readers[read_fn_name], reader_cls, read_fn_name))
if hasattr(reader_cls, 'load'):
cls._readers[read_fn_name] = reader_cls().load
return reader_cls

@classmethod
def get_reader(cls, read_fn_name):
if read_fn_name in cls._readers:
return cls._readers[read_fn_name]
raise AttributeError('no read function: {}'.format(read_fn_name))

fastNLP/io/config_saver.py → fastNLP/io/config_io.py View File

@@ -1,6 +1,152 @@
import configparser
import json
import os import os


from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""loader for configuration files"""

def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))

@staticmethod
def parse(string):
raise NotImplementedError

@staticmethod
def load_config(file_path, sections):
"""
:param file_path: the path of config file
:param sections: the dict of {section_name(string): Section instance}
Example:
test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
:return: return nothing, but the value of attributes are saved in sessions
"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
if not os.path.exists(file_path):
raise FileNotFoundError("config file {} not found. ".format(file_path))
cfg.read(file_path)
for s in sections:
attr_list = [i for i in sections[s].__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]
if s not in cfg:
print('section %s not found in config file' % (s))
continue
gen_sec = cfg[s]
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
# print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
"""
if attr in attr_list then check its type and
update its value.
else add a new attr in sections[s]
"""
setattr(sections[s], attr, val)
except Exception as e:
print("cannot load attribute %s in section %s"
% (attr, s))
pass


class ConfigSection(object):

def __init__(self):
pass

def __getitem__(self, key):
"""
:param key: str, the name of the attribute
:return attr: the value of this attribute
if key not in self.__dict__.keys():
return self[key]
else:
raise AttributeError
"""
if key in self.__dict__.keys():
return getattr(self, key)
raise AttributeError("do NOT have attribute %s" % key)

def __setitem__(self, key, value):
"""
:param key: str, the name of the attribute
:param value: the value of this attribute
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not isinstance(value, type(getattr(self, key))):
raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)

def __contains__(self, item):
"""
:param item: The key of item.
:return: True if the key in self.__dict__.keys() else False.
"""
return item in self.__dict__.keys()

def __eq__(self, other):
"""Overwrite the == operator

:param other: Another ConfigSection() object which to be compared.
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
"""
for k in self.__dict__.keys():
if k not in other.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

for k in other.__dict__.keys():
if k not in self.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

return True

def __ne__(self, other):
"""Overwrite the != operator

:param other:
:return:
"""
return not self.__eq__(other)

@property
def data(self):
return self.__dict__


if __name__ == "__main__":
config = ConfigLoader('there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
General and My can be found in config file, so the attr and
value will be updated
A cannot be found in config file, so nothing will be done
"""

config.load_config("../../test/data_for_tests/config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))




class ConfigSaver(object): class ConfigSaver(object):
@@ -125,7 +271,7 @@ class ConfigSaver(object):
# logger = create_logger(__name__, "./config_loader.log") # logger = create_logger(__name__, "./config_loader.log")
# logger.warning("section [%s] in config file [%s] has been changed" % ( # logger.warning("section [%s] in config file [%s] has been changed" % (
# section_name, self.file_path # section_name, self.file_path
#))
# ))
change_file = True change_file = True
break break
if not change_file: if not change_file:

+ 0
- 149
fastNLP/io/config_loader.py View File

@@ -1,149 +0,0 @@
import configparser
import json
import os

from fastNLP.io.base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""loader for configuration files"""

def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))

@staticmethod
def parse(string):
raise NotImplementedError

@staticmethod
def load_config(file_path, sections):
"""
:param file_path: the path of config file
:param sections: the dict of {section_name(string): Section instance}
Example:
test_args = ConfigSection()
ConfigLoader("config.cfg", "").load_config("./data_for_tests/config", {"POS_test": test_args})
:return: return nothing, but the value of attributes are saved in sessions
"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
if not os.path.exists(file_path):
raise FileNotFoundError("config file {} not found. ".format(file_path))
cfg.read(file_path)
for s in sections:
attr_list = [i for i in sections[s].__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]
if s not in cfg:
print('section %s not found in config file' % (s))
continue
gen_sec = cfg[s]
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
# print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
"""
if attr in attr_list then check its type and
update its value.
else add a new attr in sections[s]
"""
setattr(sections[s], attr, val)
except Exception as e:
print("cannot load attribute %s in section %s"
% (attr, s))
pass


class ConfigSection(object):

def __init__(self):
pass

def __getitem__(self, key):
"""
:param key: str, the name of the attribute
:return attr: the value of this attribute
if key not in self.__dict__.keys():
return self[key]
else:
raise AttributeError
"""
if key in self.__dict__.keys():
return getattr(self, key)
raise AttributeError("do NOT have attribute %s" % key)

def __setitem__(self, key, value):
"""
:param key: str, the name of the attribute
:param value: the value of this attribute
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not isinstance(value, type(getattr(self, key))):
raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)

def __contains__(self, item):
"""
:param item: The key of item.
:return: True if the key in self.__dict__.keys() else False.
"""
return item in self.__dict__.keys()

def __eq__(self, other):
"""Overwrite the == operator

:param other: Another ConfigSection() object which to be compared.
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
"""
for k in self.__dict__.keys():
if k not in other.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

for k in other.__dict__.keys():
if k not in self.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False

return True

def __ne__(self, other):
"""Overwrite the != operator

:param other:
:return:
"""
return not self.__eq__(other)

@property
def data(self):
return self.__dict__


if __name__ == "__main__":
config = ConfigLoader('there is no data')

section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
General and My can be found in config file, so the attr and
value will be updated
A cannot be found in config file, so nothing will be done
"""

config.load_config("../../test/data_for_tests/config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr), type(getattr(section[s], attr)))

+ 43
- 38
fastNLP/io/dataset_loader.py View File

@@ -1,9 +1,8 @@
#TODO: need fix for current DataSet
import os import os


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.io.base_loader import BaseLoader
from fastNLP.io.base_loader import DataLoaderRegister




def convert_seq_dataset(data): def convert_seq_dataset(data):
@@ -20,8 +19,7 @@ def convert_seq_dataset(data):
""" """
dataset = DataSet() dataset = DataSet()
for word_seq in data: for word_seq in data:
x = TextField(word_seq, is_target=False)
dataset.append(Instance(word_seq=x))
dataset.append(Instance(word_seq=word_seq))
return dataset return dataset




@@ -40,11 +38,7 @@ def convert_seq2tag_dataset(data):
""" """
dataset = DataSet() dataset = DataSet()
for sample in data: for sample in data:
word_seq, label = sample[0], sample[1]
ins = Instance()
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
.add_field("label", LabelField(label, is_target=True))
dataset.append(ins)
dataset.append(Instance(word_seq=sample[0], label=sample[1]))
return dataset return dataset




@@ -63,20 +57,13 @@ def convert_seq2seq_dataset(data):
""" """
dataset = DataSet() dataset = DataSet()
for sample in data: for sample in data:
word_seq, label_seq = sample[0], sample[1]
ins = Instance()
ins.add_field("word_seq", TextField(word_seq, is_target=False)) \
.add_field("label_seq", TextField(label_seq, is_target=True))
dataset.append(ins)
dataset.append(Instance(word_seq=sample[0], label_seq=sample[1]))
return dataset return dataset




class DataSetLoader(BaseLoader):
class DataSetLoader:
""""loader for data sets""" """"loader for data sets"""


def __init__(self):
super(DataSetLoader, self).__init__()

def load(self, path): def load(self, path):
""" load data in `path` into a dataset """ load data in `path` into a dataset
""" """
@@ -88,7 +75,20 @@ class DataSetLoader(BaseLoader):
raise NotImplementedError raise NotImplementedError




@DataSet.set_reader('read_raw')
class NativeDataSetLoader(DataSetLoader):
def __init__(self):
super(NativeDataSetLoader, self).__init__()

def load(self, path):
ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t")
ds.set_input("raw_sentence")
ds.set_target("label")
return ds


DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive')


class RawDataSetLoader(DataSetLoader): class RawDataSetLoader(DataSetLoader):
def __init__(self): def __init__(self):
super(RawDataSetLoader, self).__init__() super(RawDataSetLoader, self).__init__()
@@ -104,7 +104,9 @@ class RawDataSetLoader(DataSetLoader):
return convert_seq_dataset(data) return convert_seq_dataset(data)




@DataSet.set_reader('read_pos')
DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata')


class POSDataSetLoader(DataSetLoader): class POSDataSetLoader(DataSetLoader):
"""Dataset Loader for POS Tag datasets. """Dataset Loader for POS Tag datasets.


@@ -174,7 +176,9 @@ class POSDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)




@DataSet.set_reader('read_tokenize')
DataLoaderRegister.set_reader(POSDataSetLoader, 'read_pos')


class TokenizeDataSetLoader(DataSetLoader): class TokenizeDataSetLoader(DataSetLoader):
""" """
Data set loader for tokenization data sets Data set loader for tokenization data sets
@@ -234,7 +238,6 @@ class TokenizeDataSetLoader(DataSetLoader):
return convert_seq2seq_dataset(data) return convert_seq2seq_dataset(data)




@DataSet.set_reader('read_class')
class ClassDataSetLoader(DataSetLoader): class ClassDataSetLoader(DataSetLoader):
"""Loader for classification data sets""" """Loader for classification data sets"""


@@ -273,7 +276,6 @@ class ClassDataSetLoader(DataSetLoader):
return convert_seq2tag_dataset(data) return convert_seq2tag_dataset(data)




@DataSet.set_reader('read_conll')
class ConllLoader(DataSetLoader): class ConllLoader(DataSetLoader):
"""loader for conll format files""" """loader for conll format files"""


@@ -315,7 +317,6 @@ class ConllLoader(DataSetLoader):
pass pass




@DataSet.set_reader('read_lm')
class LMDataSetLoader(DataSetLoader): class LMDataSetLoader(DataSetLoader):
"""Language Model Dataset Loader """Language Model Dataset Loader


@@ -352,7 +353,6 @@ class LMDataSetLoader(DataSetLoader):
pass pass




@DataSet.set_reader('read_people_daily')
class PeopleDailyCorpusLoader(DataSetLoader): class PeopleDailyCorpusLoader(DataSetLoader):
""" """
People Daily Corpus: Chinese word segmentation, POS tag, NER People Daily Corpus: Chinese word segmentation, POS tag, NER
@@ -403,10 +403,19 @@ class PeopleDailyCorpusLoader(DataSetLoader):
pos_tag_examples.append([sent_words, sent_pos_tag]) pos_tag_examples.append([sent_words, sent_pos_tag])
ner_examples.append([sent_words, sent_ner]) ner_examples.append([sent_words, sent_ner])
# List[List[List[str], List[str]]] # List[List[List[str], List[str]]]
return pos_tag_examples, ner_examples
# ner_examples not used
return self.convert(pos_tag_examples)


def convert(self, data): def convert(self, data):
pass
data_set = DataSet()
for item in data:
sent_words, sent_pos_tag = item[0], item[1]
data_set.append(Instance(words=sent_words, tags=sent_pos_tag))
data_set.apply(lambda ins: len(ins), new_field_name="seq_len")
data_set.set_target("tags")
data_set.set_input("sent_words")
data_set.set_input("seq_len")
return data_set




class SNLIDataSetLoader(DataSetLoader): class SNLIDataSetLoader(DataSetLoader):
@@ -462,17 +471,13 @@ class SNLIDataSetLoader(DataSetLoader):
for example in data: for example in data:
p, h, l = example p, h, l = example
# list, list, str # list, list, str
x1 = TextField(p, is_target=False)
x2 = TextField(h, is_target=False)
x1_len = TextField([1] * len(p), is_target=False)
x2_len = TextField([1] * len(h), is_target=False)
y = LabelField(l, is_target=True)
instance = Instance() instance = Instance()
instance.add_field("premise", x1)
instance.add_field("hypothesis", x2)
instance.add_field("premise_len", x1_len)
instance.add_field("hypothesis_len", x2_len)
instance.add_field("truth", y)
instance.add_field("premise", p)
instance.add_field("hypothesis", h)
instance.add_field("truth", l)
data_set.append(instance) data_set.append(instance)

data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len")
data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len")
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len")
data_set.set_target("truth")
return data_set return data_set

+ 54
- 23
fastNLP/io/embed_loader.py View File

@@ -1,3 +1,4 @@
import numpy as np
import torch import torch


from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
@@ -26,7 +27,7 @@ class EmbedLoader(BaseLoader):
emb = {} emb = {}
with open(emb_file, 'r', encoding='utf-8') as f: with open(emb_file, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = list(filter(lambda w: len(w)>0, line.strip().split(' ')))
line = list(filter(lambda w: len(w) > 0, line.strip().split(' ')))
if len(line) > 2: if len(line) > 2:
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) emb[line[0]] = torch.Tensor(list(map(float, line[1:])))
return emb return emb
@@ -35,9 +36,9 @@ class EmbedLoader(BaseLoader):
def _load_pretrain(emb_file, emb_type): def _load_pretrain(emb_file, emb_type):
"""Read txt data from embedding file and convert to np.array as pre-trained embedding """Read txt data from embedding file and convert to np.array as pre-trained embedding


:param emb_file: str, the pre-trained embedding file path
:param emb_type: str, the pre-trained embedding data format
:return dict: {str: np.array}
:param str emb_file: the pre-trained embedding file path
:param str emb_type: the pre-trained embedding data format
:return dict embedding: `{str: np.array}`
""" """
if emb_type == 'glove': if emb_type == 'glove':
return EmbedLoader._load_glove(emb_file) return EmbedLoader._load_glove(emb_file)
@@ -45,38 +46,68 @@ class EmbedLoader(BaseLoader):
raise Exception("embedding type {} not support yet".format(emb_type)) raise Exception("embedding type {} not support yet".format(emb_type))


@staticmethod @staticmethod
def load_embedding(emb_dim, emb_file, emb_type, vocab, emb_pkl):
def load_embedding(emb_dim, emb_file, emb_type, vocab):
"""Load the pre-trained embedding and combine with the given dictionary. """Load the pre-trained embedding and combine with the given dictionary.


:param emb_dim: int, the dimension of the embedding. Should be the same as pre-trained embedding.
:param emb_file: str, the pre-trained embedding file path.
:param emb_type: str, the pre-trained embedding format, support glove now
:param vocab: Vocabulary, a mapping from word to index, can be provided by user or built from pre-trained embedding
:param emb_pkl: str, the embedding pickle file.
:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding.
:param str emb_file: the pre-trained embedding file path.
:param str emb_type: the pre-trained embedding format, support glove now
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
:return embedding_tensor: Tensor of shape (len(word_dict), emb_dim) :return embedding_tensor: Tensor of shape (len(word_dict), emb_dim)
vocab: input vocab or vocab built by pre-train vocab: input vocab or vocab built by pre-train
TODO: fragile code
""" """
# If the embedding pickle exists, load it and return.
# if os.path.exists(emb_pkl):
# with open(emb_pkl, "rb") as f:
# embedding_tensor, vocab = _pickle.load(f)
# return embedding_tensor, vocab
# Otherwise, load the pre-trained embedding.
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
if vocab is None: if vocab is None:
# build vocabulary from pre-trained embedding # build vocabulary from pre-trained embedding
vocab = Vocabulary() vocab = Vocabulary()
for w in pretrain.keys(): for w in pretrain.keys():
vocab.update(w)
vocab.add(w)
embedding_tensor = torch.randn(len(vocab), emb_dim) embedding_tensor = torch.randn(len(vocab), emb_dim)
for w, v in pretrain.items(): for w, v in pretrain.items():
if len(v.shape) > 1 or emb_dim != v.shape[0]: if len(v.shape) > 1 or emb_dim != v.shape[0]:
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,)))
raise ValueError(
"Pretrained embedding dim is {}. Dimension dismatched. Required {}".format(v.shape, (emb_dim,)))
if vocab.has_word(w): if vocab.has_word(w):
embedding_tensor[vocab[w]] = v embedding_tensor[vocab[w]] = v

# save and return the result
# with open(emb_pkl, "wb") as f:
# _pickle.dump((embedding_tensor, vocab), f)
return embedding_tensor, vocab return embedding_tensor, vocab

@staticmethod
def parse_glove_line(line):
line = list(filter(lambda w: len(w) > 0, line.strip().split(" ")))
if len(line) <= 2:
raise RuntimeError("something goes wrong in parsing glove embedding")
return line[0], torch.Tensor(list(map(float, line[1:])))

@staticmethod
def fast_load_embedding(emb_dim, emb_file, vocab):
"""Fast load the pre-trained embedding and combine with the given dictionary.
This loading method uses line-by-line operation.

:param int emb_dim: the dimension of the embedding. Should be the same as pre-trained embedding.
:param str emb_file: the pre-trained embedding file path.
:param Vocabulary vocab: a mapping from word to index, can be provided by user or built from pre-trained embedding
:return numpy.ndarray embedding_matrix:

"""
if vocab is None:
raise RuntimeError("You must provide a vocabulary.")
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim))
hit_flags = np.zeros(shape=(len(vocab),), dtype=int)
with open(emb_file, "r", encoding="utf-8") as f:
for line in f:
word, vector = EmbedLoader.parse_glove_line(line)
if word in vocab:
if len(vector.shape) > 1 or emb_dim != vector.shape[0]:
raise ValueError("Pre-trained embedding dim is {}. Expect {}.".format(vector.shape, (emb_dim,)))
embedding_matrix[vocab[word]] = vector
hit_flags[vocab[word]] = 1

if np.sum(hit_flags) < len(vocab):
# some words from vocab are missing in pre-trained embedding
# we normally sample each dimension
vocab_embed = embedding_matrix[np.where(hit_flags)]
sampled_vectors = np.random.normal(vocab_embed.mean(axis=0), vocab_embed.std(axis=0),
size=(len(vocab) - np.sum(hit_flags), emb_dim))
embedding_matrix[np.where(1 - hit_flags)] = sampled_vectors
return embedding_matrix

fastNLP/io/model_saver.py → fastNLP/io/model_io.py View File

@@ -1,5 +1,32 @@
import torch import torch


from fastNLP.io.base_loader import BaseLoader


class ModelLoader(BaseLoader):
"""
Loader for models.
"""

def __init__(self):
super(ModelLoader, self).__init__()

@staticmethod
def load_pytorch(empty_model, model_path):
"""
Load model parameters from .pkl files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
:param model_path: str, the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))

@staticmethod
def load_pytorch_model(model_path):
"""Load the entire model.

"""
return torch.load(model_path)



class ModelSaver(object): class ModelSaver(object):
"""Save a model """Save a model
@@ -8,6 +35,7 @@ class ModelSaver(object):
saver.save_pytorch(model) saver.save_pytorch(model)


""" """

def __init__(self, save_path): def __init__(self, save_path):
""" """



+ 0
- 28
fastNLP/io/model_loader.py View File

@@ -1,28 +0,0 @@
import torch

from fastNLP.io.base_loader import BaseLoader


class ModelLoader(BaseLoader):
"""
Loader for models.
"""

def __init__(self):
super(ModelLoader, self).__init__()

@staticmethod
def load_pytorch(empty_model, model_path):
"""
Load model parameters from .pkl files into the empty PyTorch model.
:param empty_model: a PyTorch model with initialized parameters.
:param model_path: str, the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))

@staticmethod
def load_pytorch_model(model_path):
"""Load the entire model.

"""
return torch.load(model_path)

+ 14
- 3
fastNLP/models/base_model.py View File

@@ -1,6 +1,6 @@
import torch import torch


from fastNLP.core.trainer import Trainer
from fastNLP.modules.decoder.MLP import MLP




class BaseModel(torch.nn.Module): class BaseModel(torch.nn.Module):
@@ -11,8 +11,19 @@ class BaseModel(torch.nn.Module):
super(BaseModel, self).__init__() super(BaseModel, self).__init__()


def fit(self, train_data, dev_data=None, **train_args): def fit(self, train_data, dev_data=None, **train_args):
trainer = Trainer(**train_args)
trainer.train(self, train_data, dev_data)
pass


def predict(self, *args, **kwargs): def predict(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError


class NaiveClassifier(BaseModel):
def __init__(self, in_feature_dim, out_feature_dim):
super(NaiveClassifier, self).__init__()
self.mlp = MLP([in_feature_dim, in_feature_dim, out_feature_dim])

def forward(self, x):
return {"predict": torch.sigmoid(self.mlp(x))}

def predict(self, x):
return {"predict": torch.sigmoid(self.mlp(x)) > 0.5}

+ 5
- 29
fastNLP/models/cnn_text_classification.py View File

@@ -18,8 +18,8 @@ class CNNText(torch.nn.Module):
def __init__(self, embed_num, def __init__(self, embed_num,
embed_dim, embed_dim,
num_classes, num_classes,
kernel_nums=(3,4,5),
kernel_sizes=(3,4,5),
kernel_nums=(3, 4, 5),
kernel_sizes=(3, 4, 5),
padding=0, padding=0,
dropout=0.5): dropout=0.5):
super(CNNText, self).__init__() super(CNNText, self).__init__()
@@ -33,7 +33,6 @@ class CNNText(torch.nn.Module):
padding=padding) padding=padding)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
self.fc = encoder.Linear(sum(kernel_nums), num_classes) self.fc = encoder.Linear(sum(kernel_nums), num_classes)
self._loss = nn.CrossEntropyLoss()


def forward(self, word_seq): def forward(self, word_seq):
""" """
@@ -45,7 +44,7 @@ class CNNText(torch.nn.Module):
x = self.conv_pool(x) # [N,L,C] -> [N,C] x = self.conv_pool(x) # [N,L,C] -> [N,C]
x = self.dropout(x) x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class] x = self.fc(x) # [N,C] -> [N, N_class]
return {'output':x}
return {'pred': x}


def predict(self, word_seq): def predict(self, word_seq):
""" """
@@ -54,28 +53,5 @@ class CNNText(torch.nn.Module):
:return predict: dict of torch.LongTensor, [batch_size, seq_len] :return predict: dict of torch.LongTensor, [batch_size, seq_len]
""" """
output = self(word_seq) output = self(word_seq)
_, predict = output['output'].max(dim=1)
return {'predict': predict}

def get_loss(self, output, label_seq):
"""

:param output: output of forward(), [batch_size, seq_len]
:param label_seq: true label in DataSet, [batch_size, seq_len]
:return loss: torch.Tensor
"""
return self._loss(output, label_seq)

def evaluate(self, predict, label_seq):
"""

:param predict: iterable predict tensors
:param label_seq: iterable true label tensors
:return accuracy: dict of float
"""
predict, label_seq = torch.stack(tuple(predict), dim=0), torch.stack(tuple(label_seq), dim=0)
predict, label_seq = predict.squeeze(), label_seq.squeeze()
correct = (predict == label_seq).long().sum().item()
total = label_seq.size(0)
return {'acc': 1.0 * correct / total}

_, predict = output['pred'].max(dim=1)
return {'pred': predict}

+ 1
- 1
fastNLP/modules/encoder/char_embedding.py View File

@@ -43,7 +43,7 @@ class ConvCharEmbedding(nn.Module):
# [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1] # [batch_size*sent_length, feature_maps[i], 1, width - kernels[i] + 1]
y = torch.squeeze(y, 2) y = torch.squeeze(y, 2)
# [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1] # [batch_size*sent_length, feature_maps[i], width - kernels[i] + 1]
y = F.tanh(y)
y = torch.tanh(y)
y, __ = torch.max(y, 2) y, __ = torch.max(y, 2)
# [batch_size*sent_length, feature_maps[i]] # [batch_size*sent_length, feature_maps[i]]
feats.append(y) feats.append(y)


+ 1
- 1
reproduction/Biaffine_parser/infer.py View File

@@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])


from fastNLP.api.processor import * from fastNLP.api.processor import *
from fastNLP.models.biaffine_parser import BiaffineParser from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.config_io import ConfigSection, ConfigLoader


import _pickle as pickle import _pickle as pickle
import torch import torch


+ 2
- 3
reproduction/Biaffine_parser/run.py View File

@@ -13,11 +13,10 @@ from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.field import TextField, SeqLabelField from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.model_loader import ModelLoader
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.io.model_io import ModelLoader, ModelSaver
from fastNLP.io.embed_loader import EmbedLoader from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.models.biaffine_parser import BiaffineParser from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.io.model_saver import ModelSaver


BOS = '<BOS>' BOS = '<BOS>'
EOS = '<EOS>' EOS = '<EOS>'


+ 2
- 2
reproduction/LSTM+self_attention_sentiment_analysis/main.py View File

@@ -2,8 +2,8 @@ import torch.nn.functional as F


from fastNLP.core.trainer import ClassificationTrainer from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.core.utils import ClassPreprocess as Preprocess from fastNLP.core.utils import ClassPreprocess as Preprocess
from fastNLP.io.config_loader import ConfigLoader
from fastNLP.io.config_loader import ConfigSection
from fastNLP.io.config_io import ConfigLoader
from fastNLP.io.config_io import ConfigSection
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader
from fastNLP.models.base_model import BaseModel from fastNLP.models.base_model import BaseModel
from fastNLP.modules.aggregator.self_attention import SelfAttention from fastNLP.modules.aggregator.self_attention import SelfAttention


+ 2
- 3
reproduction/chinese_word_segment/run.py View File

@@ -3,12 +3,11 @@ import sys


sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader
from fastNLP.core.utils import load_pickle from fastNLP.core.utils import load_pickle
from fastNLP.io.model_saver import ModelSaver
from fastNLP.io.model_loader import ModelLoader
from fastNLP.io.model_io import ModelLoader, ModelSaver
from fastNLP.core.tester import SeqLabelTester from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.predictor import SeqLabelInfer from fastNLP.core.predictor import SeqLabelInfer


+ 1
- 1
requirements.txt View File

@@ -1,4 +1,4 @@
numpy>=1.14.2 numpy>=1.14.2
torch>=0.4.0 torch>=0.4.0
torchvision>=0.1.8
tensorboardX tensorboardX
tqdm>=4.28.1

+ 2
- 2
setup.py View File

@@ -12,12 +12,12 @@ with open('requirements.txt', encoding='utf-8') as f:
reqs = f.read() reqs = f.read()


setup( setup(
name='fastNLP',
name='FastNLP',
version='0.1.1', version='0.1.1',
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
long_description=readme, long_description=readme,
license=license, license=license,
author='fudanNLP',
author='FudanNLP',
python_requires='>=3.5', python_requires='>=3.5',
packages=find_packages(), packages=find_packages(),
install_requires=reqs.strip().split('\n'), install_requires=reqs.strip().split('\n'),


+ 12
- 0
test/api/test_processor.py View File

@@ -0,0 +1,12 @@
import unittest

from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
from fastNLP.core.dataset import DataSet


class TestProcessor(unittest.TestCase):
def test_FullSpaceToHalfSpaceProcessor(self):
ds = DataSet({"word": ["00, u1, u), (u2, u2"]})
proc = FullSpaceToHalfSpaceProcessor("word")
ds = proc(ds)
self.assertTrue(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"])

+ 0
- 0
test/core/__init__.py View File


+ 2
- 2
test/core/test_batch.py View File

@@ -22,8 +22,8 @@ class TestCase1(unittest.TestCase):


def test_dataset_batching(self): def test_dataset_batching(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.set_input(x=True)
ds.set_target(y=True)
ds.set_input("x")
ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
for x, y in iter: for x, y in iter:
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray))


+ 128
- 3
test/core/test_dataset.py View File

@@ -1,6 +1,8 @@
import os
import unittest import unittest


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance




@@ -44,6 +46,9 @@ class TestDataSet(unittest.TestCase):
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)


with self.assertRaises(RuntimeError):
dd.add_field("??", [[1, 2]] * 40)

def test_delete_field(self): def test_delete_field(self):
dd = DataSet() dd = DataSet()
dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("x", [[1, 2, 3]] * 10)
@@ -55,7 +60,7 @@ class TestDataSet(unittest.TestCase):
def test_getitem(self): def test_getitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ins_1, ins_0 = ds[0], ds[1] ins_1, ins_0 = ds[0], ds[1]
self.assertTrue(isinstance(ins_1, DataSet.Instance) and isinstance(ins_0, DataSet.Instance))
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
self.assertEqual(ins_1["x"], [1, 2, 3, 4]) self.assertEqual(ins_1["x"], [1, 2, 3, 4])
self.assertEqual(ins_1["y"], [5, 6]) self.assertEqual(ins_1["y"], [5, 6])
self.assertEqual(ins_0["x"], [1, 2, 3, 4]) self.assertEqual(ins_0["x"], [1, 2, 3, 4])
@@ -65,11 +70,131 @@ class TestDataSet(unittest.TestCase):
self.assertTrue(isinstance(sub_ds, DataSet)) self.assertTrue(isinstance(sub_ds, DataSet))
self.assertEqual(len(sub_ds), 10) self.assertEqual(len(sub_ds), 10)


field = ds["x"]
self.assertEqual(field, ds.field_arrays["x"])
def test_get_item_error(self):
with self.assertRaises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds[40:]

with self.assertRaises(KeyError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
_ = ds["kom"]

def test_len_(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertEqual(len(ds), 40)

ds = DataSet()
self.assertEqual(len(ds), 0)


def test_apply(self): def test_apply(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx") ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")
self.assertTrue("rx" in ds.field_arrays) self.assertTrue("rx" in ds.field_arrays)
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])

ds.apply(lambda ins: len(ins["y"]), new_field_name="y")
self.assertEqual(ds.field_arrays["y"].content[0], 2)

res = ds.apply(lambda ins: len(ins["x"]))
self.assertTrue(isinstance(res, list) and len(res) > 0)
self.assertTrue(res[0], 4)

def test_drop(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
ds.drop(lambda ins: len(ins["y"]) < 3)
self.assertEqual(len(ds), 20)

def test_contains(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
self.assertTrue("x" in ds)
self.assertTrue("y" in ds)
self.assertFalse("z" in ds)

def test_rename_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.rename_field("x", "xx")
self.assertTrue("xx" in ds)
self.assertFalse("x" in ds)

with self.assertRaises(KeyError):
ds.rename_field("yyy", "oo")

def test_input_target(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.set_input("x")
ds.set_target("y")
self.assertTrue(ds.field_arrays["x"].is_input)
self.assertTrue(ds.field_arrays["y"].is_target)

with self.assertRaises(KeyError):
ds.set_input("xxx")
with self.assertRaises(KeyError):
ds.set_input("yyy")

def test_get_input_name(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input])

def test_get_target_name(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target])

def test_apply2(self):
def split_sent(ins):
return ins['raw_sentence'].split()

dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'),
sep='\t')
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)
dataset.apply(split_sent, new_field_name='words', is_input=True)
# print(dataset)

def test_add_field(self):
ds = DataSet({"x": [3, 4]})
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True)
# ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y')
print(ds)

def test_save_load(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ds.save("./my_ds.pkl")
self.assertTrue(os.path.exists("./my_ds.pkl"))

ds_1 = DataSet.load("./my_ds.pkl")
os.remove("my_ds.pkl")

def test_get_all_fields(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ans = ds.get_all_fields()
self.assertEqual(ans["x"].content, [[1, 2, 3, 4]] * 10)
self.assertEqual(ans["y"].content, [[5, 6]] * 10)

def test_get_field(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
ans = ds.get_field("x")
self.assertTrue(isinstance(ans, FieldArray))
self.assertEqual(ans.content, [[1, 2, 3, 4]] * 10)
ans = ds.get_field("y")
self.assertTrue(isinstance(ans, FieldArray))
self.assertEqual(ans.content, [[5, 6]] * 10)

def test_reader(self):
# 跑通即可
ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)

ds = DataSet().read_pos("test/data_for_tests/people.txt")
self.assertTrue(isinstance(ds, DataSet))
self.assertTrue(len(ds) > 0)


class TestDataSetIter(unittest.TestCase):
def test__repr__(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
for iter in ds:
self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4],\n'y': [5, 6]}")

+ 77
- 0
test/core/test_fieldarray.py View File

@@ -20,3 +20,80 @@ class TestFieldArray(unittest.TestCase):
self.assertEqual(fa.get(0), 1) self.assertEqual(fa.get(0), 1)
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])

def test_type_conversion(self):
fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)

fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
fa.append(1.3333)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)

fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
fa.append(10)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)

fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True)
fa.append("e")
self.assertEqual(fa.dtype, np.str)
self.assertEqual(fa.pytype, str)

def test_support_np_array(self):
fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True)
self.assertEqual(fa.dtype, np.ndarray)
self.assertEqual(fa.pytype, np.ndarray)

fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
self.assertEqual(fa.dtype, np.ndarray)
self.assertEqual(fa.pytype, np.ndarray)

fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True)
# in this case, pytype is actually a float. We do not care about it.
self.assertEqual(fa.dtype, np.float64)

def test_nested_list(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True)
self.assertEqual(fa.pytype, float)
self.assertEqual(fa.dtype, np.float64)

def test_getitem_v1(self):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
ans = fa[[0, 1]]
self.assertTrue(isinstance(ans, np.ndarray))
self.assertTrue(isinstance(ans[0], np.ndarray))
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
self.assertEqual(ans.dtype, np.float64)

def test_getitem_v2(self):
x = np.random.rand(10, 5)
fa = FieldArray("my_field", x, is_input=True)
indices = [0, 1, 3, 4, 6]
for a, b in zip(fa[indices], x[indices]):
self.assertListEqual(a.tolist(), b.tolist())

def test_append(self):
with self.assertRaises(Exception):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append(0)

with self.assertRaises(Exception):
fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
fa.append([1, 2, 3, 4, 5])

with self.assertRaises(Exception):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append([])

with self.assertRaises(Exception):
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append(["str", 0, 0, 0, 1.89])

fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])

+ 6
- 0
test/core/test_instance.py View File

@@ -27,3 +27,9 @@ class TestCase(unittest.TestCase):
self.assertEqual(ins["x"], [1, 2, 3]) self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6]) self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1]) self.assertEqual(ins["z"], [1, 1, 1])

def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields)
# simple print, that is enough.
print(ins)

+ 82
- 301
test/core/test_loss.py View File

@@ -1,306 +1,87 @@
import unittest import unittest


import fastNLP.core.loss as loss
import math
import torch as tc
import pdb
import torch
import torch.nn.functional as F


class TestLoss(unittest.TestCase):

def test_case_1(self):
#验证nllloss的原理

print (".----------------------------------")

loss_func = loss.Loss("nll")

#pdb.set_trace()

y = tc.Tensor(
[
[.3,.4,.3],
[.5,.3,.2],
[.3,.6,.1],
]
)

gy = tc.LongTensor(
[
0,
1,
2,
]
)


y = tc.log(y)
los = loss_func(y , gy)

r = -math.log(.3) - math.log(.3) - math.log(.1)
r /= 3
print ("loss = %f" % (los))
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_2(self):
#验证squash()的正确性
print ("----------------------------------")

log = math.log

loss_func = loss.Loss("nll")

#pdb.set_trace()

y = tc.Tensor(
[
[[.3,.4,.3],[.3,.4,.3],],
[[.5,.3,.2],[.1,.2,.7],],
[[.3,.6,.1],[.2,.1,.7],],
]
)

gy = tc.LongTensor(
[
[0,2],
[1,2],
[2,1],
]
)


#pdb.set_trace()

y = tc.log(y)
los = loss_func(y , gy)
print ("loss = %f" % (los))

r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
r /= 6
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_3(self):
#验证pack_padded_sequence()的正确性
print ("----------------------------------")

log = math.log

loss_func = loss.Loss("nll")

#pdb.set_trace()

y = tc.Tensor(
[
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],],
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],],
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],],
]
)

gy = tc.LongTensor(
[
[0,2,1,],
[1,2,0,],
[2,0,0,],
]
)

lens = [3,2,1]

#pdb.set_trace()

y = tc.log(y)

yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data
los = loss_func(yy , gyy)
print ("loss = %f" % (los))


r = -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
r /= 6
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_4(self):
#验证unpad()的正确性
print ("----------------------------------")

log = math.log

#pdb.set_trace()

y = tc.Tensor(
[
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
]
)

gy = tc.LongTensor(
[
[0,2,1,2,],
[1,2,0,0,],
[2,0,0,0,],
]
)

lens = [4,2,1]

#pdb.set_trace()

y = tc.log(y)
import fastNLP.core.losses as loss
from fastNLP.core.losses import squash, unpad


loss_func = loss.Loss("nll" , pre_pro = ["unpad"])
los = loss_func(y , gy , lens = lens)
print ("loss = %f" % (los))



r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
r /= 7
print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_5(self):
#验证mask()和make_mask()的正确性
print ("----------------------------------")

log = math.log

#pdb.set_trace()

y = tc.Tensor(
[
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
[[.5,.4,.1],[.3,.2,.5],[.4,.5,.1,],[.6,.1,.3,],],
[[.3,.6,.1],[.3,.2,.5],[.0,.0,.0,],[.0,.0,.0,],],
]
)

gy = tc.LongTensor(
[
[1,2,0,0,],
[0,2,1,2,],
[2,1,0,0,],
]
)

mask = tc.ByteTensor(
[
[1,1,0,0,],
[1,1,1,1,],
[1,1,0,0,],
]
)

y = tc.log(y)

lens = [2,4,2]

loss_func = loss.Loss("nll" , pre_pro = ["mask"])
los = loss_func(y , gy , mask = mask)
print ("loss = %f" % (los))

los2 = loss_func(y , gy , mask = loss.make_mask(lens,gy.size()[-1]))
print ("loss2 = %f" % (los2))


r = -log(.3) -log(.7) - log(.5) - log(.5) - log(.5) - log(.3) - log(.1) - log(.2)
r /= 8
print ("r = %f" % (r))


self.assertEqual(int(los * 1000), int(r * 1000))
self.assertEqual(int(los2 * 1000), int(r * 1000))

def test_case_6(self):
#验证unpad_mask()的正确性
print ("----------------------------------")

log = math.log

#pdb.set_trace()

y = tc.Tensor(
[
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
]
)

gy = tc.LongTensor(
[
[0,2,1,2,],
[1,2,0,0,],
[2,0,0,0,],
]
)

lens = [4,2,1]

#pdb.set_trace()

y = tc.log(y)

loss_func = loss.Loss("nll" , pre_pro = ["unpad_mask"])
los = loss_func(y , gy , lens = lens)
print ("loss = %f" % (los))


r = -log(.1) -log(.3) - log(.5) - log(.5) - log(.3) - log(.7) - log(.1)
r /= 7
print ("r = %f" % (r))

self.assertEqual(int(los * 1000), int(r * 1000))

def test_case_7(self):
#验证一些其他东西
print ("----------------------------------")

log = math.log

#pdb.set_trace()

y = tc.Tensor(
[
[[.3,.4,.3],[.3,.2,.5],[.4,.5,.1,],[.6,.3,.1,],],
[[.5,.3,.2],[.1,.2,.7],[.0,.0,.0,],[.0,.0,.0,],],
[[.3,.6,.1],[.0,.0,.0],[.0,.0,.0,],[.0,.0,.0,],],
]
)

gy = tc.LongTensor(
[
[0,2,1,2,],
[1,2,0,0,],
[2,0,0,0,],
]
)

lens = [4,2,1]

#pdb.set_trace()

y = tc.log(y)

loss_func = loss.Loss("nll" , pre_pro = [] , weight = tc.Tensor([1,1,0]))
loss_func.add_pre_pro("unpad_mask")
los = loss_func(y , gy , lens = lens)
print ("loss = %f" % (los))


r = - log(.3) - log(.5) - log(.3)
r /= 3
print ("r = %f" % (r))
self.assertEqual(int(los * 1000), int(r * 1000))

if __name__ == "__main__":
unittest.main()
class TestLoss(unittest.TestCase):
def test_CrossEntropyLoss(self):
ce = loss.CrossEntropyLoss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False)
b = torch.empty(3, dtype=torch.long).random_(5)
ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))

def test_BCELoss(self):
bce = loss.BCELoss(pred="my_predict", target="my_truth")
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False))
b = torch.randn((3, 5), requires_grad=False)
ans = bce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b))

def test_L1Loss(self):
l1 = loss.L1Loss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False)
b = torch.randn(3, 5)
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b))

def test_NLLLoss(self):
l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
b = torch.tensor([1, 0, 4])
ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.nll_loss(a, b))


class TestLosserError(unittest.TestCase):
def test_losser1(self):
# (1) only input, targets passed
pred_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4).long()}
los = loss.CrossEntropyLoss()

print(los(pred_dict=pred_dict, target_dict=target_dict))

#
def test_losser2(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3)}
target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss()

with self.assertRaises(RuntimeError):
print(los(pred_dict=pred_dict, target_dict=target_dict))

def test_losser3(self):
# (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0}
target_dict = {'target': torch.zeros(16).long()}
los = loss.CrossEntropyLoss()

print(los(pred_dict=pred_dict, target_dict=target_dict))

def test_check_error(self):
l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
b = torch.tensor([1, 0, 4])
with self.assertRaises(Exception):
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b})

with self.assertRaises(Exception):
ans = l1({"my_predict": a}, {"truth": b, "my": a})


class TestLossUtils(unittest.TestCase):
def test_squash(self):
a, b = squash(torch.randn(3, 5), torch.randn(3, 5))
self.assertEqual(tuple(a.size()), (3, 5))
self.assertEqual(tuple(b.size()), (15,))

def test_unpad(self):
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8))
self.assertEqual(tuple(a.size()), (5, 8, 3))
self.assertEqual(tuple(b.size()), (5, 8))

+ 145
- 0
test/core/test_metrics.py View File

@@ -0,0 +1,145 @@
import unittest

import numpy as np
import torch

from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.metrics import pred_topk, accuracy_topk


class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
# (1) only input, targets passed
pred_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())

def test_AccuracyMetric2(self):
# (2) with corrupted size
try:
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_AccuracyMetric3(self):
# (3) the second batch is corrupted size
try:
metric = AccuracyMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(pred_dict=pred_dict, target_dict=target_dict)

print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_AccuaryMetric4(self):
# (5) check reset
metric = AccuracyMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3) + 1}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 0})

def test_AccuaryMetric5(self):
# (5) check reset
metric = AccuracyMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1})

pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3) + 1}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 0.5})

def test_AccuaryMetric6(self):
# (6) check numpy array is not acceptable
try:
metric = AccuracyMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))}
target_dict = {'target': np.zeros((4, 3))}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_AccuaryMetric7(self):
# (7) check map, match
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.zeros(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

def test_AccuaryMetric8(self):
# (8) check map, does not match. use stop_fast_param to stop fast param map
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param": 1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict, )
self.assertDictEqual(metric.get_metric(), {'acc': 1})
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_AccuaryMetric9(self):
# (9) check map, include unused
try:
metric = AccuracyMetric(pred='prediction', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_AccuaryMetric10(self):
# (10) check _fast_metric
try:
metric = AccuracyMetric()
pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."


class TestUsefulFunctions(unittest.TestCase):
# 测试metrics.py中一些看上去挺有用的函数
def test_case_1(self):
# multi-class
_ = accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
_ = pred_topk(np.random.randint(0, 3, size=(10, 1)))

# 跑通即可

+ 54
- 0
test/core/test_optimizer.py View File

@@ -0,0 +1,54 @@
import unittest

import torch

from fastNLP.core.optimizer import SGD, Adam


class TestOptim(unittest.TestCase):
def test_SGD(self):
optim = SGD(model_params=torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("momentum" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD))

optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD))

optim = SGD(lr=0.002, momentum=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)

optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD))

with self.assertRaises(TypeError):
_ = SGD("???")
with self.assertRaises(TypeError):
_ = SGD(0.001, lr=0.002)

def test_Adam(self):
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))

optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))

optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)

optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))

+ 29
- 1
test/core/test_predictor.py View File

@@ -1,6 +1,34 @@
import unittest import unittest


import numpy as np
import torch

from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.predictor import Predictor
from fastNLP.modules.encoder.linear import Linear


def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set



class TestPredictor(unittest.TestCase): class TestPredictor(unittest.TestCase):
def test(self): def test(self):
pass
predictor = Predictor()
model = Linear(2, 1)
data = prepare_fake_dataset()
data.set_input("x")
ans = predictor.predict(model, data)
self.assertEqual(len(ans), 2000)
self.assertTrue(isinstance(ans[0], torch.Tensor))

+ 11
- 1
test/core/test_sampler.py View File

@@ -1,9 +1,11 @@
import random
import unittest import unittest


import torch import torch


from fastNLP.core.dataset import DataSet
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \
k_means_1d, k_means_bucketing, simple_sort_bucketing
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler




class TestSampler(unittest.TestCase): class TestSampler(unittest.TestCase):
@@ -40,3 +42,11 @@ class TestSampler(unittest.TestCase):
def test_simple_sort_bucketing(self): def test_simple_sort_bucketing(self):
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10])
assert len(_) == 10 assert len(_) == 10

def test_BucketSampler(self):
sampler = BucketSampler(num_buckets=3, batch_size=16, seq_lens_field_name="seq_len")
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len")
indices = sampler(data_set)
self.assertEqual(len(indices), 10)
# 跑通即可,不验证效果

+ 59
- 1
test/core/test_tester.py View File

@@ -4,6 +4,64 @@ data_name = "pku_training.utf8"
pickle_path = "data_for_tests" pickle_path = "data_for_tests"




import numpy as np
import torch.nn.functional as F
from torch import nn
import time
from fastNLP.core.utils import CheckError
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.tester import Tester
from fastNLP.models.base_model import NaiveClassifier

def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set


def prepare_fake_dataset2(*args, size=100):
ys = np.random.randint(4, size=100, dtype=np.int64)
data = {'y': ys}
for arg in args:
data[arg] = np.random.randn(size, 5)
return DataSet(data=data)

class TestTester(unittest.TestCase): class TestTester(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
pass
# 检查报错提示能否正确提醒用户
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2')
dataset.set_target('y', 'x1')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
time.sleep(0.1)
# loss = F.cross_entropy(x, y)
return {'preds': x}

model = Model()
with self.assertRaises(NameError):
tester = Tester(
data=dataset,
model=model,
metrics=AccuracyMetric())
tester.test()

+ 239
- 3
test/core/test_trainer.py View File

@@ -1,6 +1,242 @@
import unittest import unittest


import numpy as np
import torch.nn.functional as F
from torch import nn
import time
from fastNLP.core.utils import CheckError
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer
from fastNLP.models.base_model import NaiveClassifier


class TestTrainer(unittest.TestCase):
def test_case_1(self):
pass

def prepare_fake_dataset():
mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,))

mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,))

data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set


def prepare_fake_dataset2(*args, size=100):
ys = np.random.randint(4, size=100, dtype=np.int64)
data = {'y': ys}
for arg in args:
data[arg] = np.random.randn(size, 5)
return DataSet(data=data)


class TrainerTestGround(unittest.TestCase):
def test_case(self):
data_set = prepare_fake_dataset()
data_set.set_input("x", flag=True)
data_set.set_target("y", flag=True)

train_set, dev_set = data_set.split(0.3)

model = NaiveClassifier(2, 1)

trainer = Trainer(train_set, model,
loss=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"),
n_epochs=10,
batch_size=32,
print_every=50,
validate_every=-1,
dev_data=dev_set,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=True,
save_path=None)
trainer.train()
"""
# 应该正确运行
"""

def test_trainer_suggestion1(self):
# 检查报错提示能否正确提醒用户。
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
dataset = prepare_fake_dataset2('x')

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)

def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}

model = Model()

with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model
)
"""
# 应该获取到的报错提示
NameError:
The following problems occurred when calling Model.forward(self, x1, x2, y)
missing param: ['y', 'x1', 'x2']
Suggestion: (1). You might need to set ['y'] as input.
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input.

"""

def test_trainer_suggestion2(self):
# 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,看是否可以运行
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)

def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()
"""
# 应该正确运行
"""

def test_trainer_suggestion3(self):
# 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,但是forward没有返回loss这个key
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)

def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'wrong_loss_key': loss}

model = Model()
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()

def test_trainer_suggestion4(self):
# 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,是否可以正确提示unused
dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'losses': loss}

model = Model()
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)

def test_trainer_suggestion5(self):
# 检查报错提示能否正确提醒用户
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2', 'y')
dataset.set_target('y')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)

def test_trainer_suggestion6(self):
# 检查报错提示能否正确提醒用户
# 这里传入多余参数,让其duplicate
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2')
dataset.set_target('y', 'x1')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
time.sleep(0.1)
# loss = F.cross_entropy(x, y)
return {'preds': x}

model = Model()
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
loss=CrossEntropyLoss(),
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2)

def test_case2(self):
# check metrics Wrong
data_set = prepare_fake_dataset2('x1', 'x2')

+ 35
- 8
test/core/test_vocabulary.py View File

@@ -10,36 +10,36 @@ counter = Counter(text)


class TestAdd(unittest.TestCase): class TestAdd(unittest.TestCase):
def test_add(self): def test_add(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
for word in text: for word in text:
vocab.add(word) vocab.add(word)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)


def test_add_word(self): def test_add_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
for word in text: for word in text:
vocab.add_word(word) vocab.add_word(word)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)


def test_add_word_lst(self): def test_add_word_lst(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.add_word_lst(text) vocab.add_word_lst(text)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)


def test_update(self): def test_update(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text) vocab.update(text)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)




class TestIndexing(unittest.TestCase): class TestIndexing(unittest.TestCase):
def test_len(self): def test_len(self):
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab.update(text) vocab.update(text)
self.assertEqual(len(vocab), len(counter)) self.assertEqual(len(vocab), len(counter))


def test_contains(self): def test_contains(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab.update(text) vocab.update(text)
self.assertTrue(text[-1] in vocab) self.assertTrue(text[-1] in vocab)
self.assertFalse("~!@#" in vocab) self.assertFalse("~!@#" in vocab)
@@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase):
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))


def test_index(self): def test_index(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text) vocab.update(text)
res = [vocab[w] for w in set(text)] res = [vocab[w] for w in set(text)]
self.assertEqual(len(res), len(set(res))) self.assertEqual(len(res), len(set(res)))
@@ -56,6 +56,33 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(len(res), len(set(res))) self.assertEqual(len(res), len(set(res)))


def test_to_word(self): def test_to_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text) vocab.update(text)
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])


class TestOther(unittest.TestCase):
def test_additional_update(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text)

_ = vocab["well"]
self.assertEqual(vocab.rebuild, False)

vocab.add("hahaha")
self.assertEqual(vocab.rebuild, True)

_ = vocab["hahaha"]
self.assertEqual(vocab.rebuild, False)
self.assertTrue("hahaha" in vocab)

def test_warning(self):
vocab = Vocabulary(max_size=len(set(text)), min_freq=None)
vocab.update(text)
self.assertEqual(vocab.rebuild, True)
print(len(vocab))
self.assertEqual(vocab.rebuild, False)

vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"])
# this will print a warning
self.assertEqual(vocab.rebuild, True)

+ 1
- 7
test/data_for_tests/glove.6B.50d_test.txt View File

@@ -1,12 +1,6 @@
the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581 the 0.418 0.24968 -0.41242 0.1217 0.34527 -0.044457 -0.49688 -0.17862 -0.00066023 -0.6566 0.27843 -0.14767 -0.55677 0.14658 -0.0095095 0.011658 0.10204 -0.12792 -0.8443 -0.12181 -0.016801 -0.33279 -0.1552 -0.23131 -0.19181 -1.8823 -0.76746 0.099051 -0.42125 -0.19526 4.0071 -0.18594 -0.52287 -0.31681 0.00059213 0.0074449 0.17778 -0.15897 0.012041 -0.054223 -0.29871 -0.15749 -0.34758 -0.045637 -0.44251 0.18785 0.0027849 -0.18411 -0.11514 -0.78581
, 0.013441 0.23682 -0.16899 0.40951 0.63812 0.47709 -0.42852 -0.55641 -0.364 -0.23938 0.13001 -0.063734 -0.39575 -0.48162 0.23291 0.090201 -0.13324 0.078639 -0.41634 -0.15428 0.10068 0.48891 0.31226 -0.1252 -0.037512 -1.5179 0.12612 -0.02442 -0.042961 -0.28351 3.5416 -0.11956 -0.014533 -0.1499 0.21864 -0.33412 -0.13872 0.31806 0.70358 0.44858 -0.080262 0.63003 0.32111 -0.46765 0.22786 0.36034 -0.37818 -0.56657 0.044691 0.30392
. 0.15164 0.30177 -0.16763 0.17684 0.31719 0.33973 -0.43478 -0.31086 -0.44999 -0.29486 0.16608 0.11963 -0.41328 -0.42353 0.59868 0.28825 -0.11547 -0.041848 -0.67989 -0.25063 0.18472 0.086876 0.46582 0.015035 0.043474 -1.4671 -0.30384 -0.023441 0.30589 -0.21785 3.746 0.0042284 -0.18436 -0.46209 0.098329 -0.11907 0.23919 0.1161 0.41705 0.056763 -6.3681e-05 0.068987 0.087939 -0.10285 -0.13931 0.22314 -0.080803 -0.35652 0.016413 0.10216
of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375 of 0.70853 0.57088 -0.4716 0.18048 0.54449 0.72603 0.18157 -0.52393 0.10381 -0.17566 0.078852 -0.36216 -0.11829 -0.83336 0.11917 -0.16605 0.061555 -0.012719 -0.56623 0.013616 0.22851 -0.14396 -0.067549 -0.38157 -0.23698 -1.7037 -0.86692 -0.26704 -0.2589 0.1767 3.8676 -0.1613 -0.13273 -0.68881 0.18444 0.0052464 -0.33874 -0.078956 0.24185 0.36576 -0.34727 0.28483 0.075693 -0.062178 -0.38988 0.22902 -0.21617 -0.22562 -0.093918 -0.80375
to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044 to 0.68047 -0.039263 0.30186 -0.17792 0.42962 0.032246 -0.41376 0.13228 -0.29847 -0.085253 0.17118 0.22419 -0.10046 -0.43653 0.33418 0.67846 0.057204 -0.34448 -0.42785 -0.43275 0.55963 0.10032 0.18677 -0.26854 0.037334 -2.0932 0.22171 -0.39868 0.20912 -0.55725 3.8826 0.47466 -0.95658 -0.37788 0.20869 -0.32752 0.12751 0.088359 0.16351 -0.21634 -0.094375 0.018324 0.21048 -0.03088 -0.19722 0.082279 -0.09434 -0.073297 -0.064699 -0.26044
and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097 and 0.26818 0.14346 -0.27877 0.016257 0.11384 0.69923 -0.51332 -0.47368 -0.33075 -0.13834 0.2702 0.30938 -0.45012 -0.4127 -0.09932 0.038085 0.029749 0.10076 -0.25058 -0.51818 0.34558 0.44922 0.48791 -0.080866 -0.10121 -1.3777 -0.10866 -0.23201 0.012839 -0.46508 3.8463 0.31362 0.13643 -0.52244 0.3302 0.33707 -0.35601 0.32431 0.12041 0.3512 -0.069043 0.36885 0.25168 -0.24517 0.25381 0.1367 -0.31178 -0.6321 -0.25028 -0.38097
in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285 in 0.33042 0.24995 -0.60874 0.10923 0.036372 0.151 -0.55083 -0.074239 -0.092307 -0.32821 0.09598 -0.82269 -0.36717 -0.67009 0.42909 0.016496 -0.23573 0.12864 -1.0953 0.43334 0.57067 -0.1036 0.20422 0.078308 -0.42795 -1.7984 -0.27865 0.11954 -0.12689 0.031744 3.8631 -0.17786 -0.082434 -0.62698 0.26497 -0.057185 -0.073521 0.46103 0.30862 0.12498 -0.48609 -0.0080272 0.031184 -0.36576 -0.42699 0.42164 -0.11666 -0.50703 -0.027273 -0.53285
a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796
" 0.25769 0.45629 -0.76974 -0.37679 0.59272 -0.063527 0.20545 -0.57385 -0.29009 -0.13662 0.32728 1.4719 -0.73681 -0.12036 0.71354 -0.46098 0.65248 0.48887 -0.51558 0.039951 -0.34307 -0.014087 0.86488 0.3546 0.7999 -1.4995 -1.8153 0.41128 0.23921 -0.43139 3.6623 -0.79834 -0.54538 0.16943 -0.82017 -0.3461 0.69495 -1.2256 -0.17992 -0.057474 0.030498 -0.39543 -0.38515 -1.0002 0.087599 -0.31009 -0.34677 -0.31438 0.75004 0.97065
's 0.23727 0.40478 -0.20547 0.58805 0.65533 0.32867 -0.81964 -0.23236 0.27428 0.24265 0.054992 0.16296 -1.2555 -0.086437 0.44536 0.096561 -0.16519 0.058378 -0.38598 0.086977 0.0033869 0.55095 -0.77697 -0.62096 0.092948 -2.5685 -0.67739 0.10151 -0.48643 -0.057805 3.1859 -0.017554 -0.16138 0.055486 -0.25885 -0.33938 -0.19928 0.26049 0.10478 -0.55934 -0.12342 0.65961 -0.51802 -0.82995 -0.082739 0.28155 -0.423 -0.27378 -0.007901 -0.030231


a 0.21705 0.46515 -0.46757 0.10082 1.0135 0.74845 -0.53104 -0.26256 0.16812 0.13182 -0.24909 -0.44185 -0.21739 0.51004 0.13448 -0.43141 -0.03123 0.20674 -0.78138 -0.20148 -0.097401 0.16088 -0.61836 -0.18504 -0.12461 -2.2526 -0.22321 0.5043 0.32257 0.15313 3.9636 -0.71365 -0.67012 0.28388 0.21738 0.14433 0.25926 0.23434 0.4274 -0.44451 0.13813 0.36973 -0.64289 0.024142 -0.039315 -0.26037 0.12017 -0.043782 0.41013 0.1796

+ 77
- 0
test/data_for_tests/tutorial_sample_dataset.csv View File

@@ -0,0 +1,77 @@
A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story . 1
This quiet , introspective and entertaining independent is worth seeking . 4
Even fans of Ismail Merchant 's work , I suspect , would have a hard time sitting through this one . 1
A positively thrilling combination of ethnography and all the intrigue , betrayal , deceit and murder of a Shakespearean tragedy or a juicy soap opera . 3
Aggressive self-glorification and a manipulative whitewash . 1
A comedy-drama of nearly epic proportions rooted in a sincere performance by the title character undergoing midlife crisis . 4
Narratively , Trouble Every Day is a plodding mess . 1
The Importance of Being Earnest , so thick with wit it plays like a reading from Bartlett 's Familiar Quotations 3
But it does n't leave you with much . 1
You could hate it for the same reason . 1
There 's little to recommend Snow Dogs , unless one considers cliched dialogue and perverse escapism a source of high hilarity . 1
Kung Pow is Oedekerk 's realization of his childhood dream to be in a martial-arts flick , and proves that sometimes the dreams of youth should remain just that . 1
The performances are an absolute joy . 4
Fresnadillo has something serious to say about the ways in which extravagant chance can distort our perspective and throw us off the path of good sense . 3
I still like Moonlight Mile , better judgment be damned . 3
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3
a bilingual charmer , just like the woman who inspired it 3
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1
It 's everything you 'd expect -- but nothing more . 2
Best indie of the year , so far . 4
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2
The plot is romantic comedy boilerplate from start to finish . 2
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2
A film that clearly means to preach exclusively to the converted . 2
While The Importance of Being Earnest offers opportunities for occasional smiles and chuckles , it does n't give us a reason to be in the theater beyond Wilde 's wit and the actors ' performances . 1
The latest vapid actor 's exercise to appropriate the structure of Arthur Schnitzler 's Reigen . 1
More vaudeville show than well-constructed narrative , but on those terms it 's inoffensive and actually rather sweet . 2
Nothing more than a run-of-the-mill action flick . 2
Hampered -- no , paralyzed -- by a self-indulgent script ... that aims for poetry and ends up sounding like satire . 0
Ice Age is the first computer-generated feature cartoon to feel like other movies , and that makes for some glacial pacing early on . 2
There 's very little sense to what 's going on here , but the makers serve up the cliches with considerable dash . 2
Cattaneo should have followed the runaway success of his first film , The Full Monty , with something different . 2
They 're the unnamed , easily substitutable forces that serve as whatever terror the heroes of horror movies try to avoid . 1
It almost feels as if the movie is more interested in entertaining itself than in amusing us . 1
The movie 's progression into rambling incoherence gives new meaning to the phrase ` fatal script error . ' 0
I still like Moonlight Mile , better judgment be damned . 3
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3
a bilingual charmer , just like the woman who inspired it 3
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1
It 's everything you 'd expect -- but nothing more . 2
Best indie of the year , so far . 4
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2
The plot is romantic comedy boilerplate from start to finish . 2
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2
A film that clearly means to preach exclusively to the converted . 2
I still like Moonlight Mile , better judgment be damned . 3
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3
a bilingual charmer , just like the woman who inspired it 3
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1
It 's everything you 'd expect -- but nothing more . 2
Best indie of the year , so far . 4
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2
The plot is romantic comedy boilerplate from start to finish . 2
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2
A film that clearly means to preach exclusively to the converted . 2
I still like Moonlight Mile , better judgment be damned . 3
A welcome relief from baseball movies that try too hard to be mythic , this one is a sweet and modest and ultimately winning story . 3
a bilingual charmer , just like the woman who inspired it 3
Like a less dizzily gorgeous companion to Mr. Wong 's In the Mood for Love -- very much a Hong Kong movie despite its mainland setting . 2
As inept as big-screen remakes of The Avengers and The Wild Wild West . 1
It 's everything you 'd expect -- but nothing more . 2
Best indie of the year , so far . 4
Hatfield and Hicks make the oddest of couples , and in this sense the movie becomes a study of the gambles of the publishing world , offering a case study that exists apart from all the movie 's political ramifications . 3
It 's like going to a house party and watching the host defend himself against a frothing ex-girlfriend . 1
That the Chuck Norris `` grenade gag '' occurs about 7 times during Windtalkers is a good indication of how serious-minded the film is . 2
The plot is romantic comedy boilerplate from start to finish . 2
It arrives with an impeccable pedigree , mongrel pep , and almost indecipherable plot complications . 2
A film that clearly means to preach exclusively to the converted . 2

+ 0
- 0
test/io/__init__.py View File


+ 1
- 2
test/io/test_config_saver.py View File

@@ -1,8 +1,7 @@
import os import os
import unittest import unittest


from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.config_saver import ConfigSaver
from fastNLP.io.config_io import ConfigSection, ConfigLoader, ConfigSaver




class TestConfigSaver(unittest.TestCase): class TestConfigSaver(unittest.TestCase):


+ 12
- 0
test/io/test_embed_loader.py View File

@@ -0,0 +1,12 @@
import unittest

from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.embed_loader import EmbedLoader


class TestEmbedLoader(unittest.TestCase):
def test_case(self):
vocab = Vocabulary()
vocab.update(["the", "in", "I", "to", "of", "hahaha"])
embedding = EmbedLoader().fast_load_embedding(50, "test/data_for_tests/glove.6B.50d_test.txt", vocab)
self.assertEqual(tuple(embedding.shape), (len(vocab), 50))

+ 91
- 0
test/test_tutorial.py View File

@@ -0,0 +1,91 @@
import unittest

from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import Tester
from fastNLP import Vocabulary
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.models import CNNText


class TestTutorial(unittest.TestCase):
def test_tutorial(self):
# 从csv读取数据到DataSet
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv"
dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
sep='\t')
print(len(dataset))
print(dataset[0])

dataset.append(Instance(raw_sentence='fake data', label='0'))
dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
# label转int
dataset.apply(lambda x: int(x['label']), new_field_name='label')

# 使用空格分割句子
def split_sent(ins):
return ins['raw_sentence'].split()

dataset.apply(split_sent, new_field_name='words')
# 增加长度信息
dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')
print(len(dataset))
print(dataset[0])

# DataSet.drop(func)筛除数据
dataset.drop(lambda x: x['seq_len'] <= 3)
print(len(dataset))

# 设置DataSet中,哪些field要转为tensor
# set target,loss或evaluate中的golden,计算loss,模型评估时使用
dataset.set_target("label")
# set input,模型forward时使用
dataset.set_input("words")

# 分出测试集、训练集
test_data, train_data = dataset.split(0.5)
print(len(test_data))
print(len(train_data))

# 构建词表, Vocabulary.add(word)
vocab = Vocabulary(min_freq=2)
train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
vocab.build_vocab()

# index句子, Vocabulary.to_index(word)
train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')
print(test_data[0])

model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)

from fastNLP import Trainer
from copy import deepcopy

# 更改DataSet中对应field的名称,要以模型的forward等参数名一致
train_data.rename_field('words', 'word_seq') # input field 与 forward 参数一致
train_data.rename_field('label', 'label_seq')
test_data.rename_field('words', 'word_seq')
test_data.rename_field('label', 'label_seq')

# 实例化Trainer,传入模型和数据,进行训练
copy_model = deepcopy(model)
overfit_trainer = Trainer(train_data=test_data, model=copy_model,
loss=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
overfit_trainer.train()

trainer = Trainer(train_data=train_data, model=model,
loss=CrossEntropyLoss(pred="output", target="label_seq"),
metrics=AccuracyMetric(pred="predict", target="label_seq"), n_epochs=10, batch_size=4,
dev_data=test_data, save_path="./save")
trainer.train()
print('Train finished!')

# 使用fastNLP的Tester测试脚本
tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred="predict", target="label_seq"),
batch_size=4)
acc = tester.test()
print(acc)

+ 911
- 0
tutorials/fastnlp_10min_tutorial_v2.ipynb View File

@@ -0,0 +1,911 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastNLP上手教程\n",
"-------\n",
"\n",
"fastNLP提供方便的数据预处理,训练和测试模型的功能"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DataSet & Instance\n",
"------\n",
"\n",
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n",
"\n",
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8529"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from fastNLP import DataSet\n",
"from fastNLP import Instance\n",
"\n",
"# 从csv读取数据到DataSet\n",
"dataset = DataSet.read_csv('../sentence.csv', headers=('raw_sentence', 'label'), sep='\\t')\n",
"print(len(dataset))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 使用数字索引[k],获取第k个样本\n",
"print(dataset[0])\n",
"\n",
"# 索引也可以是负数\n",
"print(dataset[-3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instance\n",
"Instance表示一个样本,由一个或多个field(域,属性,特征)组成,每个field有名字和值。\n",
"\n",
"在初始化Instance时即可定义它包含的域,使用 \"field_name=field_value\"的写法。"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'raw_sentence': fake data,\n'label': 0}"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# DataSet.append(Instance)加入新数据\n",
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
"dataset[-1]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DataSet.apply方法\n",
"数据预处理利器"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 将所有数字转为小写\n",
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# label转int\n",
"dataset.apply(lambda x: int(x['label']), new_field_name='label')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 使用空格分割句子\n",
"def split_sent(ins):\n",
" return ins['raw_sentence'].split()\n",
"dataset.apply(split_sent, new_field_name='words')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1,\n'words': ['a', 'series', 'of', 'escapades', 'demonstrating', 'the', 'adage', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gander', ',', 'some', 'of', 'which', 'occasionally', 'amuses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.'],\n'seq_len': 37}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 增加长度信息\n",
"dataset.apply(lambda x: len(x['words']), new_field_name='seq_len')\n",
"print(dataset[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## DataSet.drop\n",
"筛选数据"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8358"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"dataset.drop(lambda x: x['seq_len'] <= 3)\n",
"print(len(dataset))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 配置DataSet\n",
"1. 哪些域是特征,哪些域是标签\n",
"2. 切分训练集/验证集"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"# 设置DataSet中,哪些field要转为tensor\n",
"\n",
"# set target,loss或evaluate中的golden,计算loss,模型评估时使用\n",
"dataset.set_target(\"label\")\n",
"# set input,模型forward时使用\n",
"dataset.set_input(\"words\")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"5851"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2507"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 分出测试集、训练集\n",
"\n",
"test_data, train_data = dataset.split(0.3)\n",
"print(len(test_data))\n",
"print(len(train_data))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Vocabulary\n",
"------\n",
"\n",
"fastNLP中的Vocabulary轻松构建词表,将词转成数字"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': the project 's filmmakers forgot to include anything even halfway scary as they poorly rejigger fatal attraction into a high school setting .,\n'label': 0,\n'words': [4, 423, 9, 316, 1, 8, 1, 312, 72, 1478, 885, 14, 86, 725, 1, 1913, 1431, 53, 5, 455, 736, 1, 2],\n'seq_len': 23}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from fastNLP import Vocabulary\n",
"\n",
"# 构建词表, Vocabulary.add(word)\n",
"vocab = Vocabulary(min_freq=2)\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
"vocab.build_vocab()\n",
"\n",
"# index句子, Vocabulary.to_index(word)\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='words')\n",
"\n",
"\n",
"print(test_data[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model\n",
"定义一个PyTorch模型"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"CNNText(\n (embed): Embedding(\n (embed): Embedding(3459, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from fastNLP.models import CNNText\n",
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这是上述模型的forward方法。如果你不知道什么是forward方法,请参考我们的PyTorch教程。\n",
"\n",
"注意两点:\n",
"1. forward参数名字叫**word_seq**,请记住。\n",
"2. forward的返回值是一个**dict**,其中有个key的名字叫**output**。\n",
"\n",
"```Python\n",
" def forward(self, word_seq):\n",
" \"\"\"\n",
"\n",
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
" :return output: dict of torch.LongTensor, [batch_size, num_classes]\n",
" \"\"\"\n",
" x = self.embed(word_seq) # [N,L] -> [N,L,C]\n",
" x = self.conv_pool(x) # [N,L,C] -> [N,C]\n",
" x = self.dropout(x)\n",
" x = self.fc(x) # [N,C] -> [N, N_class]\n",
" return {'output': x}\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"这是上述模型的predict方法,是用来直接输出该任务的预测结果,与forward目的不同。\n",
"\n",
"注意两点:\n",
"1. predict参数名也叫**word_seq**。\n",
"2. predict的返回值是也一个**dict**,其中有个key的名字叫**predict**。\n",
"\n",
"```\n",
" def predict(self, word_seq):\n",
" \"\"\"\n",
"\n",
" :param word_seq: torch.LongTensor, [batch_size, seq_len]\n",
" :return predict: dict of torch.LongTensor, [batch_size, seq_len]\n",
" \"\"\"\n",
" output = self(word_seq)\n",
" _, predict = output['output'].max(dim=1)\n",
" return {'predict': predict}\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Trainer & Tester\n",
"------\n",
"\n",
"使用fastNLP的Trainer训练模型"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer\n",
"from copy import deepcopy\n",
"from fastNLP.core.losses import CrossEntropyLoss\n",
"from fastNLP.core.metrics import AccuracyMetric\n",
"\n",
"\n",
"# 更改DataSet中对应field的名称,与模型的forward的参数名一致\n",
"# 因为forward的参数叫word_seq, 所以要把原本叫words的field改名为word_seq\n",
"# 这里的演示是让你了解这种**命名规则**\n",
"train_data.rename_field('words', 'word_seq')\n",
"test_data.rename_field('words', 'word_seq')\n",
"\n",
"# 顺便把label换名为label_seq\n",
"train_data.rename_field('label', 'label_seq')\n",
"test_data.rename_field('label', 'label_seq')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### loss\n",
"训练模型需要提供一个损失函数\n",
"\n",
"下面提供了一个在分类问题中常用的交叉熵损失。注意它的**初始化参数**。\n",
"\n",
"pred参数对应的是模型的forward返回的dict的一个key的名字,这里是\"output\"。\n",
"\n",
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"loss = CrossEntropyLoss(pred=\"output\", target=\"label_seq\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metric\n",
"定义评价指标\n",
"\n",
"这里使用准确率。参数的“命名规则”跟上面类似。\n",
"\n",
"pred参数对应的是模型的predict方法返回的dict的一个key的名字,这里是\"predict\"。\n",
"\n",
"target参数对应的是dataset作为标签的field的名字,这里是\"label_seq\"。"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"metric = AccuracyMetric(pred=\"predict\", target=\"label_seq\")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-07 14:11:31"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=915), HTML(value='')), layout=Layout(display=…"
]
},
"execution_count": 0,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5. Step:183/915. AccuracyMetric: acc=0.350367"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5. Step:366/915. AccuracyMetric: acc=0.409332"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5. Step:549/915. AccuracyMetric: acc=0.572552"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5. Step:732/915. AccuracyMetric: acc=0.711331"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5. Step:915/915. AccuracyMetric: acc=0.801572"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
}
],
"source": [
"# 实例化Trainer,传入模型和数据,进行训练\n",
"# 先在test_data拟合\n",
"copy_model = deepcopy(model)\n",
"overfit_trainer = Trainer(model=copy_model, train_data=test_data, dev_data=test_data,\n",
" loss=loss,\n",
" metrics=metric,\n",
" save_path=None,\n",
" batch_size=32,\n",
" n_epochs=5)\n",
"overfit_trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-07 14:12:21"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=395), HTML(value='')), layout=Layout(display=…"
]
},
"execution_count": 0,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5. Step:79/395. AccuracyMetric: acc=0.250043"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5. Step:158/395. AccuracyMetric: acc=0.280807"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5. Step:237/395. AccuracyMetric: acc=0.280978"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5. Step:316/395. AccuracyMetric: acc=0.285592"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5. Step:395/395. AccuracyMetric: acc=0.278927"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
}
],
"source": [
"# 用train_data训练,在test_data验证\n",
"trainer = Trainer(model=model, train_data=train_data, dev_data=test_data,\n",
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n",
" metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n",
" save_path=None,\n",
" batch_size=32,\n",
" n_epochs=5)\n",
"trainer.train()\n",
"print('Train finished!')"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tester] \nAccuracyMetric: acc=0.280636"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'AccuracyMetric': {'acc': 0.280636}}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 调用Tester在test_data上评价效果\n",
"from fastNLP import Tester\n",
"\n",
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric(pred=\"predict\", target=\"label_seq\"),\n",
" batch_size=4)\n",
"acc = tester.test()\n",
"print(acc)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 860
- 0
tutorials/fastnlp_10tmin_tutorial.ipynb View File

@@ -0,0 +1,860 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"fastNLP上手教程\n",
"-------\n",
"\n",
"fastNLP提供方便的数据预处理,训练和测试模型的功能"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"DataSet & Instance\n",
"------\n",
"\n",
"fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。\n",
"\n",
"有一些read_*方法,可以轻松从文件读取数据,存成DataSet。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,\n'label': 1}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from fastNLP import DataSet\n",
"from fastNLP import Instance\n",
"\n",
"# 从csv读取数据到DataSet\n",
"win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n",
"dataset = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')\n",
"print(dataset[0])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'raw_sentence': fake data,\n'label': 0}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# DataSet.append(Instance)加入新数据\n",
"\n",
"dataset.append(Instance(raw_sentence='fake data', label='0'))\n",
"dataset[-1]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# DataSet.apply(func, new_field_name)对数据预处理\n",
"\n",
"# 将所有数字转为小写\n",
"dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
"# label转int\n",
"dataset.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n",
"# 使用空格分割句子\n",
"dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)\n",
"def split_sent(ins):\n",
" return ins['raw_sentence'].split()\n",
"dataset.apply(split_sent, new_field_name='words', is_input=True)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# DataSet.drop(func)筛除数据\n",
"# 删除低于某个长度的词语\n",
"dataset.drop(lambda x: len(x['words']) <= 3)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train size: "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"54"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test size: "
]
}
],
"source": [
"# 分出测试集、训练集\n",
"\n",
"test_data, train_data = dataset.split(0.3)\n",
"print(\"Train size: \", len(test_data))\n",
"print(\"Test size: \", len(train_data))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Vocabulary\n",
"------\n",
"\n",
"fastNLP中的Vocabulary轻松构建词表,将词转成数字"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'raw_sentence': the plot is romantic comedy boilerplate from start to finish .,\n'label': 2,\n'label_seq': 2,\n'words': ['the', 'plot', 'is', 'romantic', 'comedy', 'boilerplate', 'from', 'start', 'to', 'finish', '.'],\n'word_seq': [2, 13, 9, 24, 25, 26, 15, 27, 11, 28, 3]}"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from fastNLP import Vocabulary\n",
"\n",
"# 构建词表, Vocabulary.add(word)\n",
"vocab = Vocabulary(min_freq=2)\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
"vocab.build_vocab()\n",
"\n",
"# index句子, Vocabulary.to_index(word)\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
"test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
"\n",
"\n",
"print(test_data[0])"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"batch_x has: {'words': array([list(['this', 'kind', 'of', 'hands-on', 'storytelling', 'is', 'ultimately', 'what', 'makes', 'shanghai', 'ghetto', 'move', 'beyond', 'a', 'good', ',', 'dry', ',', 'reliable', 'textbook', 'and', 'what', 'allows', 'it', 'to', 'rank', 'with', 'its', 'worthy', 'predecessors', '.']),\n",
" list(['the', 'entire', 'movie', 'is', 'filled', 'with', 'deja', 'vu', 'moments', '.'])],\n",
" dtype=object), 'word_seq': tensor([[ 19, 184, 6, 1, 481, 9, 206, 50, 91, 1210, 1609, 1330,\n",
" 495, 5, 63, 4, 1269, 4, 1, 1184, 7, 50, 1050, 10,\n",
" 8, 1611, 16, 21, 1039, 1, 2],\n",
" [ 3, 711, 22, 9, 1282, 16, 2482, 2483, 200, 2, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0]])}\n",
"batch_y has: {'label_seq': tensor([3, 2])}\n"
]
}
],
"source": [
"# 假设你们需要做强化学习或者gan之类的项目,也许你们可以使用这里的dataset\n",
"from fastNLP.core.batch import Batch\n",
"from fastNLP.core.sampler import RandomSampler\n",
"\n",
"batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())\n",
"for batch_x, batch_y in batch_iterator:\n",
" print(\"batch_x has: \", batch_x)\n",
" print(\"batch_y has: \", batch_y)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Model\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"CNNText(\n (embed): Embedding(\n (embed): Embedding(77, 50, padding_idx=0)\n (dropout): Dropout(p=0.0)\n )\n (conv_pool): ConvMaxpool(\n (convs): ModuleList(\n (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))\n (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))\n (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))\n )\n )\n (dropout): Dropout(p=0.1)\n (fc): Linear(\n (linear): Linear(in_features=12, out_features=5, bias=True)\n )\n)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 定义一个简单的Pytorch模型\n",
"\n",
"from fastNLP.models import CNNText\n",
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n",
"model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Trainer & Tester\n",
"------\n",
"\n",
"使用fastNLP的Trainer训练模型"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Trainer\n",
"from copy import deepcopy\n",
"from fastNLP import CrossEntropyLoss\n",
"from fastNLP import AccuracyMetric"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-07 14:07:20"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=20), HTML(value='')), layout=Layout(display='…"
]
},
"execution_count": 0,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/10. Step:2/20. AccuracyMetric: acc=0.037037"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/10. Step:4/20. AccuracyMetric: acc=0.296296"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/10. Step:6/20. AccuracyMetric: acc=0.333333"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/10. Step:8/20. AccuracyMetric: acc=0.555556"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/10. Step:10/20. AccuracyMetric: acc=0.611111"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 6/10. Step:12/20. AccuracyMetric: acc=0.481481"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 7/10. Step:14/20. AccuracyMetric: acc=0.62963"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 8/10. Step:16/20. AccuracyMetric: acc=0.685185"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 9/10. Step:18/20. AccuracyMetric: acc=0.722222"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 10/10. Step:20/20. AccuracyMetric: acc=0.777778"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
}
],
"source": [
"# 进行overfitting测试\n",
"copy_model = deepcopy(model)\n",
"overfit_trainer = Trainer(model=copy_model, \n",
" train_data=test_data, \n",
" dev_data=test_data,\n",
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n",
" metrics=AccuracyMetric(),\n",
" n_epochs=10,\n",
" save_path=None)\n",
"overfit_trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-07 14:08:10"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=5), HTML(value='')), layout=Layout(display='i…"
]
},
"execution_count": 0,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.037037"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.037037"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.037037"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.185185"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.240741"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train finished!"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 实例化Trainer,传入模型和数据,进行训练\n",
"trainer = Trainer(model=model, \n",
" train_data=train_data, \n",
" dev_data=test_data,\n",
" loss=CrossEntropyLoss(pred=\"output\", target=\"label_seq\"),\n",
" metrics=AccuracyMetric(),\n",
" n_epochs=5)\n",
"trainer.train()\n",
"print('Train finished!')"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[tester] \nAccuracyMetric: acc=0.240741"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from fastNLP import Tester\n",
"\n",
"tester = Tester(data=test_data, model=model, metrics=AccuracyMetric())\n",
"acc = tester.test()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# In summary\n",
"\n",
"## fastNLP Trainer的伪代码逻辑\n",
"### 1. 准备DataSet,假设DataSet中共有如下的fields\n",
" ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']\n",
" 通过\n",
" DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input\n",
" 通过\n",
" DataSet.set_target('label', flag=True)将'label'设置为target\n",
"### 2. 初始化模型\n",
" class Model(nn.Module):\n",
" def __init__(self):\n",
" xxx\n",
" def forward(self, word_seq1, word_seq2):\n",
" # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的\n",
" # (2) input field的数量可以多于这里的形参数量。但是不能少于。\n",
" xxxx\n",
" # 输出必须是一个dict\n",
"### 3. Trainer的训练过程\n",
" (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward\n",
" (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。\n",
" 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx}; \n",
" 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;\n",
" 为了解决以上的问题,我们的loss提供映射机制\n",
" 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target\n",
" 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可\n",
" (3) 对于Metric是同理的\n",
" Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值 \n",
" \n",
" \n",
"\n",
"## 一些问题.\n",
"### 1. DataSet中为什么需要设置input和target\n",
" 只有被设置为input或者target的数据才会在train的过程中被取出来\n",
" (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。\n",
" (1.2) 我们在传递值给losser或者metric的时候会使用来自: \n",
" (a)Model.forward的output\n",
" (b)被设置为target的field\n",
" \n",
"\n",
"### 2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数\n",
" (1.1) 构建模型过程中,\n",
" 例如:\n",
" DataSet中x,seq_lens是input,那么forward就应该是\n",
" def forward(self, x, seq_lens):\n",
" pass\n",
" 我们是通过形参名称进行匹配的field的\n",
" \n",
"\n",
"\n",
"### 1. 加载数据到DataSet\n",
"### 2. 使用apply操作对DataSet进行预处理\n",
" (2.1) 处理过程中将某些field设置为input,某些field设置为target\n",
"### 3. 构建模型\n",
" (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。\n",
" 例如:\n",
" DataSet中x,seq_lens是input,那么forward就应该是\n",
" def forward(self, x, seq_lens):\n",
" pass\n",
" 我们是通过形参名称进行匹配的field的\n",
" (3.2) 模型的forward的output需要是dict类型的。\n",
" 建议将输出设置为{\"pred\": xx}.\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

+ 333
- 0
tutorials/fastnlp_1_minute_tutorial.ipynb View File

@@ -0,0 +1,333 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# FastNLP 1分钟上手教程"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## step 1\n",
"读取数据集"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import DataSet\n",
"# linux_path = \"../test/data_for_tests/tutorial_sample_dataset.csv\"\n",
"win_path = \"C:\\\\Users\\zyfeng\\Desktop\\FudanNLP\\\\fastNLP\\\\test\\\\data_for_tests\\\\tutorial_sample_dataset.csv\"\n",
"ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\\t')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## step 2\n",
"数据预处理\n",
"1. 类型转换\n",
"2. 切分验证集\n",
"3. 构建词典"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# 将所有数字转为小写\n",
"ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')\n",
"# label转int\n",
"ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)\n",
"\n",
"def split_sent(ins):\n",
" return ins['raw_sentence'].split()\n",
"ds.apply(split_sent, new_field_name='words', is_input=True)\n"
]
},
{
"cell_type": "code",
"execution_count": 60,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train size: "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"54"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test size: "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"23"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# 分割训练集/验证集\n",
"train_data, dev_data = ds.split(0.3)\n",
"print(\"Train size: \", len(train_data))\n",
"print(\"Test size: \", len(dev_data))"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP import Vocabulary\n",
"vocab = Vocabulary(min_freq=2)\n",
"train_data.apply(lambda x: [vocab.add(word) for word in x['words']])\n",
"\n",
"# index句子, Vocabulary.to_index(word)\n",
"train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n",
"dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## step 3\n",
" 定义模型"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [],
"source": [
"from fastNLP.models import CNNText\n",
"model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## step 4\n",
"开始训练"
]
},
{
"cell_type": "code",
"execution_count": 63,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"training epochs started 2018-12-07 14:03:41"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"text/plain": [
"HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…"
]
},
"execution_count": 0,
"metadata": {},
"output_type": "execute_result"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train finished!"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric\n",
"trainer = Trainer(model=model, \n",
" train_data=train_data, \n",
" dev_data=dev_data,\n",
" loss=CrossEntropyLoss(),\n",
" metrics=AccuracyMetric()\n",
" )\n",
"trainer.train()\n",
"print('Train finished!')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 本教程结束。更多操作请参考进阶教程。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

+ 101
- 0
tutorials/fastnlp_advanced_tutorial.ipynb View File

@@ -0,0 +1,101 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"## FastNLP 进阶教程\n",
"本教程阅读时间平均30分钟"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据部分\n",
"### DataSet\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Instance"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Vocabulary"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 模型部分\n",
"### model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 训练测试部分\n",
"### Loss"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Metric"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Trainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tester"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

Loading…
Cancel
Save