Browse Source

init new field

tags/v0.2.0
yunfan 6 years ago
parent
commit
9b25de3ff3
2 changed files with 51 additions and 48 deletions
  1. +2
    -8
      fastNLP/core/dataset.py
  2. +49
    -40
      fastNLP/core/field.py

+ 2
- 8
fastNLP/core/dataset.py View File

@@ -14,17 +14,11 @@ class DataSet(list):

"""

def __init__(self, name="", instances=None):
def __init__(self, fields=None):
"""

:param name: str, the name of the dataset. (default: "")
:param instances: list of Instance objects. (default: None)
"""
list.__init__([])
self.name = name
self.origin_len = None
if instances is not None:
self.extend(instances)
pass

def index_all(self, vocab):
for ins in self:


+ 49
- 40
fastNLP/core/field.py View File

@@ -1,4 +1,5 @@
import torch
import numpy as np


class Field(object):
@@ -6,61 +7,69 @@ class Field(object):

"""

def __init__(self, is_target: bool):
def __init__(self, name, is_target: bool):
self.name = name
self.is_target = is_target
self.content = None

def index(self, vocab):
"""create index field
"""
raise NotImplementedError

def get_length(self):
raise NotImplementedError
def to_tensor(self, padding_length):
raise NotImplementedError
def __len__(self):
"""number of samples
"""
assert self.content is not None
return len(self.content)

def contents(self):
def to_tensor(self, id_list):
"""convert batch of index to tensor
"""
raise NotImplementedError

class TextField(Field):
def __init__(self, text, is_target):
def __init__(self, name, text, is_target):
"""
:param text: list of strings
:param is_target: bool
"""
super(TextField, self).__init__(is_target)
self.text = text
self._index = None
super(TextField, self).__init__(name, is_target)
self.content = text

def index(self, vocab):
if self._index is None:
self._index = [vocab[c] for c in self.text]
else:
raise RuntimeError("Replicate indexing of this field.")
return self._index

def get_length(self):
"""Fetch the length of the text field.

:return length: int, the length of the text.

"""
return len(self.text)

def to_tensor(self, padding_length: int):
"""Convert text field to tensor.

:param padding_length: int
:return tensor: torch.LongTensor, of shape [padding_length, ]
"""
pads = []
if self._index is None:
raise RuntimeError("Indexing not done before to_tensor in TextField.")
if padding_length > self.get_length():
pads = [0] * (padding_length - self.get_length())
return torch.LongTensor(self._index + pads)

def contents(self):
return self.text.copy()
idx_field = IndexField(self.name+'_idx', self.content, vocab, self.is_target)
return idx_field


class IndexField(Field):
def __init__(self, name, content, vocab, is_target):
super(IndexField, self).__init__(name, is_target)
self.content = []
self.padding_idx = vocab.padding_idx
for sent in content:
idx = vocab.index_sent(sent)
if isinstance(idx, list):
idx = torch.Tensor(idx)
elif isinstance(idx, np.array):
idx = torch.from_numpy(idx)
elif not isinstance(idx, torch.Tensor):
raise ValueError
self.content.append(idx)

def to_tensor(self, id_list, sort_within_batch=False):
max_len = max(id_list)
batch_size = len(id_list)
tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long)
len_list = [(i, self.content[i].size(0)) for i in id_list]
if sort_within_batch:
len_list = sorted(len_list, key=lambda x: x[1], reverse=True)
for i, (idx, length) in enumerate(len_list):
if length == max_len:
tensor[i] = self.content[idx]
else:
tensor[i][:length] = self.content[idx]
return tensor

class LabelField(Field):
"""The Field representing a single label. Can be a string or integer.


Loading…
Cancel
Save