diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index 64aafdd3..1c5e7425 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -131,5 +131,37 @@ class SeqLabelField(Field): def contents(self): return self.label_seq.copy() -if __name__ == "__main__": - tf = TextField("test the code".split(), is_target=False) + +class CharTextField(Field): + def __init__(self, text, max_word_len, is_target=False): + super(CharTextField, self).__init__(is_target) + self.text = text + self.max_word_len = max_word_len + self._index = [] + + def get_length(self): + return len(self.text) + + def contents(self): + return self.text.copy() + + def index(self, char_vocab): + if len(self._index) == 0: + for word in self.text: + char_index = [char_vocab[ch] for ch in word] + if self.max_word_len >= len(char_index): + char_index += [0] * (self.max_word_len - len(char_index)) + else: + self._index.clear() + raise RuntimeError("Word {} has more than {} characters. ".format(word, self.max_word_len)) + self._index.append(char_index) + return self._index + + def to_tensor(self, padding_length): + """ + + :param padding_length: int, the padding length of the word sequence. + :return : tensor of shape (padding_length, max_word_len) + """ + pads = [[0] * self.max_word_len] * (padding_length - self.get_length()) + return torch.LongTensor(self._index + pads) diff --git a/test/core/test_field.py b/test/core/test_field.py new file mode 100644 index 00000000..ccc36f49 --- /dev/null +++ b/test/core/test_field.py @@ -0,0 +1,23 @@ +import unittest + +from fastNLP.core.field import CharTextField + + +class TestField(unittest.TestCase): + def test_case(self): + text = "PhD applicants must submit a Research Plan and a resume " \ + "specify your class ranking written in English and a list of research" \ + " publications if any".split() + max_word_len = max([len(w) for w in text]) + field = CharTextField(text, max_word_len, is_target=False) + all_char = set() + for word in text: + all_char.update([ch for ch in word]) + char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)} + + self.assertEqual(field.index(char_vocab), + [[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text]) + self.assertEqual(field.get_length(), len(text)) + self.assertEqual(field.contents(), text) + tensor = field.to_tensor(50) + self.assertEqual(tuple(tensor.shape), (50, max_word_len))