Browse Source

update reproduction

tags/v0.3.1^2
yunfan 6 years ago
parent
commit
de856fb8eb
3 changed files with 11 additions and 2 deletions
  1. +4
    -1
      fastNLP/io/embed_loader.py
  2. +1
    -1
      reproduction/Biaffine_parser/cfg.cfg
  3. +6
    -0
      reproduction/Biaffine_parser/run.py

+ 4
- 1
fastNLP/io/embed_loader.py View File

@@ -101,9 +101,12 @@ class EmbedLoader(BaseLoader):
"""
if vocab is None:
raise RuntimeError("You must provide a vocabulary.")
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim))
embedding_matrix = np.zeros(shape=(len(vocab), emb_dim), dtype=np.float32)
hit_flags = np.zeros(shape=(len(vocab),), dtype=int)
with open(emb_file, "r", encoding="utf-8") as f:
startline = f.readline()
if len(startline.split()) > 2:
f.seek(0)
for line in f:
word, vector = EmbedLoader.parse_glove_line(line)
if word in vocab:


+ 1
- 1
reproduction/Biaffine_parser/cfg.cfg View File

@@ -26,7 +26,7 @@ arc_mlp_size = 500
label_mlp_size = 100
num_label = -1
dropout = 0.3
encoder="transformer"
encoder="var-lstm"
use_greedy_infer=false

[optim]


+ 6
- 0
reproduction/Biaffine_parser/run.py View File

@@ -10,9 +10,13 @@ from fastNLP.core.trainer import Trainer
from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.tester import Tester
from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.io.model_io import ModelLoader
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.io.model_io import ModelSaver
from fastNLP.io.dataset_loader import ConllxDataLoader
from fastNLP.api.processor import *
from fastNLP.io.embed_loader import EmbedLoader
@@ -156,6 +160,8 @@ print('test len {}'.format(len(test_data)))
def train(path):
# test saving pipeline
save_pipe(path)
embed = EmbedLoader.fast_load_embedding(model_args['word_emb_dim'], emb_file_name, word_v)
embed = torch.tensor(embed, dtype=torch.float32)

# embed = EmbedLoader.fast_load_embedding(emb_dim=model_args['word_emb_dim'], emb_file=emb_file_name, vocab=word_v)
# embed = torch.tensor(embed, dtype=torch.float32)


Loading…
Cancel
Save