Browse Source

1. callback中增加GradientClip; 2.Trainer中取消_print_train()和_tqdm_train(),全部并入了_train()

tags/v0.3.0^2
yh 5 years ago
parent
commit
400552971c
7 changed files with 212 additions and 121 deletions
  1. +64
    -27
      fastNLP/api/api.py
  2. +77
    -10
      fastNLP/core/callback.py
  3. +29
    -68
      fastNLP/core/trainer.py
  4. +27
    -0
      fastNLP/core/utils.py
  5. +2
    -2
      reproduction/chinese_word_segment/models/cws_model.py
  6. +1
    -1
      reproduction/chinese_word_segment/process/cws_processor.py
  7. +12
    -13
      reproduction/chinese_word_segment/train_context.py

+ 64
- 27
fastNLP/api/api.py View File

@@ -23,7 +23,7 @@ from fastNLP.api.processor import IndexerProcessor


# TODO add pretrain urls # TODO add pretrain urls
model_urls = { model_urls = {
'cws': "http://123.206.98.91:8888/download/cws_crf-69e357c9.pkl"
} }




@@ -139,6 +139,12 @@ class POS(API):


class CWS(API): class CWS(API):
def __init__(self, model_path=None, device='cpu'): def __init__(self, model_path=None, device='cpu'):
"""
中文分词高级接口。

:param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型
:param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。
"""
super(CWS, self).__init__() super(CWS, self).__init__()
if model_path is None: if model_path is None:
model_path = model_urls['cws'] model_path = model_urls['cws']
@@ -146,7 +152,13 @@ class CWS(API):
self.load(model_path, device) self.load(model_path, device)


def predict(self, content): def predict(self, content):
"""
分词接口。


:param content: str或List[str], 例如: "中文分词很重要!", 返回的结果是"中文 分词 很 重要 !"。 如果传入的为List[str],比如
[ "中文分词很重要!", ...], 返回的结果["中文 分词 很 重要 !", ...]。
:return: str或List[str], 根据输入的的类型决定。
"""
if not hasattr(self, 'pipeline'): if not hasattr(self, 'pipeline'):
raise ValueError("You have to load model first.") raise ValueError("You have to load model first.")


@@ -162,7 +174,10 @@ class CWS(API):
dataset.add_field('raw_sentence', sentence_list) dataset.add_field('raw_sentence', sentence_list)


# 3. 使用pipeline # 3. 使用pipeline
self.pipeline(dataset)
pipeline = self.pipeline.pipeline[:-3] + self.pipeline.pipeline[-2:]
pp = Pipeline(pipeline)
pp(dataset)
# self.pipeline(dataset)


output = dataset['output'].content output = dataset['output'].content
if isinstance(content, str): if isinstance(content, str):
@@ -171,10 +186,28 @@ class CWS(API):
return output return output


def test(self, filepath): def test(self, filepath):

tag_proc = self._dict['tag_indexer']
"""
传入一个分词文件路径,返回该数据集上分词f1, precision, recall。
分词文件应该为:
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct

1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
以空行分割两个句子,有内容的每行有7列。

:param filepath: str, 文件路径路径。
:return: float, float, float. 分别f1, precision, recall.
"""
tag_proc = self._dict['tag_proc']
cws_model = self.pipeline.pipeline[-2].model cws_model = self.pipeline.pipeline[-2].model
pipeline = self.pipeline.pipeline[:5]
pipeline = self.pipeline.pipeline[:-2]


pipeline.insert(1, tag_proc) pipeline.insert(1, tag_proc)
pp = Pipeline(pipeline) pp = Pipeline(pipeline)
@@ -185,12 +218,16 @@ class CWS(API):
te_dataset = reader.load(filepath) te_dataset = reader.load(filepath)
pp(te_dataset) pp(te_dataset)


batch_size = 64
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes')
f1 = round(f1 * 100, 2)
pre = round(pre * 100, 2)
rec = round(rec * 100, 2)
from fastNLP.core.tester import Tester
from fastNLP.core.metrics import BMESF1PreRecMetric

tester = Tester(data=te_dataset, model=cws_model, metrics=BMESF1PreRecMetric(target='target'), batch_size=64,
verbose=0)
eval_res = tester.test()

f1 = eval_res['BMESF1PreRecMetric']['f']
pre = eval_res['BMESF1PreRecMetric']['pre']
rec = eval_res['BMESF1PreRecMetric']['rec']
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec)) # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))


return f1, pre, rec return f1, pre, rec
@@ -301,25 +338,25 @@ class Analyzer:




if __name__ == "__main__": if __name__ == "__main__":
pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl'
pos = POS(pos_model_path, device='cpu')
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(pos.test("/home/zyfeng/data/sample.conllx"))
# pos_model_path = '/home/zyfeng/fastnlp/reproduction/pos_tag_model/model_pp.pkl'
# pos = POS(pos_model_path, device='cpu')
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# '那么这款无人机到底有多厉害?']
# print(pos.test("/home/zyfeng/data/sample.conllx"))
# print(pos.predict(s)) # print(pos.predict(s))


# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
# cws = CWS(device='cpu')
# s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# '那么这款无人机到底有多厉害?']
# print(cws.test('/Users/yh/Desktop/test_data/cws_test.conll'))
# print(cws.predict(s))
cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
cws = CWS(model_path=cws_model_path, device='cuda:0')
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# print(cws.test('/home/hyan/ctb3/test.conllx'))
print(cws.predict(s))


# parser = Parser(device='cpu') # parser = Parser(device='cpu')
# print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll')) # print(parser.test('/Users/yh/Desktop/test_data/parser_test2.conll'))
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# '那么这款无人机到底有多厉害?']
# print(parser.predict(s)) # print(parser.predict(s))

+ 77
- 10
fastNLP/core/callback.py View File

@@ -8,38 +8,76 @@ class Callback(object):
def __init__(self): def __init__(self):
super(Callback, self).__init__() super(Callback, self).__init__()


def before_train(self, *args):
def before_train(self):
# before the main training loop # before the main training loop
pass pass


def before_epoch(self, *args):
def before_epoch(self, cur_epoch, total_epoch):
# at the beginning of each epoch # at the beginning of each epoch
pass pass


def before_batch(self, *args):
def before_batch(self, batch_x, batch_y, indices):
# at the beginning of each step/mini-batch # at the beginning of each step/mini-batch
pass pass


def before_loss(self, *args):
def before_loss(self, batch_y, predict_y):
# after data_forward, and before loss computation # after data_forward, and before loss computation
pass pass


def before_backward(self, *args):
def before_backward(self, loss, model):
# after loss computation, and before gradient backward # after loss computation, and before gradient backward
pass pass


def after_backward(self, model):
pass

def after_step(self, optimizer):
pass

def after_batch(self, *args): def after_batch(self, *args):
# at the end of each step/mini-batch # at the end of each step/mini-batch
pass pass


def after_epoch(self, *args):
# at the end of each epoch
def after_valid(self, eval_result, metric_key, optimizer):
"""
每次执行验证机的evaluation后会调用。传入eval_result

:param eval_result: Dict[str: Dict[str: float]], evaluation的结果
:param metric_key: str
:param optimizer:
:return:
"""
pass

def after_epoch(self, cur_epoch, n_epoch, optimizer):
"""
每个epoch结束将会调用该方法

:param cur_epoch: int, 当前的batch。从1开始。
:param n_epoch: int, 总的batch数
:param optimizer: 传入Trainer的optimizer。
:return:
"""
pass pass


def after_train(self, *args):
# after training loop
def after_train(self, model):
"""
训练结束,调用该方法

:param model: nn.Module, 传入Trainer的模型
:return:
"""
pass pass


def on_exception(self, exception, model, indices):
"""
当训练过程出现异常,会触发该方法
:param exception: 某种类型的Exception,比如KeyboardInterrupt等
:param model: 传入Trainer的模型
:param indices: 当前batch的index
:return:
"""
pass


def transfer(func): def transfer(func):
"""装饰器,将对CallbackManager的调用转发到各个Callback子类. """装饰器,将对CallbackManager的调用转发到各个Callback子类.
@@ -111,7 +149,7 @@ class CallbackManager(Callback):
pass pass


@transfer @transfer
def after_step(self):
def after_step(self, optimizer):
pass pass


@transfer @transfer
@@ -169,6 +207,35 @@ class EchoCallback(Callback):
def after_train(self): def after_train(self):
print("after_train") print("after_train")


class GradientClipCallback(Callback):
def __init__(self, parameters=None, clip_value=1, clip_type='norm'):
"""
每次backward前,将parameter的gradient clip到某个范围。

:param parameters: None, torch.Tensor或List[torch.Tensor], 一般通过model.parameters()获得。如果为None则默认对Trainer
的model中所有参数进行clip
:param clip_value: float, 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数
:param clip_type: str, 支持'norm', 'value'两种。
(1) 'norm', 将gradient的norm rescale到[-clip_value, clip_value]
(2) 'value', 将gradient限制在[-clip_value, clip_value], 小于-clip_value的gradient被赋值为-clip_value; 大于
clip_value的gradient被赋值为clip_value.
"""
super().__init__()

from torch import nn
if clip_type == 'norm':
self.clip_fun = nn.utils.clip_grad_norm_
elif clip_type == 'value':
self.clip_fun = nn.utils.clip_grad_value_
else:
raise ValueError("Only supports `norm` or `value` right now.")
self.parameters = parameters
self.clip_value = clip_value

def after_backward(self, model):
self.clip_fun(model.parameters(), self.clip_value)




if __name__ == "__main__": if __name__ == "__main__":
manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()]) manager = CallbackManager(env={"n_epoch": 3}, callbacks=[DummyCallback(), DummyCallback()])


+ 29
- 68
fastNLP/core/trainer.py View File

@@ -7,7 +7,11 @@ import numpy as np
import torch import torch
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn from torch import nn
from tqdm.autonotebook import tqdm

try:
from tqdm.autonotebook import tqdm
except:
from fastNLP.core.utils import pseudo_tqdm as tqdm


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.callback import CallbackManager from fastNLP.core.callback import CallbackManager
@@ -108,7 +112,7 @@ class Trainer(object):
self.use_cuda = bool(use_cuda) self.use_cuda = bool(use_cuda)
self.save_path = save_path self.save_path = save_path
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every)
self.validate_every = int(validate_every) if validate_every!=0 else -1
self.best_metric_indicator = None self.best_metric_indicator = None
self.sampler = sampler self.sampler = sampler
self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks)
@@ -119,11 +123,7 @@ class Trainer(object):
self.optimizer = optimizer.construct_from_pytorch(self.model.parameters()) self.optimizer = optimizer.construct_from_pytorch(self.model.parameters())


self.use_tqdm = use_tqdm self.use_tqdm = use_tqdm
if self.use_tqdm:
tester_verbose = 0
self.print_every = abs(self.print_every)
else:
tester_verbose = 1
self.print_every = abs(self.print_every)


if self.dev_data is not None: if self.dev_data is not None:
self.tester = Tester(model=self.model, self.tester = Tester(model=self.model,
@@ -131,7 +131,7 @@ class Trainer(object):
metrics=self.metrics, metrics=self.metrics,
batch_size=self.batch_size, batch_size=self.batch_size,
use_cuda=self.use_cuda, use_cuda=self.use_cuda,
verbose=tester_verbose)
verbose=0)


self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
@@ -199,10 +199,7 @@ class Trainer(object):
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)


self.callback_manager.before_train() self.callback_manager.before_train()
if self.use_tqdm:
self._tqdm_train()
else:
self._print_train()
self._train()
self.callback_manager.after_train(self.model) self.callback_manager.after_train(self.model)


if self.dev_data is not None: if self.dev_data is not None:
@@ -225,12 +222,16 @@ class Trainer(object):


return results return results


def _tqdm_train(self):
def _train(self):
if not self.use_tqdm:
from fastNLP.core.utils import pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0 self.step = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)
start = time.time()
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False)
total_steps = data_iterator.num_batches * self.n_epochs total_steps = data_iterator.num_batches * self.n_epochs
with tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0 avg_loss = 0
for epoch in range(1, self.n_epochs+1): for epoch in range(1, self.n_epochs+1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
@@ -265,18 +266,26 @@ class Trainer(object):
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if (self.step+1) % self.print_every == 0: if (self.step+1) % self.print_every == 0:
pbar.set_postfix_str("loss:{0:<6.5f}".format(avg_loss / self.print_every))
if self.use_tqdm:
print_output = "loss:{0:<6.5f}".format(avg_loss / self.print_every)
pbar.update(self.print_every)
else:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, avg_loss, diff)
pbar.set_postfix_str(print_output)
avg_loss = 0 avg_loss = 0
pbar.update(self.print_every)
self.step += 1 self.step += 1
# do nothing # do nothing
self.callback_manager.after_batch() self.callback_manager.after_batch()


if ((self.validate_every > 0 and self.step % self.validate_every == 0) or if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % self.batch_size == len(data_iterator))) \
(self.validate_every < 0 and self.step % len(data_iterator)) == 0) \
and self.dev_data is not None: and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step) eval_res = self._do_validation(epoch=epoch, step=self.step)
eval_str = "Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, total_steps) + \
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
total_steps) + \
self.tester._format_eval_results(eval_res) self.tester._format_eval_results(eval_res)
pbar.write(eval_str) pbar.write(eval_str)


@@ -292,54 +301,6 @@ class Trainer(object):
self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer) self.callback_manager.after_epoch(epoch, self.n_epochs, self.optimizer)
pbar.close() pbar.close()


def _print_train(self):
epoch = 1
start = time.time()
while epoch <= self.n_epochs:
self.callback_manager.before_epoch()

data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler,
as_numpy=False)

for batch_x, batch_y in data_iterator:
self.callback_manager.before_batch()
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
prediction = self._data_forward(self.model, batch_x)

self.callback_manager.before_loss()
loss = self._compute_loss(prediction, batch_y)

self.callback_manager.before_backward()
self._grad_backward(loss)
self._update()

self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters():
if param.requires_grad:
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time()
diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
epoch, self.step, loss.data, diff)
print(print_output)

if (self.validate_every > 0 and self.step % self.validate_every == 0 and
self.dev_data is not None):
self._do_validation(epoch=epoch, step=self.step)

self.step += 1
self.callback_manager.after_batch()

# validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0:
self._do_validation(epoch=epoch, step=self.step)
epoch += 1
self.callback_manager.after_epoch()

def _do_validation(self, epoch, step): def _do_validation(self, epoch, step):
res = self.tester.test() res = self.tester.test()
for name, metric in res.items(): for name, metric in res.items():


+ 27
- 0
fastNLP/core/utils.py View File

@@ -430,3 +430,30 @@ def seq_mask(seq_len, max_len):
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len] return torch.gt(seq_len, seq_range) # [batch_size, max_len]


class pseudo_tqdm:
"""
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
"""

def __init__(self, **kwargs):
pass

def write(self, info):
print(info)

def set_postfix_str(self, info):
print(info)

def __getattr__(self, item):
def pass_func(*args, **kwargs):
pass

return pass_func

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
del self

+ 2
- 2
reproduction/chinese_word_segment/models/cws_model.py View File

@@ -65,7 +65,7 @@ class CWSBiLSTMEncoder(BaseModel):


x_tensor = self.char_embedding(chars) x_tensor = self.char_embedding(chars)


if not bigrams is None:
if hasattr(self, 'bigram_embedding'):
bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1) bigram_tensor = self.bigram_embedding(bigrams).view(batch_size, max_len, -1)
x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2) x_tensor = torch.cat([x_tensor, bigram_tensor], dim=2)
x_tensor = self.embedding_drop(x_tensor) x_tensor = self.embedding_drop(x_tensor)
@@ -185,5 +185,5 @@ class CWSBiLSTMCRF(BaseModel):
feats = self.decoder_model(feats) feats = self.decoder_model(feats)
probs = self.crf.viterbi_decode(feats, masks, get_score=False) probs = self.crf.viterbi_decode(feats, masks, get_score=False)


return {'pred': probs}
return {'pred': probs, 'seq_lens':seq_lens}



+ 1
- 1
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -378,7 +378,7 @@ class BMES2OutputProcessor(Processor):
prediction为BSEMS,会被认为是SSSSS. prediction为BSEMS,会被认为是SSSSS.


""" """
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output',
def __init__(self, chars_field_name='chars_list', tag_field_name='pred', new_added_field_name='output',
b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3): b_idx = 0, m_idx = 1, e_idx = 2, s_idx = 3):
""" """




+ 12
- 13
reproduction/chinese_word_segment/train_context.py View File

@@ -11,7 +11,6 @@ from reproduction.chinese_word_segment.process.cws_processor import InputTargetP
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF


from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1


ds_name = 'msr' ds_name = 'msr'


@@ -39,8 +38,6 @@ bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='b


seq_len_proc = SeqLenProcessor('chars') seq_len_proc = SeqLenProcessor('chars')


input_target_proc = InputTargetProcessor(input_fields=['chars', 'bigrams', 'seq_lens', "target"],
target_fields=['target', 'seq_lens'])
# 2. 使用processor # 2. 使用processor
fs2hs_proc(tr_dataset) fs2hs_proc(tr_dataset)


@@ -63,15 +60,15 @@ char_vocab_proc(dev_dataset)
bigram_vocab_proc(dev_dataset) bigram_vocab_proc(dev_dataset)
seq_len_proc(dev_dataset) seq_len_proc(dev_dataset)


input_target_proc(tr_dataset)
input_target_proc(dev_dataset)
dev_dataset.set_input('target')
tr_dataset.set_input('target')



print("Finish preparing data.") print("Finish preparing data.")


# 3. 得到数据集可以用于训练了 # 3. 得到数据集可以用于训练了
# TODO pretrain的embedding是怎么解决的? # TODO pretrain的embedding是怎么解决的?


import torch
from torch import optim from torch import optim




@@ -79,8 +76,8 @@ tag_size = tag_proc.tag_size


cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100,
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), bigram_vocab_num=bigram_vocab_proc.get_vocab_size(),
bigram_embed_dim=100, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=0.2,
bigram_embed_dim=30, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=0.3,
num_layers=1, tag_size=tag_size) num_layers=1, tag_size=tag_size)
cws_model.cuda() cws_model.cuda()


@@ -108,7 +105,7 @@ pp.add_processor(bigram_proc)
pp.add_processor(char_vocab_proc) pp.add_processor(char_vocab_proc)
pp.add_processor(bigram_vocab_proc) pp.add_processor(bigram_vocab_proc)
pp.add_processor(seq_len_proc) pp.add_processor(seq_len_proc)
pp.add_processor(input_target_proc)
# pp.add_processor(input_target_proc)


# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) # te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_filename = '/home/hyan/ctb3/test.conllx' te_filename = '/home/hyan/ctb3/test.conllx'
@@ -142,7 +139,7 @@ from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor


model_proc = ModelProcessor(cws_model) model_proc = ModelProcessor(cws_model)
output_proc = BMES2OutputProcessor()
output_proc = BMES2OutputProcessor(tag_field_name='pred')


pp = Pipeline() pp = Pipeline()
pp.add_processor(fs2hs_proc) pp.add_processor(fs2hs_proc)
@@ -158,9 +155,11 @@ pp.add_processor(output_proc)




# TODO 这里貌似需要区分test pipeline与infer pipeline # TODO 这里貌似需要区分test pipeline与infer pipeline

infer_context_dict = {'pipeline': pp}
# torch.save(infer_context_dict, 'models/cws_crf.pkl')
import torch
import datetime
now = datetime.datetime.now()
infer_context_dict = {'pipeline': pp, 'tag_proc': tag_proc}
torch.save(infer_context_dict, 'models/cws_crf_{}_{}.pkl'.format(now.month, now.day))




# TODO 还需要考虑如何替换回原文的问题? # TODO 还需要考虑如何替换回原文的问题?


Loading…
Cancel
Save