|
@@ -131,5 +131,37 @@ class SeqLabelField(Field): |
|
|
def contents(self): |
|
|
def contents(self): |
|
|
return self.label_seq.copy() |
|
|
return self.label_seq.copy() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
tf = TextField("test the code".split(), is_target=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CharTextField(Field): |
|
|
|
|
|
def __init__(self, text, max_word_len, is_target=False): |
|
|
|
|
|
super(CharTextField, self).__init__(is_target) |
|
|
|
|
|
self.text = text |
|
|
|
|
|
self.max_word_len = max_word_len |
|
|
|
|
|
self._index = [] |
|
|
|
|
|
|
|
|
|
|
|
def get_length(self): |
|
|
|
|
|
return len(self.text) |
|
|
|
|
|
|
|
|
|
|
|
def contents(self): |
|
|
|
|
|
return self.text.copy() |
|
|
|
|
|
|
|
|
|
|
|
def index(self, char_vocab): |
|
|
|
|
|
if len(self._index) == 0: |
|
|
|
|
|
for word in self.text: |
|
|
|
|
|
char_index = [char_vocab[ch] for ch in word] |
|
|
|
|
|
if self.max_word_len >= len(char_index): |
|
|
|
|
|
char_index += [0] * (self.max_word_len - len(char_index)) |
|
|
|
|
|
else: |
|
|
|
|
|
self._index.clear() |
|
|
|
|
|
raise RuntimeError("Word {} has more than {} characters. ".format(word, self.max_word_len)) |
|
|
|
|
|
self._index.append(char_index) |
|
|
|
|
|
return self._index |
|
|
|
|
|
|
|
|
|
|
|
def to_tensor(self, padding_length): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
:param padding_length: int, the padding length of the word sequence. |
|
|
|
|
|
:return : tensor of shape (padding_length, max_word_len) |
|
|
|
|
|
""" |
|
|
|
|
|
pads = [[0] * self.max_word_len] * (padding_length - self.get_length()) |
|
|
|
|
|
return torch.LongTensor(self._index + pads) |