Browse Source

add character field

tags/v0.2.0
FengZiYjun 5 years ago
parent
commit
5133fe67b4
2 changed files with 57 additions and 2 deletions
  1. +34
    -2
      fastNLP/core/field.py
  2. +23
    -0
      test/core/test_field.py

+ 34
- 2
fastNLP/core/field.py View File

@@ -131,5 +131,37 @@ class SeqLabelField(Field):
def contents(self): def contents(self):
return self.label_seq.copy() return self.label_seq.copy()


if __name__ == "__main__":
tf = TextField("test the code".split(), is_target=False)

class CharTextField(Field):
def __init__(self, text, max_word_len, is_target=False):
super(CharTextField, self).__init__(is_target)
self.text = text
self.max_word_len = max_word_len
self._index = []

def get_length(self):
return len(self.text)

def contents(self):
return self.text.copy()

def index(self, char_vocab):
if len(self._index) == 0:
for word in self.text:
char_index = [char_vocab[ch] for ch in word]
if self.max_word_len >= len(char_index):
char_index += [0] * (self.max_word_len - len(char_index))
else:
self._index.clear()
raise RuntimeError("Word {} has more than {} characters. ".format(word, self.max_word_len))
self._index.append(char_index)
return self._index

def to_tensor(self, padding_length):
"""

:param padding_length: int, the padding length of the word sequence.
:return : tensor of shape (padding_length, max_word_len)
"""
pads = [[0] * self.max_word_len] * (padding_length - self.get_length())
return torch.LongTensor(self._index + pads)

+ 23
- 0
test/core/test_field.py View File

@@ -0,0 +1,23 @@
import unittest

from fastNLP.core.field import CharTextField


class TestField(unittest.TestCase):
def test_case(self):
text = "PhD applicants must submit a Research Plan and a resume " \
"specify your class ranking written in English and a list of research" \
" publications if any".split()
max_word_len = max([len(w) for w in text])
field = CharTextField(text, max_word_len, is_target=False)
all_char = set()
for word in text:
all_char.update([ch for ch in word])
char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)}

self.assertEqual(field.index(char_vocab),
[[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text])
self.assertEqual(field.get_length(), len(text))
self.assertEqual(field.contents(), text)
tensor = field.to_tensor(50)
self.assertEqual(tuple(tensor.shape), (50, max_word_len))

Loading…
Cancel
Save