Browse Source

make import more friendly, Dataset support slice.

tags/v0.2.0
yunfan 5 years ago
parent
commit
8ee94eb6d5
9 changed files with 65 additions and 12 deletions
  1. +3
    -0
      fastNLP/__init__.py
  2. +10
    -0
      fastNLP/core/__init__.py
  3. +4
    -4
      fastNLP/core/batch.py
  4. +21
    -2
      fastNLP/core/dataset.py
  5. +7
    -2
      fastNLP/core/fieldarray.py
  6. +6
    -0
      fastNLP/models/__init__.py
  7. +6
    -1
      fastNLP/modules/__init__.py
  8. +5
    -3
      fastNLP/modules/aggregator/__init__.py
  9. +3
    -0
      fastNLP/modules/aggregator/attention.py

+ 3
- 0
fastNLP/__init__.py View File

@@ -0,0 +1,3 @@
from .core import *
from . import models
from . import modules

+ 10
- 0
fastNLP/core/__init__.py View File

@@ -0,0 +1,10 @@
from .batch import Batch
from .dataset import DataSet
from .fieldarray import FieldArray
from .instance import Instance
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester
from .trainer import Trainer
from .vocabulary import Vocabulary


+ 4
- 4
fastNLP/core/batch.py View File

@@ -9,7 +9,7 @@ class Batch(object):

"""

def __init__(self, dataset, batch_size, sampler, use_cuda):
def __init__(self, dataset, batch_size, sampler, use_cuda=False):
"""

:param dataset: a DataSet object
@@ -54,9 +54,9 @@ class Batch(object):
for field_name, field in self.dataset.get_fields().items():
if field.need_tensor:
batch = torch.from_numpy(field.get(indices))
if not field.need_tensor:
pass
elif field.is_target:
if self.use_cuda:
batch = batch.cuda()
if field.is_target:
batch_y[field_name] = batch
else:
batch_x[field_name] = batch


+ 21
- 2
fastNLP/core/dataset.py View File

@@ -88,10 +88,11 @@ class DataSet(object):
assert name in self.field_arrays
self.field_arrays[name].append(field)

def add_field(self, name, fields, need_tensor=False, is_target=False):
def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False):
if len(self.field_arrays) != 0:
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields,
padding_val=padding_val,
need_tensor=need_tensor,
is_target=is_target)

@@ -104,6 +105,16 @@ class DataSet(object):
def __getitem__(self, name):
if isinstance(name, int):
return self.Instance(self, idx=name)
elif isinstance(name, slice):
ds = DataSet()
for field in self.field_arrays.values():
ds.add_field(name=field.name,
fields=field.content[name],
padding_val=field.padding_val,
need_tensor=field.need_tensor,
is_target=field.is_target)
return ds

elif isinstance(name, str):
return self.field_arrays[name]
else:
@@ -187,7 +198,15 @@ class DataSet(object):
for ins in self:
results.append(func(ins))
if new_field_name is not None:
self.add_field(new_field_name, results)
if new_field_name in self.field_arrays:
# overwrite the field, keep same attributes
old_field = self.field_arrays[new_field_name]
padding_val = old_field.padding_val
need_tensor = old_field.need_tensor
is_target = old_field.is_target
self.add_field(new_field_name, results, padding_val, need_tensor, is_target)
else:
self.add_field(new_field_name, results)
else:
return results



+ 7
- 2
fastNLP/core/fieldarray.py View File

@@ -8,6 +8,7 @@ class FieldArray(object):
self.padding_val = padding_val
self.is_target = is_target
self.need_tensor = need_tensor
self.dtype = None

def __repr__(self):
# TODO
@@ -30,10 +31,14 @@ class FieldArray(object):
batch_size = len(idxes)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0]))
if self.dtype is None:
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double
array = np.array([self.content[i] for i in idxes], dtype=self.dtype)
else:
if self.dtype is None:
self.dtype = np.int64
max_len = max([len(self.content[i]) for i in idxes])
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int64)
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)

for i, idx in enumerate(idxes):
array[i][:len(self.content[idx])] = self.content[idx]


+ 6
- 0
fastNLP/models/__init__.py View File

@@ -0,0 +1,6 @@
from .base_model import BaseModel
from .biaffine_parser import BiaffineParser, GraphParser
from .char_language_model import CharLM
from .cnn_text_classification import CNNText
from .sequence_modeling import SeqLabeling, AdvSeqLabel
from .snli import SNLI

+ 6
- 1
fastNLP/modules/__init__.py View File

@@ -2,10 +2,15 @@ from . import aggregator
from . import decoder
from . import encoder
from . import interactor
from .aggregator import *
from .decoder import *
from .encoder import *
from .dropout import TimestepDropout

__version__ = '0.0.0'

__all__ = ['encoder',
'decoder',
'aggregator',
'interactor']
'interactor',
'TimestepDropout']

+ 5
- 3
fastNLP/modules/aggregator/__init__.py View File

@@ -1,5 +1,7 @@
from .max_pool import MaxPool
from .avg_pool import AvgPool
from .kmax_pool import KMaxPool

from .attention import Attention
from .self_attention import SelfAttention

__all__ = [
'MaxPool'
]

+ 3
- 0
fastNLP/modules/aggregator/attention.py View File

@@ -21,6 +21,7 @@ class Attention(torch.nn.Module):

class DotAtte(nn.Module):
def __init__(self, key_size, value_size):
# TODO never test
super(DotAtte, self).__init__()
self.key_size = key_size
self.value_size = value_size
@@ -42,6 +43,8 @@ class DotAtte(nn.Module):

class MultiHeadAtte(nn.Module):
def __init__(self, input_size, output_size, key_size, value_size, num_atte):
raise NotImplementedError
# TODO never test
super(MultiHeadAtte, self).__init__()
self.in_linear = nn.ModuleList()
for i in range(num_atte * 3):


Loading…
Cancel
Save