[new] Update Readme.md and repair bugs with the modification of fastnlp 0.4.5tags/v0.4.10
@@ -56,7 +56,7 @@ class SummarizationLoader(JsonLoader): | |||||
return ds | return ds | ||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab=True): | |||||
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True): | |||||
""" | """ | ||||
:param paths: dict path for each dataset | :param paths: dict path for each dataset | ||||
:param vocab_size: int max_size for vocab | :param vocab_size: int max_size for vocab | ||||
@@ -65,7 +65,7 @@ class SummarizationLoader(JsonLoader): | |||||
:param doc_max_timesteps: int max sentence number of the document | :param doc_max_timesteps: int max sentence number of the document | ||||
:param domain: bool build vocab for publication, use 'X' for unknown | :param domain: bool build vocab for publication, use 'X' for unknown | ||||
:param tag: bool build vocab for tag, use 'X' for unknown | :param tag: bool build vocab for tag, use 'X' for unknown | ||||
:param load_vocab: bool build vocab (False) or load vocab (True) | |||||
:param load_vocab_file: bool build vocab (False) or load vocab (True) | |||||
:return: DataBundle | :return: DataBundle | ||||
datasets: dict keys correspond to the paths dict | datasets: dict keys correspond to the paths dict | ||||
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True) | ||||
@@ -146,7 +146,7 @@ class SummarizationLoader(JsonLoader): | |||||
train_ds = datasets[key] | train_ds = datasets[key] | ||||
vocab_dict = {} | vocab_dict = {} | ||||
if load_vocab == False: | |||||
if load_vocab_file == False: | |||||
logger.info("[INFO] Build new vocab from training dataset!") | logger.info("[INFO] Build new vocab from training dataset!") | ||||
if train_ds == None: | if train_ds == None: | ||||
raise ValueError("Lack train file to build vocabulary!") | raise ValueError("Lack train file to build vocabulary!") | ||||
@@ -36,8 +36,8 @@ import pickle | |||||
from nltk.tokenize import sent_tokenize | from nltk.tokenize import sent_tokenize | ||||
import utils | |||||
from logger import * | |||||
import tools.utils | |||||
from tools.logger import * | |||||
# <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. | # <s> and </s> are used in the data files to segment the abstracts into sentences. They don't receive vocab ids. | ||||
SENTENCE_START = '<s>' | SENTENCE_START = '<s>' | ||||
@@ -313,7 +313,8 @@ class Example(object): | |||||
for sent in article_sents: | for sent in article_sents: | ||||
article_words = sent.split() | article_words = sent.split() | ||||
self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding | self.enc_sent_len.append(len(article_words)) # store the length after truncation but before padding | ||||
self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
# self.enc_sent_input.append([vocab.word2id(w) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
self.enc_sent_input.append([vocab.word2id(w.lower()) for w in article_words]) # list of word ids; OOVs are represented by the id for UNK token | |||||
self._pad_encoder_input(vocab.word2id('[PAD]')) | self._pad_encoder_input(vocab.word2id('[PAD]')) | ||||
# Store the original strings | # Store the original strings | ||||
@@ -29,7 +29,7 @@ import torch.nn | |||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | ||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | ||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP/') | |||||
sys.path.append('/remote-home/dqwang/FastNLP/fastNLP_brxx/') | |||||
from fastNLP.core.const import Const | from fastNLP.core.const import Const | ||||
@@ -39,6 +39,54 @@ FastNLP中实现的模型包括: | |||||
### Evaluation | |||||
#### FastRougeMetric | |||||
FastRougeMetric使用python实现的ROUGE非官方库来实现在训练过程中快速计算rouge近似值。 | |||||
源代码可见 [https://github.com/pltrdy/rouge](https://github.com/pltrdy/rouge) | |||||
在fastNLP中,该方法已经被包装成Metric.py中的FastRougeMetric类以供trainer直接使用。 | |||||
需要事先使用pip安装该rouge库。 | |||||
pip install rouge | |||||
**注意:由于实现细节的差异,该结果和官方ROUGE结果存在1-2个点的差异,仅可作为训练过程优化趋势的粗略估计。** | |||||
#### PyRougeMetric | |||||
PyRougeMetric 使用论文 [*ROUGE: A Package for Automatic Evaluation of Summaries*](https://www.aclweb.org/anthology/W04-1013) 提供的官方ROUGE 1.5.5评测库。 | |||||
由于原本的ROUGE使用perl解释器,[pyrouge](https://github.com/bheinzerling/pyrouge)对其进行了python包装,而PyRougeMetric将其进一步包装为trainer可以直接使用的Metric类。 | |||||
为了使用ROUGE 1.5.5,需要使用sudo权限安装一系列依赖库。 | |||||
1. ROUGE 本身在Ubuntu下的安装可以参考[博客](https://blog.csdn.net/Hay54/article/details/78744912) | |||||
2. 配置wordnet可参考: | |||||
```shell | |||||
$ cd ~/rouge/RELEASE-1.5.5/data/WordNet-2.0-Exceptions/ | |||||
$ ./buildExeptionDB.pl . exc WordNet-2.0.exc.db | |||||
$ cd ../ | |||||
$ ln -s WordNet-2.0-Exceptions/WordNet-2.0.exc.db WordNet-2.0.exc.db | |||||
``` | |||||
3. 安装pyrouge | |||||
```shell | |||||
$ git clone https://github.com/bheinzerling/pyrouge | |||||
$ cd pyrouge | |||||
$ python setup.py install | |||||
``` | |||||
4. 测试ROUGE安装是否正确 | |||||
```shell | |||||
$ pyrouge_set_rouge_path /absolute/path/to/ROUGE-1.5.5/directory | |||||
$ python -m pyrouge.test | |||||
``` | |||||
### Dataset_loader | ### Dataset_loader | ||||
- SummarizationLoader: 用于读取处理好的jsonl格式数据集,返回以下field | - SummarizationLoader: 用于读取处理好的jsonl格式数据集,返回以下field | ||||
@@ -56,6 +104,21 @@ FastNLP中实现的模型包括: | |||||
### Train Cmdline | |||||
#### Baseline | |||||
LSTM + Sequence Labeling | |||||
python train.py --cuda --gpu <gpuid> --sentence_encoder deeplstm --sentence_decoder seqlab --save_root <savedir> --log_root <logdir> --lr_descent --grad_clip --max_grad_norm 10 | |||||
Transformer + Sequence Labeling | |||||
python train.py --cuda --gpu <gpuid> --sentence_encoder transformer --sentence_decoder seqlab --save_root <savedir> --log_root <logdir> --lr_descent --grad_clip --max_grad_norm 10 | |||||
#### BertSum | |||||