diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py index b57b9bb6..ebf8f2ea 100644 --- a/fastNLP/core/field.py +++ b/fastNLP/core/field.py @@ -93,5 +93,35 @@ class LabelField(Field): return torch.LongTensor([self._index]) +class SeqLabelField(Field): + def __init__(self, label_seq, is_target=True): + super(SeqLabelField, self).__init__(is_target) + self.label_seq = label_seq + self._index = None + + def get_length(self): + return len(self.label_seq) + + def index(self, vocab): + if self._index is None: + self._index = [vocab[c] for c in self.label_seq] + return self._index + + def to_tensor(self, padding_length): + pads = [0] * (padding_length - self.get_length()) + if self._index is None: + if self.get_length() == 0: + return pads + elif isinstance(self.label_seq[0], int): + return torch.LongTensor(self.label_seq + pads) + elif isinstance(self.label_seq[0], str): + raise RuntimeError("Field {} not indexed. Call index method.".format(self.label)) + else: + raise RuntimeError( + "Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label))) + else: + return torch.LongTensor(self._index + pads) + + if __name__ == "__main__": tf = TextField("test the code".split(), is_target=False) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index d2ed4564..77b27b92 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -126,12 +126,14 @@ class Vocabulary(object): """ return self[w] + @property @check_build_vocab def unknown_idx(self): if self.unknown_label is None: return None return self.word2idx[self.unknown_label] + @property @check_build_vocab def padding_idx(self): if self.padding_label is None: diff --git a/fastNLP/loader/embed_loader.py b/fastNLP/loader/embed_loader.py index 6de83cee..b44c9851 100644 --- a/fastNLP/loader/embed_loader.py +++ b/fastNLP/loader/embed_loader.py @@ -1,7 +1,7 @@ import _pickle import os -import numpy as np +import torch from fastNLP.loader.base_loader import BaseLoader from fastNLP.core.vocabulary import Vocabulary @@ -30,7 +30,7 @@ class EmbedLoader(BaseLoader): for line in f: line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) if len(line) > 0: - emb[line[0]] = np.array(list(map(float, line[1:]))) + emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) return emb @staticmethod @@ -62,8 +62,8 @@ class EmbedLoader(BaseLoader): # If the embedding pickle exists, load it and return. if os.path.exists(emb_pkl): with open(emb_pkl, "rb") as f: - embedding_np = _pickle.load(f) - return embedding_np + embedding_np, vocab = _pickle.load(f) + return embedding_np, vocab # Otherwise, load the pre-trained embedding. pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) if vocab is None: @@ -71,7 +71,7 @@ class EmbedLoader(BaseLoader): vocab = Vocabulary() for w in pretrain.keys(): vocab.update(w) - embedding_np = np.random.uniform(-1, 1, size=(len(vocab), emb_dim)) + embedding_np = torch.randn(len(vocab), emb_dim) for w, v in pretrain.items(): if len(v.shape) > 1 or emb_dim != v.shape[0]: raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) @@ -80,5 +80,5 @@ class EmbedLoader(BaseLoader): # save and return the result with open(emb_pkl, "wb") as f: - _pickle.dump(embedding_np, f) + _pickle.dump((embedding_np, vocab), f) return embedding_np, vocab