|
|
@@ -1,7 +1,7 @@ |
|
|
|
import _pickle |
|
|
|
import os |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import torch |
|
|
|
|
|
|
|
from fastNLP.loader.base_loader import BaseLoader |
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
@@ -30,7 +30,7 @@ class EmbedLoader(BaseLoader): |
|
|
|
for line in f: |
|
|
|
line = list(filter(lambda w: len(w)>0, line.strip().split(' '))) |
|
|
|
if len(line) > 0: |
|
|
|
emb[line[0]] = np.array(list(map(float, line[1:]))) |
|
|
|
emb[line[0]] = torch.Tensor(list(map(float, line[1:]))) |
|
|
|
return emb |
|
|
|
|
|
|
|
@staticmethod |
|
|
@@ -62,8 +62,8 @@ class EmbedLoader(BaseLoader): |
|
|
|
# If the embedding pickle exists, load it and return. |
|
|
|
if os.path.exists(emb_pkl): |
|
|
|
with open(emb_pkl, "rb") as f: |
|
|
|
embedding_np = _pickle.load(f) |
|
|
|
return embedding_np |
|
|
|
embedding_np, vocab = _pickle.load(f) |
|
|
|
return embedding_np, vocab |
|
|
|
# Otherwise, load the pre-trained embedding. |
|
|
|
pretrain = EmbedLoader._load_pretrain(emb_file, emb_type) |
|
|
|
if vocab is None: |
|
|
@@ -71,7 +71,7 @@ class EmbedLoader(BaseLoader): |
|
|
|
vocab = Vocabulary() |
|
|
|
for w in pretrain.keys(): |
|
|
|
vocab.update(w) |
|
|
|
embedding_np = np.random.uniform(-1, 1, size=(len(vocab), emb_dim)) |
|
|
|
embedding_np = torch.randn(len(vocab), emb_dim) |
|
|
|
for w, v in pretrain.items(): |
|
|
|
if len(v.shape) > 1 or emb_dim != v.shape[0]: |
|
|
|
raise ValueError('pretrian embedding dim is {}, dismatching required {}'.format(v.shape, (emb_dim,))) |
|
|
@@ -80,5 +80,5 @@ class EmbedLoader(BaseLoader): |
|
|
|
|
|
|
|
# save and return the result |
|
|
|
with open(emb_pkl, "wb") as f: |
|
|
|
_pickle.dump(embedding_np, f) |
|
|
|
_pickle.dump((embedding_np, vocab), f) |
|
|
|
return embedding_np, vocab |