Browse Source

make import more friendly, Dataset support slice.

tags/v0.2.0
yunfan 5 years ago
parent
commit
8ee94eb6d5
9 changed files with 65 additions and 12 deletions
  1. +3
    -0
      fastNLP/__init__.py
  2. +10
    -0
      fastNLP/core/__init__.py
  3. +4
    -4
      fastNLP/core/batch.py
  4. +21
    -2
      fastNLP/core/dataset.py
  5. +7
    -2
      fastNLP/core/fieldarray.py
  6. +6
    -0
      fastNLP/models/__init__.py
  7. +6
    -1
      fastNLP/modules/__init__.py
  8. +5
    -3
      fastNLP/modules/aggregator/__init__.py
  9. +3
    -0
      fastNLP/modules/aggregator/attention.py

+ 3
- 0
fastNLP/__init__.py View File

@@ -0,0 +1,3 @@
from .core import *
from . import models
from . import modules

+ 10
- 0
fastNLP/core/__init__.py View File

@@ -0,0 +1,10 @@
from .batch import Batch
from .dataset import DataSet
from .fieldarray import FieldArray
from .instance import Instance
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester
from .trainer import Trainer
from .vocabulary import Vocabulary


+ 4
- 4
fastNLP/core/batch.py View File

@@ -9,7 +9,7 @@ class Batch(object):


""" """


def __init__(self, dataset, batch_size, sampler, use_cuda):
def __init__(self, dataset, batch_size, sampler, use_cuda=False):
""" """


:param dataset: a DataSet object :param dataset: a DataSet object
@@ -54,9 +54,9 @@ class Batch(object):
for field_name, field in self.dataset.get_fields().items(): for field_name, field in self.dataset.get_fields().items():
if field.need_tensor: if field.need_tensor:
batch = torch.from_numpy(field.get(indices)) batch = torch.from_numpy(field.get(indices))
if not field.need_tensor:
pass
elif field.is_target:
if self.use_cuda:
batch = batch.cuda()
if field.is_target:
batch_y[field_name] = batch batch_y[field_name] = batch
else: else:
batch_x[field_name] = batch batch_x[field_name] = batch


+ 21
- 2
fastNLP/core/dataset.py View File

@@ -88,10 +88,11 @@ class DataSet(object):
assert name in self.field_arrays assert name in self.field_arrays
self.field_arrays[name].append(field) self.field_arrays[name].append(field)


def add_field(self, name, fields, need_tensor=False, is_target=False):
def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False):
if len(self.field_arrays) != 0: if len(self.field_arrays) != 0:
assert len(self) == len(fields) assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields, self.field_arrays[name] = FieldArray(name, fields,
padding_val=padding_val,
need_tensor=need_tensor, need_tensor=need_tensor,
is_target=is_target) is_target=is_target)


@@ -104,6 +105,16 @@ class DataSet(object):
def __getitem__(self, name): def __getitem__(self, name):
if isinstance(name, int): if isinstance(name, int):
return self.Instance(self, idx=name) return self.Instance(self, idx=name)
elif isinstance(name, slice):
ds = DataSet()
for field in self.field_arrays.values():
ds.add_field(name=field.name,
fields=field.content[name],
padding_val=field.padding_val,
need_tensor=field.need_tensor,
is_target=field.is_target)
return ds

elif isinstance(name, str): elif isinstance(name, str):
return self.field_arrays[name] return self.field_arrays[name]
else: else:
@@ -187,7 +198,15 @@ class DataSet(object):
for ins in self: for ins in self:
results.append(func(ins)) results.append(func(ins))
if new_field_name is not None: if new_field_name is not None:
self.add_field(new_field_name, results)
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]
padding_val = old_field.padding_val
need_tensor = old_field.need_tensor
is_target = old_field.is_target
self.add_field(new_field_name, results, padding_val, need_tensor, is_target)
else:
self.add_field(new_field_name, results)
else: else:
return results return results




+ 7
- 2
fastNLP/core/fieldarray.py View File

@@ -8,6 +8,7 @@ class FieldArray(object):
self.padding_val = padding_val self.padding_val = padding_val
self.is_target = is_target self.is_target = is_target
self.need_tensor = need_tensor self.need_tensor = need_tensor
self.dtype = None


def __repr__(self): def __repr__(self):
# TODO # TODO
@@ -30,10 +31,14 @@ class FieldArray(object):
batch_size = len(idxes) batch_size = len(idxes)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float): if isinstance(self.content[0], int) or isinstance(self.content[0], float):
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0]))
if self.dtype is None:
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double
array = np.array([self.content[i] for i in idxes], dtype=self.dtype)
else: else:
if self.dtype is None:
self.dtype = np.int64
max_len = max([len(self.content[i]) for i in idxes]) max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int64)
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)


for i, idx in enumerate(idxes): for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx] array[i][:len(self.content[idx])] = self.content[idx]


+ 6
- 0
fastNLP/models/__init__.py View File

@@ -0,0 +1,6 @@
from .base_model import BaseModel
from .biaffine_parser import BiaffineParser, GraphParser
from .char_language_model import CharLM
from .cnn_text_classification import CNNText
from .sequence_modeling import SeqLabeling, AdvSeqLabel
from .snli import SNLI

+ 6
- 1
fastNLP/modules/__init__.py View File

@@ -2,10 +2,15 @@ from . import aggregator
from . import decoder from . import decoder
from . import encoder from . import encoder
from . import interactor from . import interactor
from .aggregator import *
from .decoder import *
from .encoder import *
from .dropout import TimestepDropout


__version__ = '0.0.0' __version__ = '0.0.0'


__all__ = ['encoder', __all__ = ['encoder',
'decoder', 'decoder',
'aggregator', 'aggregator',
'interactor']
'interactor',
'TimestepDropout']

+ 5
- 3
fastNLP/modules/aggregator/__init__.py View File

@@ -1,5 +1,7 @@
from .max_pool import MaxPool from .max_pool import MaxPool
from .avg_pool import AvgPool
from .kmax_pool import KMaxPool

from .attention import Attention
from .self_attention import SelfAttention


__all__ = [
'MaxPool'
]

+ 3
- 0
fastNLP/modules/aggregator/attention.py View File

@@ -21,6 +21,7 @@ class Attention(torch.nn.Module):


class DotAtte(nn.Module): class DotAtte(nn.Module):
def __init__(self, key_size, value_size): def __init__(self, key_size, value_size):
# TODO never test
super(DotAtte, self).__init__() super(DotAtte, self).__init__()
self.key_size = key_size self.key_size = key_size
self.value_size = value_size self.value_size = value_size
@@ -42,6 +43,8 @@ class DotAtte(nn.Module):


class MultiHeadAtte(nn.Module): class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte): def __init__(self, input_size, output_size, key_size, value_size, num_atte):
raise NotImplementedError
# TODO never test
super(MultiHeadAtte, self).__init__() super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList() self.in_linear = nn.ModuleList()
for i in range(num_atte * 3): for i in range(num_atte * 3):


Loading…
Cancel
Save