@@ -34,7 +34,6 @@ class API: | |||||
if os.path.exists(os.path.expanduser(path)): | if os.path.exists(os.path.expanduser(path)): | ||||
_dict = torch.load(path, map_location='cpu') | _dict = torch.load(path, map_location='cpu') | ||||
else: | else: | ||||
print(os.path.expanduser(path)) | |||||
_dict = load_url(path, map_location='cpu') | _dict = load_url(path, map_location='cpu') | ||||
self.pipeline = _dict['pipeline'] | self.pipeline = _dict['pipeline'] | ||||
self._dict = _dict | self._dict = _dict | ||||
@@ -58,7 +57,7 @@ class POS(API): | |||||
def predict(self, content): | def predict(self, content): | ||||
""" | """ | ||||
:param query: list of list of str. Each string is a token(word). | |||||
:param content: list of list of str. Each string is a token(word). | |||||
:return answer: list of list of str. Each string is a tag. | :return answer: list of list of str. Each string is a tag. | ||||
""" | """ | ||||
if not hasattr(self, 'pipeline'): | if not hasattr(self, 'pipeline'): | ||||
@@ -183,99 +182,64 @@ class CWS(API): | |||||
return f1, pre, rec | return f1, pre, rec | ||||
class Parser(API): | |||||
def __init__(self, model_path=None, device='cpu'): | |||||
super(Parser, self).__init__() | |||||
if model_path is None: | |||||
model_path = model_urls['parser'] | |||||
class Analyzer: | |||||
def __init__(self, seg=True, pos=True, parser=True, device='cpu'): | |||||
self.load(model_path, device) | |||||
self.seg = seg | |||||
self.pos = pos | |||||
self.parser = parser | |||||
def predict(self, content): | |||||
if not hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first.") | |||||
if self.seg: | |||||
self.cws = CWS(device=device) | |||||
if self.pos: | |||||
self.pos = POS(device=device) | |||||
if parser: | |||||
self.parser = None | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(content, str): | |||||
sentence_list.append(content) | |||||
elif isinstance(content, list): | |||||
sentence_list = content | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('words', sentence_list) | |||||
# dataset.add_field('tag', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
for ins in dataset: | |||||
ins['heads'] = ins['heads'].tolist() | |||||
return dataset['heads'], dataset['labels'] | |||||
def predict(self, content): | |||||
output_dict = {} | |||||
if self.seg: | |||||
seg_output = self.cws.predict(content) | |||||
output_dict['seg'] = seg_output | |||||
if self.pos: | |||||
pos_output = self.pos.predict(content) | |||||
output_dict['pos'] = pos_output | |||||
if self.parser: | |||||
parser_output = self.parser.predict(content) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict | |||||
def test(self, filepath): | def test(self, filepath): | ||||
data = ConllxDataLoader().load(filepath) | |||||
ds = DataSet() | |||||
for ins1, ins2 in zip(add_seg_tag(data), data): | |||||
ds.append(Instance(words=ins1[0], tag=ins1[1], | |||||
gold_words=ins2[0], gold_pos=ins2[1], | |||||
gold_heads=ins2[2], gold_head_tags=ins2[3])) | |||||
pp = self.pipeline | |||||
for p in pp: | |||||
if p.field_name == 'word_list': | |||||
p.field_name = 'gold_words' | |||||
elif p.field_name == 'pos_list': | |||||
p.field_name = 'gold_pos' | |||||
pp(ds) | |||||
head_cor, label_cor, total = 0, 0, 0 | |||||
for ins in ds: | |||||
head_gold = ins['gold_heads'] | |||||
head_pred = ins['heads'] | |||||
length = len(head_gold) | |||||
total += length | |||||
for i in range(length): | |||||
head_cor += 1 if head_pred[i] == head_gold[i] else 0 | |||||
uas = head_cor / total | |||||
print('uas:{:.2f}'.format(uas)) | |||||
for p in pp: | |||||
if p.field_name == 'gold_words': | |||||
p.field_name = 'word_list' | |||||
elif p.field_name == 'gold_pos': | |||||
p.field_name = 'pos_list' | |||||
return uas | |||||
output_dict = {} | |||||
if self.seg: | |||||
seg_output = self.cws.test(filepath) | |||||
output_dict['seg'] = seg_output | |||||
if self.pos: | |||||
pos_output = self.pos.test(filepath) | |||||
output_dict['pos'] = pos_output | |||||
if self.parser: | |||||
parser_output = self.parser.test(filepath) | |||||
output_dict['parser'] = parser_output | |||||
return output_dict | |||||
if __name__ == "__main__": | |||||
# 以下路径在102 | |||||
""" | |||||
pos_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/pos_crf-5e26d3b0.pkl' | |||||
pos = POS(model_path=pos_model_path, device='cpu') | |||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
#print(pos.test('../../reproduction/chinese_word_segment/new-clean.txt.conll')) | |||||
print(pos.predict(s)) | |||||
""" | |||||
""" | |||||
cws_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/cws_crf-5a8a3e66.pkl' | |||||
cws = CWS(model_path=cws_model_path, device='cuda:0') | |||||
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
if __name__ == "__main__": | |||||
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl' | |||||
# pos = POS(device='cpu') | |||||
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , | |||||
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
# '那么这款无人机到底有多厉害?'] | |||||
# print(pos.test('/Users/yh/Desktop/test_data/small_test.conll')) | |||||
# print(pos.predict(s)) | |||||
# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' | |||||
cws = CWS(device='cpu') | |||||
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' , | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | '那么这款无人机到底有多厉害?'] | ||||
#print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll')) | |||||
cws.predict(s) | |||||
""" | |||||
print(cws.test('/Users/yh/Desktop/test_data/small_test.conll')) | |||||
print(cws.predict(s)) | |||||
parser_model_path = "/home/hyan/fastNLP_models/upload-demo/upload/parser-d57cd5fc.pkl" | |||||
parser = Parser(model_path=parser_model_path, device='cuda:0') | |||||
# print(parser.test('../../reproduction/Biaffine_parser/test.conll')) | |||||
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', | |||||
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', | |||||
'那么这款无人机到底有多厉害?'] | |||||
print(parser.predict(s)) | |||||
@@ -1,47 +0,0 @@ | |||||
import torch | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP.core.batch import Batch | |||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||||
def f1(): | |||||
ds_name = 'pku' | |||||
test_dict = torch.load('models/test_context.pkl') | |||||
pp = test_dict['pipeline'] | |||||
model = test_dict['model'].cuda() | |||||
reader = NaiveCWSReader() | |||||
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/{}_raw_data/{}_raw_test.txt'.format(ds_name, ds_name, | |||||
ds_name) | |||||
te_dataset = reader.load(te_filename) | |||||
pp(te_dataset) | |||||
batch_size = 64 | |||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
pre, rec, f1 = calculate_pre_rec_f1(model, te_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, | |||||
pre * 100, | |||||
rec * 100)) | |||||
def f2(): | |||||
from fastNLP.api.api import CWS | |||||
cws = CWS('models/maml-cws.pkl') | |||||
datasets = ['msr', 'as', 'pku', 'ctb', 'ncc', 'cityu', 'ckip', 'sxu'] | |||||
for dataset in datasets: | |||||
print(dataset) | |||||
with open('/hdd/fudanNLP/CWS/others/benchmark/raw_and_gold/{}_raw.txt'.format(dataset), 'r') as f: | |||||
lines = f.readlines() | |||||
results = cws.predict(lines) | |||||
with open('/hdd/fudanNLP/CWS/others/benchmark/fastNLP_output/{}_seg.txt'.format(dataset), 'w', encoding='utf-8') as f: | |||||
for line in results: | |||||
f.write(line) | |||||
f1() |
@@ -1,245 +0,0 @@ | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||||
from fastNLP.api.processor import IndexerProcessor | |||||
from reproduction.chinese_word_segment.process.cws_processor import SpeicalSpanProcessor | |||||
from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor | |||||
from reproduction.chinese_word_segment.process.cws_processor import CWSSegAppTagProcessor | |||||
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor | |||||
from reproduction.chinese_word_segment.process.cws_processor import VocabProcessor | |||||
from reproduction.chinese_word_segment.process.cws_processor import SeqLenProcessor | |||||
from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter | |||||
from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter | |||||
from reproduction.chinese_word_segment.process.span_converter import TimeConverter | |||||
from reproduction.chinese_word_segment.process.span_converter import MixNumAlphaConverter | |||||
from reproduction.chinese_word_segment.process.span_converter import EmailConverter | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader | |||||
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp | |||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||||
ds_name = 'pku' | |||||
# tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, | |||||
# ds_name) | |||||
# dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, | |||||
# ds_name) | |||||
tr_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_train.txt'.format(ds_name, | |||||
ds_name) | |||||
dev_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, | |||||
ds_name) | |||||
reader = NaiveCWSReader() | |||||
tr_dataset = reader.load(tr_filename, cut_long_sent=True) | |||||
dev_dataset = reader.load(dev_filename) | |||||
# 1. 准备processor | |||||
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | |||||
# sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') | |||||
# sp_proc.add_span_converter(EmailConverter()) | |||||
# sp_proc.add_span_converter(MixNumAlphaConverter()) | |||||
# sp_proc.add_span_converter(AlphaSpanConverter()) | |||||
# sp_proc.add_span_converter(DigitSpanConverter()) | |||||
# sp_proc.add_span_converter(TimeConverter()) | |||||
char_proc = CWSCharSegProcessor('raw_sentence', 'chars_list') | |||||
tag_proc = CWSSegAppTagProcessor('raw_sentence', 'tags') | |||||
bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') | |||||
char_vocab_proc = VocabProcessor('chars_list') | |||||
bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4) | |||||
# 2. 使用processor | |||||
fs2hs_proc(tr_dataset) | |||||
# sp_proc(tr_dataset) | |||||
char_proc(tr_dataset) | |||||
tag_proc(tr_dataset) | |||||
bigram_proc(tr_dataset) | |||||
char_vocab_proc(tr_dataset) | |||||
bigram_vocab_proc(tr_dataset) | |||||
char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'chars', | |||||
delete_old_field=False) | |||||
bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','bigrams', | |||||
delete_old_field=True) | |||||
seq_len_proc = SeqLenProcessor('chars') | |||||
char_index_proc(tr_dataset) | |||||
bigram_index_proc(tr_dataset) | |||||
seq_len_proc(tr_dataset) | |||||
# 2.1 处理dev_dataset | |||||
fs2hs_proc(dev_dataset) | |||||
# sp_proc(dev_dataset) | |||||
char_proc(dev_dataset) | |||||
tag_proc(dev_dataset) | |||||
bigram_proc(dev_dataset) | |||||
char_index_proc(dev_dataset) | |||||
bigram_index_proc(dev_dataset) | |||||
seq_len_proc(dev_dataset) | |||||
print("Finish preparing data.") | |||||
print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(), bigram_vocab_proc.get_vocab_size())) | |||||
# 3. 得到数据集可以用于训练了 | |||||
# TODO pretrain的embedding是怎么解决的? | |||||
from reproduction.chinese_word_segment.utils import FocalLoss | |||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import BucketSampler | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
import torch | |||||
from torch import optim | |||||
import sys | |||||
from tqdm import tqdm | |||||
tag_size = tag_proc.tag_size | |||||
cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100, | |||||
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), | |||||
bigram_embed_dim=100, num_bigram_per_char=8, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, | |||||
num_layers=1, tag_size=tag_size) | |||||
cws_model.cuda() | |||||
num_epochs = 3 | |||||
loss_fn = FocalLoss(class_num=tag_size) | |||||
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) | |||||
print_every = 50 | |||||
batch_size = 32 | |||||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
num_batch_per_epoch = len(tr_dataset) // batch_size | |||||
best_f1 = 0 | |||||
best_epoch = 0 | |||||
for num_epoch in range(num_epochs): | |||||
print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) | |||||
sys.stdout.flush() | |||||
avg_loss = 0 | |||||
with tqdm(total=num_batch_per_epoch, leave=True) as pbar: | |||||
pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) | |||||
cws_model.train() | |||||
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | |||||
optimizer.zero_grad() | |||||
pred_dict = cws_model(**batch_x) # B x L x tag_size | |||||
seq_lens = pred_dict['seq_lens'] | |||||
masks = seq_lens_to_mask(seq_lens).float() | |||||
tags = batch_y['tags'].long().to(seq_lens.device) | |||||
loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), | |||||
tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | |||||
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | |||||
avg_loss += loss.item() | |||||
loss.backward() | |||||
for group in optimizer.param_groups: | |||||
for param in group['params']: | |||||
param.grad.clamp_(-5, 5) | |||||
optimizer.step() | |||||
if batch_idx % print_every == 0: | |||||
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | |||||
avg_loss = 0 | |||||
pbar.update(print_every) | |||||
tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
# 验证集 | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||||
pre*100, | |||||
rec*100)) | |||||
if best_f1<f1: | |||||
best_f1 = f1 | |||||
# 缓存最佳的parameter,可能之后会用于保存 | |||||
best_state_dict = { | |||||
key:value.clone() for key, value in | |||||
cws_model.state_dict().items() | |||||
} | |||||
best_epoch = num_epoch | |||||
cws_model.load_state_dict(best_state_dict) | |||||
# 4. 组装需要存下的内容 | |||||
pp = Pipeline() | |||||
pp.add_processor(fs2hs_proc) | |||||
# pp.add_processor(sp_proc) | |||||
pp.add_processor(char_proc) | |||||
pp.add_processor(tag_proc) | |||||
pp.add_processor(bigram_proc) | |||||
pp.add_processor(char_index_proc) | |||||
pp.add_processor(bigram_index_proc) | |||||
pp.add_processor(seq_len_proc) | |||||
te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||||
te_dataset = reader.load(te_filename) | |||||
pp(te_dataset) | |||||
batch_size = 64 | |||||
te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, | |||||
pre * 100, | |||||
rec * 100)) | |||||
# TODO 这里貌似需要区分test pipeline与infer pipeline | |||||
test_context_dict = {'pipeline': pp, | |||||
'model': cws_model} | |||||
torch.save(test_context_dict, 'models/test_context.pkl') | |||||
# 5. dev的pp | |||||
# 4. 组装需要存下的内容 | |||||
from fastNLP.api.processor import ModelProcessor | |||||
from reproduction.chinese_word_segment.process.cws_processor import SegApp2OutputProcessor | |||||
model_proc = ModelProcessor(cws_model) | |||||
output_proc = SegApp2OutputProcessor() | |||||
pp = Pipeline() | |||||
pp.add_processor(fs2hs_proc) | |||||
# pp.add_processor(sp_proc) | |||||
pp.add_processor(char_proc) | |||||
pp.add_processor(bigram_proc) | |||||
pp.add_processor(char_index_proc) | |||||
pp.add_processor(bigram_index_proc) | |||||
pp.add_processor(seq_len_proc) | |||||
pp.add_processor(model_proc) | |||||
pp.add_processor(output_proc) | |||||
# TODO 这里貌似需要区分test pipeline与infer pipeline | |||||
infer_context_dict = {'pipeline': pp} | |||||
torch.save(infer_context_dict, 'models/infer_cws.pkl') | |||||
# TODO 还需要考虑如何替换回原文的问题? | |||||
# 1. 不需要将特殊tag替换 | |||||
# 2. 需要将特殊tag替换回去 |
@@ -1,127 +0,0 @@ | |||||
import copy | |||||
import os | |||||
import sys | |||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) | |||||
print(sys.path) | |||||
import torch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.pipeline import Pipeline | |||||
from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor, ModelProcessor, Index2WordProcessor | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.metrics import SeqLabelEvaluator | |||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP.io.config_loader import ConfigLoader, ConfigSection | |||||
from fastNLP.io.dataset_loader import PeopleDailyCorpusLoader | |||||
from fastNLP.models.sequence_modeling import AdvSeqLabel | |||||
cfgfile = './pos_tag.cfg' | |||||
datadir = "/home/zyfeng/data/" | |||||
data_name = "CWS_POS_TAG_NER_people_daily.txt" | |||||
# datadir = "/home/zyfeng/env/fastnlp_v_2/test/data_for_tests" | |||||
# data_name = "people_daily_raw.txt" | |||||
pos_tag_data_path = os.path.join(datadir, data_name) | |||||
pickle_path = "save" | |||||
data_infer_path = os.path.join(datadir, "infer.utf8") | |||||
def train(): | |||||
# load config | |||||
train_param = ConfigSection() | |||||
model_param = ConfigSection() | |||||
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param}) | |||||
print("config loaded") | |||||
# Data Loader | |||||
loader = PeopleDailyCorpusLoader() | |||||
train_data, _ = loader.load(os.path.join(datadir, data_name)) | |||||
print("data loaded") | |||||
dataset = DataSet() | |||||
for data in train_data: | |||||
instance = Instance() | |||||
instance["words"] = data[0] | |||||
instance["tag"] = data[1] | |||||
dataset.append(instance) | |||||
print("dataset transformed") | |||||
# processor_1 = FullSpaceToHalfSpaceProcessor('words') | |||||
# processor_1(dataset) | |||||
word_vocab_proc = VocabProcessor('words') | |||||
tag_vocab_proc = VocabProcessor("tag") | |||||
word_vocab_proc(dataset) | |||||
tag_vocab_proc(dataset) | |||||
word_indexer = IndexerProcessor(word_vocab_proc.get_vocab(), 'words', 'word_seq', delete_old_field=True) | |||||
word_indexer(dataset) | |||||
tag_indexer = IndexerProcessor(tag_vocab_proc.get_vocab(), 'tag', 'truth', delete_old_field=True) | |||||
tag_indexer(dataset) | |||||
seq_len_proc = SeqLenProcessor("word_seq", "word_seq_origin_len") | |||||
seq_len_proc(dataset) | |||||
#torch.save(dataset, "data_set.pkl") | |||||
dev_set = copy.deepcopy(dataset) | |||||
dev_set.set_is_target(truth=True) | |||||
print("processors defined") | |||||
# dataset.set_is_target(tag_ids=True) | |||||
model_param["vocab_size"] = len(word_vocab_proc.get_vocab()) | |||||
model_param["num_classes"] = len(tag_vocab_proc.get_vocab()) | |||||
print("vocab_size={} num_classes={}".format(len(word_vocab_proc.get_vocab()), len(tag_vocab_proc.get_vocab()))) | |||||
# define a model | |||||
model = AdvSeqLabel(model_param) | |||||
# call trainer to train | |||||
trainer = Trainer(epochs=train_param["epochs"], | |||||
batch_size=train_param["batch_size"], | |||||
validate=True, | |||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0.9), | |||||
evaluator=SeqLabelEvaluator(), | |||||
use_cuda=True | |||||
) | |||||
trainer.train(model, dataset, dev_set) | |||||
model_proc = ModelProcessor(model, "word_seq_origin_len") | |||||
dataset.set_is_target(truth=True) | |||||
res = model_proc.process(dataset) | |||||
decoder = Index2WordProcessor(tag_vocab_proc.get_vocab(), "predict", "outputs") | |||||
# save model & pipeline | |||||
pp = Pipeline([word_indexer, seq_len_proc, model_proc, decoder]) | |||||
save_dict = {"pipeline": pp} | |||||
torch.save(save_dict, "model_pp.pkl") | |||||
def test(): | |||||
pass | |||||
def infer(): | |||||
pass | |||||
if __name__ == "__main__": | |||||
train() | |||||
""" | |||||
import argparse | |||||
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') | |||||
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer']) | |||||
args = parser.parse_args() | |||||
if args.mode == 'train': | |||||
train() | |||||
elif args.mode == 'test': | |||||
test() | |||||
elif args.mode == 'infer': | |||||
infer() | |||||
else: | |||||
print('no mode specified for model!') | |||||
parser.print_help() | |||||
""" |