Browse Source

1.增加EvaluateCallback实现在除dev以外的数据集验证的需求; 2.StaticEmbedding增加一个only_trian_min_freq选项

tags/v0.4.10
yh 5 years ago
parent
commit
58d7742b66
6 changed files with 103 additions and 26 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +80
    -15
      fastNLP/core/callback.py
  3. +1
    -1
      fastNLP/core/trainer.py
  4. +8
    -4
      fastNLP/embeddings/bert_embedding.py
  5. +9
    -2
      fastNLP/embeddings/static_embedding.py
  6. +4
    -4
      fastNLP/io/file_utils.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -14,6 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
"""
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC
from .callback import EvaluateCallback, FitlogCallback, SaveModelCallback
from .const import Const
from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder


+ 80
- 15
fastNLP/core/callback.py View File

@@ -57,6 +57,7 @@ __all__ = [
"FitlogCallback",
"LRScheduler",
"ControlC",
"EvaluateCallback",
"CallbackException",
"EarlyStopError"
@@ -504,10 +505,9 @@ class FitlogCallback(Callback):
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。

:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过
dict的方式传入。如果仅传入DataSet, 则被命名为test
:param ~fastNLP.Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test`
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。
:param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。
:param int verbose: 是否在终端打印evaluation的结果,0不打印。
@@ -521,20 +521,23 @@ class FitlogCallback(Callback):
self._log_exception = log_exception
assert isinstance(log_loss_every, int) and log_loss_every>=0
if tester is not None:
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed."
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data."
if data is not None:
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed."
setattr(tester, 'verbose', 0)
self.testers['test'] = tester
if isinstance(tester, dict):
for name, test in tester.items():
if not isinstance(test, Tester):
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.")
self.testers['tester-' + name] = test
if isinstance(tester, Tester):
self.testers['tester-test'] = tester
for tester in self.testers.values():
setattr(tester, 'verbose', 0)

if isinstance(data, dict):
for key, value in data.items():
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}."
for key, value in data.items():
self.datasets[key] = value
self.datasets['data-' + key] = value
elif isinstance(data, DataSet):
self.datasets['test'] = data
self.datasets['data-test'] = data
elif data is not None:
raise TypeError("data receives dict[DataSet] or DataSet object.")
@@ -548,8 +551,11 @@ class FitlogCallback(Callback):
if len(self.datasets) > 0:
for key, data in self.datasets.items():
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics,
verbose=0)
tester = Tester(data=data, model=self.model,
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
metrics=self.trainer.metrics,
verbose=0,
use_tqdm=self.trainer.use_tqdm)
self.testers[key] = tester
fitlog.add_progress(total_steps=self.n_steps)
@@ -589,6 +595,65 @@ class FitlogCallback(Callback):
fitlog.add_other(repr(exception), name='except_info')


class EvaluateCallback(Callback):
"""
别名: :class:`fastNLP.EvaluateCallback` :class:`fastNLP.core.callback.EvaluateCallback`

该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。

:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个
DataSet请通过dict的方式传入。
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。
"""

def __init__(self, data=None, tester=None):
super().__init__()
self.datasets = {}
self.testers = {}
if tester is not None:
if isinstance(tester, dict):
for name, test in tester.items():
if not isinstance(test, Tester):
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.")
self.testers['tester-' + name] = test
if isinstance(tester, Tester):
self.testers['tester-test'] = tester
for tester in self.testers.values():
setattr(tester, 'verbose', 0)

if isinstance(data, dict):
for key, value in data.items():
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}."
for key, value in data.items():
self.datasets['data-' + key] = value
elif isinstance(data, DataSet):
self.datasets['data-test'] = data
elif data is not None:
raise TypeError("data receives dict[DataSet] or DataSet object.")

def on_train_begin(self):
if len(self.datasets) > 0and self.trainer.dev_data is None:
raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.")

if len(self.datasets) > 0:
for key, data in self.datasets.items():
tester = Tester(data=data, model=self.model,
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
metrics=self.trainer.metrics, verbose=0,
use_tqdm=self.trainer.use_tqdm)
self.testers[key] = tester

def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
if len(self.testers) > 0:
for key, tester in self.testers.items():
try:
eval_result = tester.test()
self.pbar.write("Evaluation on {}:".format(key))
self.pbar.write(tester._format_eval_results(eval_result))
except Exception:
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))


class LRScheduler(Callback):
"""
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler`


+ 1
- 1
fastNLP/core/trainer.py View File

@@ -690,7 +690,7 @@ class Trainer(object):
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
self.n_steps) + \
self.tester._format_eval_results(eval_res)
pbar.write(eval_str + '\n')


+ 8
- 4
fastNLP/embeddings/bert_embedding.py View File

@@ -74,7 +74,7 @@ class BertEmbedding(ContextualEmbedding):

self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls, auto_truncate=auto_truncate)
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2)

self.requires_grad = requires_grad
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size
@@ -209,7 +209,7 @@ class BertWordPieceEncoder(nn.Module):

class _WordBertModel(nn.Module):
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first',
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False):
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False, min_freq=2):
super().__init__()

self.tokenzier = BertTokenizer.from_pretrained(model_dir)
@@ -238,9 +238,12 @@ class _WordBertModel(nn.Module):
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的
found_count = 0
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
if '[sep]' in vocab:
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.")
if "[CLS]" in vocab:
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin "
"and end of the sentence automatically.")
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin"
" and end.")
for word, index in vocab:
if index == vocab.padding_idx: # pad是个特殊的符号
word = '[PAD]'
@@ -250,7 +253,8 @@ class _WordBertModel(nn.Module):
if len(word_pieces)==1:
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面
word_piece_dict[word] = 1 # 新增一个值
if vocab.word_count[word]>=min_freq and not vocab._is_word_no_create_entry(word): #出现次数大于这个次数才新增
word_piece_dict[word] = 1 # 新增一个值
continue
for word_piece in word_pieces:
word_piece_dict[word_piece] = 1


+ 9
- 2
fastNLP/embeddings/static_embedding.py View File

@@ -54,7 +54,7 @@ class StaticEmbedding(TokenEmbedding):
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True,
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1):
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

# 得到cache_path
@@ -73,7 +73,7 @@ class StaticEmbedding(TokenEmbedding):
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

# 缩小vocab
# 根据min_freq缩小vocab
truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq)
if truncate_vocab:
truncated_vocab = deepcopy(vocab)
@@ -88,6 +88,13 @@ class StaticEmbedding(TokenEmbedding):
if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq:
truncated_vocab.add_word_lst([word]*(min_freq-word_count),
no_create_entry=truncated_vocab._is_word_no_create_entry(word))

# 只限制在train里面的词语使用min_freq筛选
if kwargs.get('only_train_min_freq', False):
for word in truncated_vocab.word_count.keys():
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq:
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]),
no_create_entry=True)
truncated_vocab.build_vocab()
truncated_words_to_words = torch.arange(len(vocab)).long()
for word, index in vocab:


+ 4
- 4
fastNLP/io/file_utils.py View File

@@ -307,15 +307,15 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
if not cache_path.exists():
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
fd, temp_filename = tempfile.mkstemp()
print("%s not found in cache, downloading to %s" % (url, temp_filename))

# GET file object
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"})
if req.status_code == 200:
content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=1)
fd, temp_filename = tempfile.mkstemp()
print("%s not found in cache, downloading to %s" % (url, temp_filename))

with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024 * 16):
if chunk: # filter out keep-alive new chunks
@@ -373,7 +373,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
os.remove(temp_filename)
return get_filepath(cache_path)
else:
raise HTTPError(f"Fail to download from {url}.")
raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.")


def unzip_file(file: Path, to: Path):


Loading…
Cancel
Save