@@ -0,0 +1,3 @@ | |||||
from .core import * | |||||
from . import models | |||||
from . import modules |
@@ -0,0 +1,10 @@ | |||||
from .batch import Batch | |||||
from .dataset import DataSet | |||||
from .fieldarray import FieldArray | |||||
from .instance import Instance | |||||
from .metrics import Evaluator, ClassifyEvaluator, SNLIEvaluator, SeqLabelEvaluator | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | |||||
from .tester import Tester | |||||
from .trainer import Trainer | |||||
from .vocabulary import Vocabulary | |||||
@@ -9,7 +9,7 @@ class Batch(object): | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler, use_cuda): | |||||
def __init__(self, dataset, batch_size, sampler, use_cuda=False): | |||||
""" | """ | ||||
:param dataset: a DataSet object | :param dataset: a DataSet object | ||||
@@ -54,9 +54,9 @@ class Batch(object): | |||||
for field_name, field in self.dataset.get_fields().items(): | for field_name, field in self.dataset.get_fields().items(): | ||||
if field.need_tensor: | if field.need_tensor: | ||||
batch = torch.from_numpy(field.get(indices)) | batch = torch.from_numpy(field.get(indices)) | ||||
if not field.need_tensor: | |||||
pass | |||||
elif field.is_target: | |||||
if self.use_cuda: | |||||
batch = batch.cuda() | |||||
if field.is_target: | |||||
batch_y[field_name] = batch | batch_y[field_name] = batch | ||||
else: | else: | ||||
batch_x[field_name] = batch | batch_x[field_name] = batch | ||||
@@ -88,10 +88,11 @@ class DataSet(object): | |||||
assert name in self.field_arrays | assert name in self.field_arrays | ||||
self.field_arrays[name].append(field) | self.field_arrays[name].append(field) | ||||
def add_field(self, name, fields, need_tensor=False, is_target=False): | |||||
def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): | |||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
assert len(self) == len(fields) | assert len(self) == len(fields) | ||||
self.field_arrays[name] = FieldArray(name, fields, | self.field_arrays[name] = FieldArray(name, fields, | ||||
padding_val=padding_val, | |||||
need_tensor=need_tensor, | need_tensor=need_tensor, | ||||
is_target=is_target) | is_target=is_target) | ||||
@@ -104,6 +105,16 @@ class DataSet(object): | |||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
if isinstance(name, int): | if isinstance(name, int): | ||||
return self.Instance(self, idx=name) | return self.Instance(self, idx=name) | ||||
elif isinstance(name, slice): | |||||
ds = DataSet() | |||||
for field in self.field_arrays.values(): | |||||
ds.add_field(name=field.name, | |||||
fields=field.content[name], | |||||
padding_val=field.padding_val, | |||||
need_tensor=field.need_tensor, | |||||
is_target=field.is_target) | |||||
return ds | |||||
elif isinstance(name, str): | elif isinstance(name, str): | ||||
return self.field_arrays[name] | return self.field_arrays[name] | ||||
else: | else: | ||||
@@ -187,7 +198,15 @@ class DataSet(object): | |||||
for ins in self: | for ins in self: | ||||
results.append(func(ins)) | results.append(func(ins)) | ||||
if new_field_name is not None: | if new_field_name is not None: | ||||
self.add_field(new_field_name, results) | |||||
if new_field_name in self.field_arrays: | |||||
# overwrite the field, keep same attributes | |||||
old_field = self.field_arrays[new_field_name] | |||||
padding_val = old_field.padding_val | |||||
need_tensor = old_field.need_tensor | |||||
is_target = old_field.is_target | |||||
self.add_field(new_field_name, results, padding_val, need_tensor, is_target) | |||||
else: | |||||
self.add_field(new_field_name, results) | |||||
else: | else: | ||||
return results | return results | ||||
@@ -8,6 +8,7 @@ class FieldArray(object): | |||||
self.padding_val = padding_val | self.padding_val = padding_val | ||||
self.is_target = is_target | self.is_target = is_target | ||||
self.need_tensor = need_tensor | self.need_tensor = need_tensor | ||||
self.dtype = None | |||||
def __repr__(self): | def __repr__(self): | ||||
# TODO | # TODO | ||||
@@ -30,10 +31,14 @@ class FieldArray(object): | |||||
batch_size = len(idxes) | batch_size = len(idxes) | ||||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | ||||
if isinstance(self.content[0], int) or isinstance(self.content[0], float): | if isinstance(self.content[0], int) or isinstance(self.content[0], float): | ||||
array = np.array([self.content[i] for i in idxes], dtype=type(self.content[0])) | |||||
if self.dtype is None: | |||||
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double | |||||
array = np.array([self.content[i] for i in idxes], dtype=self.dtype) | |||||
else: | else: | ||||
if self.dtype is None: | |||||
self.dtype = np.int64 | |||||
max_len = max([len(self.content[i]) for i in idxes]) | max_len = max([len(self.content[i]) for i in idxes]) | ||||
array = np.full((batch_size, max_len), self.padding_val, dtype=np.int64) | |||||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | |||||
for i, idx in enumerate(idxes): | for i, idx in enumerate(idxes): | ||||
array[i][:len(self.content[idx])] = self.content[idx] | array[i][:len(self.content[idx])] = self.content[idx] | ||||
@@ -0,0 +1,6 @@ | |||||
from .base_model import BaseModel | |||||
from .biaffine_parser import BiaffineParser, GraphParser | |||||
from .char_language_model import CharLM | |||||
from .cnn_text_classification import CNNText | |||||
from .sequence_modeling import SeqLabeling, AdvSeqLabel | |||||
from .snli import SNLI |
@@ -2,10 +2,15 @@ from . import aggregator | |||||
from . import decoder | from . import decoder | ||||
from . import encoder | from . import encoder | ||||
from . import interactor | from . import interactor | ||||
from .aggregator import * | |||||
from .decoder import * | |||||
from .encoder import * | |||||
from .dropout import TimestepDropout | |||||
__version__ = '0.0.0' | __version__ = '0.0.0' | ||||
__all__ = ['encoder', | __all__ = ['encoder', | ||||
'decoder', | 'decoder', | ||||
'aggregator', | 'aggregator', | ||||
'interactor'] | |||||
'interactor', | |||||
'TimestepDropout'] |
@@ -1,5 +1,7 @@ | |||||
from .max_pool import MaxPool | from .max_pool import MaxPool | ||||
from .avg_pool import AvgPool | |||||
from .kmax_pool import KMaxPool | |||||
from .attention import Attention | |||||
from .self_attention import SelfAttention | |||||
__all__ = [ | |||||
'MaxPool' | |||||
] |
@@ -21,6 +21,7 @@ class Attention(torch.nn.Module): | |||||
class DotAtte(nn.Module): | class DotAtte(nn.Module): | ||||
def __init__(self, key_size, value_size): | def __init__(self, key_size, value_size): | ||||
# TODO never test | |||||
super(DotAtte, self).__init__() | super(DotAtte, self).__init__() | ||||
self.key_size = key_size | self.key_size = key_size | ||||
self.value_size = value_size | self.value_size = value_size | ||||
@@ -42,6 +43,8 @@ class DotAtte(nn.Module): | |||||
class MultiHeadAtte(nn.Module): | class MultiHeadAtte(nn.Module): | ||||
def __init__(self, input_size, output_size, key_size, value_size, num_atte): | def __init__(self, input_size, output_size, key_size, value_size, num_atte): | ||||
raise NotImplementedError | |||||
# TODO never test | |||||
super(MultiHeadAtte, self).__init__() | super(MultiHeadAtte, self).__init__() | ||||
self.in_linear = nn.ModuleList() | self.in_linear = nn.ModuleList() | ||||
for i in range(num_atte * 3): | for i in range(num_atte * 3): | ||||