|
|
@@ -1,4 +1,5 @@ |
|
|
|
import torch |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
class Field(object): |
|
|
@@ -6,61 +7,69 @@ class Field(object): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, is_target: bool): |
|
|
|
def __init__(self, name, is_target: bool): |
|
|
|
self.name = name |
|
|
|
self.is_target = is_target |
|
|
|
self.content = None |
|
|
|
|
|
|
|
def index(self, vocab): |
|
|
|
"""create index field |
|
|
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def get_length(self): |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
def to_tensor(self, padding_length): |
|
|
|
raise NotImplementedError |
|
|
|
def __len__(self): |
|
|
|
"""number of samples |
|
|
|
""" |
|
|
|
assert self.content is not None |
|
|
|
return len(self.content) |
|
|
|
|
|
|
|
def contents(self): |
|
|
|
def to_tensor(self, id_list): |
|
|
|
"""convert batch of index to tensor |
|
|
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
class TextField(Field): |
|
|
|
def __init__(self, text, is_target): |
|
|
|
def __init__(self, name, text, is_target): |
|
|
|
""" |
|
|
|
:param text: list of strings |
|
|
|
:param is_target: bool |
|
|
|
""" |
|
|
|
super(TextField, self).__init__(is_target) |
|
|
|
self.text = text |
|
|
|
self._index = None |
|
|
|
super(TextField, self).__init__(name, is_target) |
|
|
|
self.content = text |
|
|
|
|
|
|
|
def index(self, vocab): |
|
|
|
if self._index is None: |
|
|
|
self._index = [vocab[c] for c in self.text] |
|
|
|
else: |
|
|
|
raise RuntimeError("Replicate indexing of this field.") |
|
|
|
return self._index |
|
|
|
|
|
|
|
def get_length(self): |
|
|
|
"""Fetch the length of the text field. |
|
|
|
|
|
|
|
:return length: int, the length of the text. |
|
|
|
|
|
|
|
""" |
|
|
|
return len(self.text) |
|
|
|
|
|
|
|
def to_tensor(self, padding_length: int): |
|
|
|
"""Convert text field to tensor. |
|
|
|
|
|
|
|
:param padding_length: int |
|
|
|
:return tensor: torch.LongTensor, of shape [padding_length, ] |
|
|
|
""" |
|
|
|
pads = [] |
|
|
|
if self._index is None: |
|
|
|
raise RuntimeError("Indexing not done before to_tensor in TextField.") |
|
|
|
if padding_length > self.get_length(): |
|
|
|
pads = [0] * (padding_length - self.get_length()) |
|
|
|
return torch.LongTensor(self._index + pads) |
|
|
|
|
|
|
|
def contents(self): |
|
|
|
return self.text.copy() |
|
|
|
idx_field = IndexField(self.name+'_idx', self.content, vocab, self.is_target) |
|
|
|
return idx_field |
|
|
|
|
|
|
|
|
|
|
|
class IndexField(Field): |
|
|
|
def __init__(self, name, content, vocab, is_target): |
|
|
|
super(IndexField, self).__init__(name, is_target) |
|
|
|
self.content = [] |
|
|
|
self.padding_idx = vocab.padding_idx |
|
|
|
for sent in content: |
|
|
|
idx = vocab.index_sent(sent) |
|
|
|
if isinstance(idx, list): |
|
|
|
idx = torch.Tensor(idx) |
|
|
|
elif isinstance(idx, np.array): |
|
|
|
idx = torch.from_numpy(idx) |
|
|
|
elif not isinstance(idx, torch.Tensor): |
|
|
|
raise ValueError |
|
|
|
self.content.append(idx) |
|
|
|
|
|
|
|
def to_tensor(self, id_list, sort_within_batch=False): |
|
|
|
max_len = max(id_list) |
|
|
|
batch_size = len(id_list) |
|
|
|
tensor = torch.full((batch_size, max_len), self.padding_idx, dtype=torch.long) |
|
|
|
len_list = [(i, self.content[i].size(0)) for i in id_list] |
|
|
|
if sort_within_batch: |
|
|
|
len_list = sorted(len_list, key=lambda x: x[1], reverse=True) |
|
|
|
for i, (idx, length) in enumerate(len_list): |
|
|
|
if length == max_len: |
|
|
|
tensor[i] = self.content[idx] |
|
|
|
else: |
|
|
|
tensor[i][:length] = self.content[idx] |
|
|
|
return tensor |
|
|
|
|
|
|
|
class LabelField(Field): |
|
|
|
"""The Field representing a single label. Can be a string or integer. |
|
|
|