From 8ee94eb6d530e9bb5955afc6464d846c3ac4b7dd Mon Sep 17 00:00:00 2001 From: yunfan Date: Mon, 19 Nov 2018 23:10:37 +0800 Subject: [PATCH] make import more friendly, Dataset support slice. --- fastNLP/__init__.py | 3 +++ fastNLP/core/__init__.py | 10 ++++++++++ fastNLP/core/batch.py | 8 ++++---- fastNLP/core/dataset.py | 23 +++++++++++++++++++++-- fastNLP/core/fieldarray.py | 9 +++++++-- fastNLP/models/__init__.py | 6 ++++++ fastNLP/modules/__init__.py | 7 ++++++- fastNLP/modules/aggregator/__init__.py | 8 +++++--- fastNLP/modules/aggregator/attention.py | 3 +++ 9 files changed, 65 insertions(+), 12 deletions(-) diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index e69de29b..0f6da45f 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -0,0 +1,3 @@ +from .core import * +from . import models +from . import modules diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index e69de29b..03f284d5 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -0,0 +1,10 @@ +from .batch import Batch +from .dataset import DataSet +from .fieldarray import FieldArray +from .instance import Instance +from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator +from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler +from .tester import Tester +from .trainer import Trainer +from .vocabulary import Vocabulary + diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 29ed4c8a..b047081a 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -9,7 +9,7 @@ class Batch(object): """ - def __init__(self, dataset, batch_size, sampler, use_cuda): + def __init__(self, dataset, batch_size, sampler, use_cuda=False): """ :param dataset: a DataSet object @@ -54,9 +54,9 @@ class Batch(object): for field_name, field in self.dataset.get_fields().items(): if field.need_tensor: batch = torch.from_numpy(field.get(indices)) - if not field.need_tensor: - pass - elif field.is_target: + if self.use_cuda: + batch = batch.cuda() + if field.is_target: batch_y[field_name] = batch else: batch_x[field_name] = batch diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index d8ae4087..684bd18d 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -88,10 +88,11 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) - def add_field(self, name, fields, need_tensor=False, is_target=False): + def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): if len(self.field_arrays) != 0: assert len(self) == len(fields) self.field_arrays[name] = FieldArray(name, fields, + padding_val=padding_val, need_tensor=need_tensor, is_target=is_target) @@ -104,6 +105,16 @@ class DataSet(object): def __getitem__(self, name): if isinstance(name, int): return self.Instance(self, idx=name) + elif isinstance(name, slice): + ds = DataSet() + for field in self.field_arrays.values(): + ds.add_field(name=field.name, + fields=field.content[name], + padding_val=field.padding_val, + need_tensor=field.need_tensor, + is_target=field.is_target) + return ds + elif isinstance(name, str): return self.field_arrays[name] else: @@ -187,7 +198,15 @@ class DataSet(object): for ins in self: results.append(func(ins)) if new_field_name is not None: - self.add_field(new_field_name, results) + if new_field_name in self.field_arrays: + # overwrite the field, keep same attributes + old_field = self.field_arrays[new_field_name] + padding_val = old_field.padding_val + need_tensor = old_field.need_tensor + is_target = old_field.is_target + self.add_field(new_field_name, results, padding_val, need_tensor, is_target) + else: + self.add_field(new_field_name, results) else: return results diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 82eecf84..7ead3a64 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -8,6 +8,7 @@ class FieldArray(object): self.padding_val = padding_val self.is_target = is_target self.need_tensor = need_tensor + self.dtype = None def __repr__(self): # TODO @@ -30,10 +31,14 @@ class FieldArray(object): batch_size = len(idxes) # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 if isinstance(self.content[0], int) or isinstance(self.content[0], float): - array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) + if self.dtype is None: + self.dtype = np.int64 if isinstance(self.content[0], int) else np.double + array = np.array([self.content[i] for i in idxes], dtype=self.dtype) else: + if self.dtype is None: + self.dtype = np.int64 max_len = max([len(self.content[i]) for i in idxes]) - array = np.full((batch_size, max_len), self.padding_val, dtype=np.int64) + array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) for i, idx in enumerate(idxes): array[i][:len(self.content[idx])] = self.content[idx] diff --git a/fastNLP/models/__init__.py b/fastNLP/models/__init__.py index e69de29b..5bb2bc3d 100644 --- a/fastNLP/models/__init__.py +++ b/fastNLP/models/__init__.py @@ -0,0 +1,6 @@ +from .base_model import BaseModel +from .biaffine_parser import BiaffineParser, GraphParser +from .char_language_model import CharLM +from .cnn_text_classification import CNNText +from .sequence_modeling import SeqLabeling, AdvSeqLabel +from .snli import SNLI diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py index 21cb2886..3af1ebad 100644 --- a/fastNLP/modules/__init__.py +++ b/fastNLP/modules/__init__.py @@ -2,10 +2,15 @@ from . import aggregator from . import decoder from . import encoder from . import interactor +from .aggregator import * +from .decoder import * +from .encoder import * +from .dropout import TimestepDropout __version__ = '0.0.0' __all__ = ['encoder', 'decoder', 'aggregator', - 'interactor'] + 'interactor', + 'TimestepDropout'] diff --git a/fastNLP/modules/aggregator/__init__.py b/fastNLP/modules/aggregator/__init__.py index 3c57625b..dbc36abc 100644 --- a/fastNLP/modules/aggregator/__init__.py +++ b/fastNLP/modules/aggregator/__init__.py @@ -1,5 +1,7 @@ from .max_pool import MaxPool +from .avg_pool import AvgPool +from .kmax_pool import KMaxPool + +from .attention import Attention +from .self_attention import SelfAttention -__all__ = [ - 'MaxPool' -] diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index 69c5fdf6..882807f8 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -21,6 +21,7 @@ class Attention(torch.nn.Module): class DotAtte(nn.Module): def __init__(self, key_size, value_size): + # TODO never test super(DotAtte, self).__init__() self.key_size = key_size self.value_size = value_size @@ -42,6 +43,8 @@ class DotAtte(nn.Module): class MultiHeadAtte(nn.Module): def __init__(self, input_size, output_size, key_size, value_size, num_atte): + raise NotImplementedError + # TODO never test super(MultiHeadAtte, self).__init__() self.in_linear = nn.ModuleList() for i in range(num_atte * 3):