Browse Source

update, fix bug

tags/v0.2.0
yunfan 6 years ago
parent
commit
1f680f24e5
3 changed files with 38 additions and 6 deletions
  1. +30
    -0
      fastNLP/core/field.py
  2. +2
    -0
      fastNLP/core/vocabulary.py
  3. +6
    -6
      fastNLP/loader/embed_loader.py

+ 30
- 0
fastNLP/core/field.py View File

@@ -93,5 +93,35 @@ class LabelField(Field):
return torch.LongTensor([self._index]) return torch.LongTensor([self._index])




class SeqLabelField(Field):
def __init__(self, label_seq, is_target=True):
super(SeqLabelField, self).__init__(is_target)
self.label_seq = label_seq
self._index = None
def get_length(self):
return len(self.label_seq)

def index(self, vocab):
if self._index is None:
self._index = [vocab[c] for c in self.label_seq]
return self._index

def to_tensor(self, padding_length):
pads = [0] * (padding_length - self.get_length())
if self._index is None:
if self.get_length() == 0:
return pads
elif isinstance(self.label_seq[0], int):
return torch.LongTensor(self.label_seq + pads)
elif isinstance(self.label_seq[0], str):
raise RuntimeError("Field {} not indexed. Call index method.".format(self.label))
else:
raise RuntimeError(
"Not support type for SeqLabelField. Expect str or int, got {}.".format(type(self.label)))
else:
return torch.LongTensor(self._index + pads)


if __name__ == "__main__": if __name__ == "__main__":
tf = TextField("test the code".split(), is_target=False) tf = TextField("test the code".split(), is_target=False)

+ 2
- 0
fastNLP/core/vocabulary.py View File

@@ -126,12 +126,14 @@ class Vocabulary(object):
""" """
return self[w] return self[w]


@property
@check_build_vocab @check_build_vocab
def unknown_idx(self): def unknown_idx(self):
if self.unknown_label is None: if self.unknown_label is None:
return None return None
return self.word2idx[self.unknown_label] return self.word2idx[self.unknown_label]


@property
@check_build_vocab @check_build_vocab
def padding_idx(self): def padding_idx(self):
if self.padding_label is None: if self.padding_label is None:


+ 6
- 6
fastNLP/loader/embed_loader.py View File

@@ -1,7 +1,7 @@
import _pickle import _pickle
import os import os


import numpy as np
import torch


from fastNLP.loader.base_loader import BaseLoader from fastNLP.loader.base_loader import BaseLoader
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
@@ -30,7 +30,7 @@ class EmbedLoader(BaseLoader):
for line in f: for line in f:
line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) line = list(filter(lambda w: len(w)>0, line.strip().split(' ')))
if len(line) > 0: if len(line) > 0:
emb[line[0]] = np.array(list(map(float, line[1:])))
emb[line[0]] = torch.Tensor(list(map(float, line[1:])))
return emb return emb
@staticmethod @staticmethod
@@ -62,8 +62,8 @@ class EmbedLoader(BaseLoader):
# If the embedding pickle exists, load it and return. # If the embedding pickle exists, load it and return.
if os.path.exists(emb_pkl): if os.path.exists(emb_pkl):
with open(emb_pkl, "rb") as f: with open(emb_pkl, "rb") as f:
embedding_np = _pickle.load(f)
return embedding_np
embedding_np, vocab = _pickle.load(f)
return embedding_np, vocab
# Otherwise, load the pre-trained embedding. # Otherwise, load the pre-trained embedding.
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) pretrain = EmbedLoader._load_pretrain(emb_file, emb_type)
if vocab is None: if vocab is None:
@@ -71,7 +71,7 @@ class EmbedLoader(BaseLoader):
vocab = Vocabulary() vocab = Vocabulary()
for w in pretrain.keys(): for w in pretrain.keys():
vocab.update(w) vocab.update(w)
embedding_np = np.random.uniform(-1, 1, size=(len(vocab), emb_dim))
embedding_np = torch.randn(len(vocab), emb_dim)
for w, v in pretrain.items(): for w, v in pretrain.items():
if len(v.shape) > 1 or emb_dim != v.shape[0]: if len(v.shape) > 1 or emb_dim != v.shape[0]:
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,)))
@@ -80,5 +80,5 @@ class EmbedLoader(BaseLoader):


# save and return the result # save and return the result
with open(emb_pkl, "wb") as f: with open(emb_pkl, "wb") as f:
_pickle.dump(embedding_np, f)
_pickle.dump((embedding_np, vocab), f)
return embedding_np, vocab return embedding_np, vocab

Loading…
Cancel
Save