|
|
@@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import StepLR |
|
|
|
from fastNLP import Tester |
|
|
|
from fastNLP import GradientClipCallback, LRScheduler |
|
|
|
import os |
|
|
|
from fastNLP import cache_results |
|
|
|
|
|
|
|
def set_random_seed(random_seed=666): |
|
|
|
import random, numpy, torch |
|
|
@@ -39,43 +40,42 @@ label_mlp_size = 100 |
|
|
|
batch_size = 32 |
|
|
|
update_every = 4 |
|
|
|
n_epochs = 100 |
|
|
|
data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 |
|
|
|
vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt |
|
|
|
data_name = 'new_ctb7' |
|
|
|
#################################################### |
|
|
|
data_folder = f'/remote-home/hyan01/exps/JointCwsPosParser/data/{data_name}/output' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 |
|
|
|
vector_folder = '/remote-home/hyan01/exps/CWS/pretrain/vectors' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt |
|
|
|
|
|
|
|
set_random_seed(1234) |
|
|
|
device = 0 |
|
|
|
|
|
|
|
# @cache_results('caches/{}.pkl'.format(data_name)) |
|
|
|
# def get_data(): |
|
|
|
data = CTBxJointLoader().process(data_folder) |
|
|
|
|
|
|
|
char_labels_vocab = data.vocabs['char_labels'] |
|
|
|
|
|
|
|
pre_chars_vocab = data.vocabs['pre_chars'] |
|
|
|
pre_bigrams_vocab = data.vocabs['pre_bigrams'] |
|
|
|
pre_trigrams_vocab = data.vocabs['pre_trigrams'] |
|
|
|
|
|
|
|
chars_vocab = data.vocabs['chars'] |
|
|
|
bigrams_vocab = data.vocabs['bigrams'] |
|
|
|
trigrams_vocab = data.vocabs['trigrams'] |
|
|
|
|
|
|
|
pre_chars_embed = StaticEmbedding(pre_chars_vocab, |
|
|
|
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), |
|
|
|
init_method=uniform_init, normalize=False) |
|
|
|
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() |
|
|
|
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, |
|
|
|
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), |
|
|
|
init_method=uniform_init, normalize=False) |
|
|
|
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() |
|
|
|
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, |
|
|
|
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), |
|
|
|
init_method=uniform_init, normalize=False) |
|
|
|
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() |
|
|
|
|
|
|
|
# return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data |
|
|
|
|
|
|
|
# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() |
|
|
|
@cache_results('caches/{}.pkl'.format(data_name)) |
|
|
|
def get_data(): |
|
|
|
data = CTBxJointLoader().process(data_folder) |
|
|
|
char_labels_vocab = data.vocabs['char_labels'] |
|
|
|
|
|
|
|
pre_chars_vocab = data.vocabs['pre_chars'] |
|
|
|
pre_bigrams_vocab = data.vocabs['pre_bigrams'] |
|
|
|
pre_trigrams_vocab = data.vocabs['pre_trigrams'] |
|
|
|
|
|
|
|
chars_vocab = data.vocabs['chars'] |
|
|
|
bigrams_vocab = data.vocabs['bigrams'] |
|
|
|
trigrams_vocab = data.vocabs['trigrams'] |
|
|
|
pre_chars_embed = StaticEmbedding(pre_chars_vocab, |
|
|
|
model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), |
|
|
|
init_method=uniform_init, normalize=False) |
|
|
|
pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std() |
|
|
|
pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, |
|
|
|
model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), |
|
|
|
init_method=uniform_init, normalize=False) |
|
|
|
pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std() |
|
|
|
pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, |
|
|
|
model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), |
|
|
|
init_method=uniform_init, normalize=False) |
|
|
|
pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std() |
|
|
|
|
|
|
|
return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data |
|
|
|
|
|
|
|
chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() |
|
|
|
|
|
|
|
print(data) |
|
|
|
model = CharParser(char_vocab_size=len(chars_vocab), |
|
|
@@ -104,11 +104,24 @@ optimizer = optim.Adam([param for param in model.parameters() if param.requires_ |
|
|
|
|
|
|
|
sampler = BucketSampler(seq_len_field_name='seq_lens') |
|
|
|
callbacks = [] |
|
|
|
|
|
|
|
from fastNLP.core.callback import Callback |
|
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
class SchedulerCallback(Callback): |
|
|
|
def __init__(self, scheduler): |
|
|
|
super().__init__() |
|
|
|
self.scheduler = scheduler |
|
|
|
|
|
|
|
def on_backward_end(self): |
|
|
|
if self.step % self.update_every==0: |
|
|
|
self.scheduler.step() |
|
|
|
|
|
|
|
scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) |
|
|
|
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) |
|
|
|
scheduler = StepLR(optimizer, step_size=18, gamma=0.75) |
|
|
|
# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) |
|
|
|
# scheduler = StepLR(optimizer, step_size=18, gamma=0.75) |
|
|
|
scheduler_callback = SchedulerCallback(scheduler) |
|
|
|
# callbacks.append(optim_callback) |
|
|
|
scheduler_callback = LRScheduler(scheduler) |
|
|
|
# scheduler_callback = LRScheduler(scheduler) |
|
|
|
callbacks.append(scheduler_callback) |
|
|
|
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) |
|
|
|
|
|
|
@@ -119,6 +132,6 @@ callbacks.append(dev_callback) |
|
|
|
|
|
|
|
trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, |
|
|
|
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, |
|
|
|
check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, |
|
|
|
check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True, |
|
|
|
device=device, callbacks=callbacks, update_every=update_every) |
|
|
|
trainer.train() |