Browse Source

Merge pull request #50 from h00Jiang/master

fix a bug (when restore the pickle_file , cannot restore dev.pkl)
tags/v0.1.0
Yige XU GitHub 6 years ago
parent
commit
9dc32f68a7
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions
  1. +2
    -1
      fastNLP/core/preprocess.py
  2. +1
    -1
      fastNLP/modules/encoder/embedding.py

+ 2
- 1
fastNLP/core/preprocess.py View File

@@ -59,7 +59,6 @@ class BasePreprocess(object):

def run(self, train_dev_data, test_data=None, pickle_path="./", train_dev_split=0, cross_val=False, n_fold=10):
"""Main preprocessing pipeline.

:param train_dev_data: three-level list, with either single label or multiple labels in a sample.
:param test_data: three-level list, with either single label or multiple labels in a sample. (optional)
:param pickle_path: str, the path to save the pickle files.
@@ -98,6 +97,8 @@ class BasePreprocess(object):
save_pickle(data_train, pickle_path, "data_train.pkl")
else:
data_train = load_pickle(pickle_path, "data_train.pkl")
if pickle_exist(pickle_path, "data_dev.pkl"):
data_dev = load_pickle(pickle_path, "data_dev.pkl")
else:
# cross_val is True
if not pickle_exist(pickle_path, "data_train_0.pkl"):


+ 1
- 1
fastNLP/modules/encoder/embedding.py View File

@@ -15,7 +15,7 @@ class Embedding(nn.Module):
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)
if init_emb:
if init_emb is not None:
self.embed.weight = nn.Parameter(init_emb)
self.dropout = nn.Dropout(dropout)



Loading…
Cancel
Save