diff --git a/reproduction/Summarization/Baseline/data/dataloader.py b/reproduction/Summarization/Baseline/data/dataloader.py
index 57702904..47cd0856 100644
--- a/reproduction/Summarization/Baseline/data/dataloader.py
+++ b/reproduction/Summarization/Baseline/data/dataloader.py
@@ -56,7 +56,7 @@ class SummarizationLoader(JsonLoader):
return ds
- def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab=True):
+ def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
"""
:param paths: dict path for each dataset
:param vocab_size: int max_size for vocab
@@ -65,7 +65,7 @@ class SummarizationLoader(JsonLoader):
:param doc_max_timesteps: int max sentence number of the document
:param domain: bool build vocab for publication, use 'X' for unknown
:param tag: bool build vocab for tag, use 'X' for unknown
- :param load_vocab: bool build vocab (False) or load vocab (True)
+ :param load_vocab_file: bool build vocab (False) or load vocab (True)
:return: DataBundle
datasets: dict keys correspond to the paths dict
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
@@ -146,7 +146,7 @@ class SummarizationLoader(JsonLoader):
train_ds = datasets[key]
vocab_dict = {}
- if load_vocab == False:
+ if load_vocab_file == False:
logger.info("[INFO] Build new vocab from training dataset!")
if train_ds == None:
raise ValueError("Lack train file to build vocabulary!")
diff --git a/reproduction/Summarization/Baseline/tools/data.py b/reproduction/Summarization/Baseline/tools/data.py
index f7bbaddd..0cbfbb06 100644
--- a/reproduction/Summarization/Baseline/tools/data.py
+++ b/reproduction/Summarization/Baseline/tools/data.py
@@ -36,8 +36,8 @@ import pickle
from nltk.tokenize import sent_tokenize
-import utils
-from logger import *
+import tools.utils
+from tools.logger import *
# and are used in the data files to segment the abstracts into sentences. They don't receive vocab ids.
SENTENCE_START = ''
@@ -313,7 +313,8 @@ class Example(object):
for sent in article_sents:
article_words = sent.split()
self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding
- self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token
+ # self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token
+ self.enc_sent_input.append([vocab.word2id(w.lower()) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token
self._pad_encoder_input(vocab.word2id('[PAD]'))
# Store the original strings
diff --git a/reproduction/Summarization/Baseline/train.py b/reproduction/Summarization/Baseline/train.py
index c3a92f67..b3170307 100644
--- a/reproduction/Summarization/Baseline/train.py
+++ b/reproduction/Summarization/Baseline/train.py
@@ -29,7 +29,7 @@ import torch.nn
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
-sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/')
+sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/')
from fastNLP.core.const import Const