@@ -67,7 +67,7 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
sentence = ins[self.field_name].text | |||||
sentence = ins[self.field_name] | |||||
new_sentence = [None]*len(sentence) | new_sentence = [None]*len(sentence) | ||||
for idx, char in enumerate(sentence): | for idx, char in enumerate(sentence): | ||||
if char in self.convert_map: | if char in self.convert_map: | ||||
@@ -78,12 +78,13 @@ class FullSpaceToHalfSpaceProcessor(Processor): | |||||
class IndexerProcessor(Processor): | class IndexerProcessor(Processor): | ||||
def __init__(self, vocab, field_name, new_added_field_name): | |||||
def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): | |||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | super(IndexerProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.vocab = vocab | self.vocab = vocab | ||||
self.delete_old_field = delete_old_field | |||||
def set_vocab(self, vocab): | def set_vocab(self, vocab): | ||||
assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) | ||||
@@ -97,6 +98,11 @@ class IndexerProcessor(Processor): | |||||
index = [self.vocab.to_index(token) for token in tokens] | index = [self.vocab.to_index(token) for token in tokens] | ||||
ins[self.new_added_field_name] = index | ins[self.new_added_field_name] = index | ||||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||||
if self.delete_old_field: | |||||
dataset.delete_field(self.field_name) | |||||
return dataset | return dataset | ||||
@@ -55,14 +55,15 @@ class Batch(object): | |||||
indices = self.idx_list[self.curidx:endidx] | indices = self.idx_list[self.curidx:endidx] | ||||
for field_name, field in self.dataset.get_fields(): | |||||
batch = torch.from_numpy(field.get(indices)) | |||||
if not field.need_tensor: #TODO 修改 | |||||
pass | |||||
elif field.is_target: | |||||
batch_y[field_name] = batch | |||||
else: | |||||
batch_x[field_name] = batch | |||||
for field_name, field in self.dataset.get_fields().items(): | |||||
if field.need_tensor: | |||||
batch = torch.from_numpy(field.get(indices)) | |||||
if not field.need_tensor: | |||||
pass | |||||
elif field.is_target: | |||||
batch_y[field_name] = batch | |||||
else: | |||||
batch_x[field_name] = batch | |||||
self.curidx = endidx | self.curidx = endidx | ||||
@@ -75,11 +75,13 @@ class DataSet(object): | |||||
assert len(self) == len(fields) | assert len(self) == len(fields) | ||||
self.field_arrays[name] = FieldArray(name, fields) | self.field_arrays[name] = FieldArray(name, fields) | ||||
def delete_field(self, name): | |||||
self.field_arrays.pop(name) | |||||
def get_fields(self): | def get_fields(self): | ||||
return self.field_arrays | return self.field_arrays | ||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
assert name in self.field_arrays | |||||
return self.field_arrays[name] | return self.field_arrays[name] | ||||
def __len__(self): | def __len__(self): | ||||
@@ -35,13 +35,6 @@ class CWSBiLSTMEncoder(BaseModel): | |||||
self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim) | self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim) | ||||
self.input_size += self.num_bigram_per_char*bigram_embed_dim | self.input_size += self.num_bigram_per_char*bigram_embed_dim | ||||
if self.num_criterion!=None: | |||||
if bidirectional: | |||||
self.backward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion, | |||||
embedding_dim=self.hidden_size) | |||||
self.forward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion, | |||||
embedding_dim=self.hidden_size) | |||||
if not self.embed_drop_p is None: | if not self.embed_drop_p is None: | ||||
self.embedding_drop = nn.Dropout(p=self.embed_drop_p) | self.embedding_drop = nn.Dropout(p=self.embed_drop_p) | ||||
@@ -102,13 +95,14 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
self.decoder_model = MLP(size_layer) | self.decoder_model = MLP(size_layer) | ||||
def forward(self, **kwargs): | |||||
chars = kwargs['chars'] | |||||
if 'bigram' in kwargs: | |||||
bigrams = kwargs['bigrams'] | |||||
def forward(self, batch_dict): | |||||
device = self.parameters().__next__().device | |||||
chars = batch_dict['indexed_chars_list'].to(device) | |||||
if 'bigram' in batch_dict: | |||||
bigrams = batch_dict['indexed_chars_list'].to(device) | |||||
else: | else: | ||||
bigrams = None | bigrams = None | ||||
seq_lens = kwargs['seq_lens'] | |||||
seq_lens = batch_dict['seq_lens'].to(device) | |||||
feats = self.encoder_model(chars, bigrams, seq_lens) | feats = self.encoder_model(chars, bigrams, seq_lens) | ||||
probs = self.decoder_model(feats) | probs = self.decoder_model(feats) | ||||
@@ -119,6 +113,10 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
return pred_dict | return pred_dict | ||||
def predict(self, batch_dict): | |||||
pass | |||||
def loss_fn(self, pred_dict, true_dict): | def loss_fn(self, pred_dict, true_dict): | ||||
seq_lens = pred_dict['seq_lens'] | seq_lens = pred_dict['seq_lens'] | ||||
masks = seq_lens_to_mask(seq_lens).float() | masks = seq_lens_to_mask(seq_lens).float() | ||||
@@ -131,5 +129,4 @@ class CWSBiLSTMSegApp(BaseModel): | |||||
true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks) | true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks) | ||||
return loss | |||||
return loss |
@@ -21,7 +21,7 @@ class SpeicalSpanProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
sentence = ins[self.field_name].text | |||||
sentence = ins[self.field_name] | |||||
for span_converter in self.span_converters: | for span_converter in self.span_converters: | ||||
sentence = span_converter.find_certain_span_and_replace(sentence) | sentence = span_converter.find_certain_span_and_replace(sentence) | ||||
ins[self.new_added_field_name] = sentence | ins[self.new_added_field_name] = sentence | ||||
@@ -42,10 +42,9 @@ class CWSCharSegProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
sentence = ins[self.field_name].text | |||||
sentence = ins[self.field_name] | |||||
chars = self._split_sent_into_chars(sentence) | chars = self._split_sent_into_chars(sentence) | ||||
new_token_field = TokenListFiled(chars, is_target=False) | |||||
ins[self.new_added_field_name] = new_token_field | |||||
ins[self.new_added_field_name] = chars | |||||
return dataset | return dataset | ||||
@@ -109,10 +108,11 @@ class CWSTagProcessor(Processor): | |||||
def process(self, dataset): | def process(self, dataset): | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
sentence = ins[self.field_name].text | |||||
sentence = ins[self.field_name] | |||||
tag_list = self._generate_tag(sentence) | tag_list = self._generate_tag(sentence) | ||||
new_tag_field = SeqLabelField(tag_list) | new_tag_field = SeqLabelField(tag_list) | ||||
ins[self.new_added_field_name] = new_tag_field | ins[self.new_added_field_name] = new_tag_field | ||||
dataset.set_is_target(**{self.new_added_field_name:True}) | |||||
return dataset | return dataset | ||||
def _tags_from_word_len(self, word_len): | def _tags_from_word_len(self, word_len): | ||||
@@ -123,6 +123,8 @@ class CWSSegAppTagProcessor(CWSTagProcessor): | |||||
def __init__(self, field_name, new_added_field_name=None): | def __init__(self, field_name, new_added_field_name=None): | ||||
super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) | ||||
self.tag_size = 2 | |||||
def _tags_from_word_len(self, word_len): | def _tags_from_word_len(self, word_len): | ||||
tag_list = [] | tag_list = [] | ||||
for _ in range(word_len-1): | for _ in range(word_len-1): | ||||
@@ -140,10 +142,9 @@ class BigramProcessor(Processor): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
characters = ins[self.field_name].content | |||||
characters = ins[self.field_name] | |||||
bigrams = self._generate_bigram(characters) | bigrams = self._generate_bigram(characters) | ||||
new_token_field = TokenListFiled(bigrams) | |||||
ins[self.new_added_field_name] = new_token_field | |||||
ins[self.new_added_field_name] = bigrams | |||||
return dataset | return dataset | ||||
@@ -190,9 +191,26 @@ class VocabProcessor(Processor): | |||||
for dataset in datasets: | for dataset in datasets: | ||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | ||||
for ins in dataset: | for ins in dataset: | ||||
tokens = ins[self.field_name].content | |||||
tokens = ins[self.field_name] | |||||
self.vocab.update(tokens) | self.vocab.update(tokens) | ||||
def get_vocab(self): | def get_vocab(self): | ||||
self.vocab.build_vocab() | self.vocab.build_vocab() | ||||
return self.vocab | return self.vocab | ||||
def get_vocab_size(self): | |||||
return len(self.vocab) | |||||
class SeqLenProcessor(Processor): | |||||
def __init__(self, field_name, new_added_field_name='seq_lens'): | |||||
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) | |||||
def process(self, dataset): | |||||
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) | |||||
for ins in dataset: | |||||
length = len(ins[self.field_name]) | |||||
ins[self.new_added_field_name] = length | |||||
dataset.set_need_tensor(**{self.new_added_field_name:True}) | |||||
return dataset |
@@ -9,35 +9,22 @@ from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegPr | |||||
from reproduction.chinese_word_segment.process.cws_processor import CWSSegAppTagProcessor | from reproduction.chinese_word_segment.process.cws_processor import CWSSegAppTagProcessor | ||||
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor | from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor | ||||
from reproduction.chinese_word_segment.process.cws_processor import VocabProcessor | from reproduction.chinese_word_segment.process.cws_processor import VocabProcessor | ||||
from reproduction.chinese_word_segment.process.cws_processor import SeqLenProcessor | |||||
from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter | from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter | ||||
from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter | from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter | ||||
from reproduction.chinese_word_segment.io.cws_reader import NaiveCWSReader | |||||
from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader | |||||
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp | from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp | ||||
tr_filename = '' | |||||
dev_filename = '' | |||||
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_train.txt' | |||||
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_dev.txt' | |||||
reader = NaiveCWSReader() | reader = NaiveCWSReader() | ||||
tr_sentences = reader.load(tr_filename, cut_long_sent=True) | |||||
dev_sentences = reader.load(dev_filename) | |||||
tr_dataset = reader.load(tr_filename, cut_long_sent=True) | |||||
dev_dataset = reader.load(dev_filename) | |||||
# TODO 如何组建成为一个Dataset | |||||
def construct_dataset(sentences): | |||||
dataset = DataSet() | |||||
for sentence in sentences: | |||||
instance = Instance() | |||||
instance['raw_sentence'] = sentence | |||||
dataset.append(instance) | |||||
return dataset | |||||
tr_dataset = construct_dataset(tr_sentences) | |||||
dev_dataset = construct_dataset(dev_sentences) | |||||
# 1. 准备processor | # 1. 准备processor | ||||
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | ||||
@@ -45,14 +32,14 @@ sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') | |||||
sp_proc.add_span_converter(AlphaSpanConverter()) | sp_proc.add_span_converter(AlphaSpanConverter()) | ||||
sp_proc.add_span_converter(DigitSpanConverter()) | sp_proc.add_span_converter(DigitSpanConverter()) | ||||
char_proc = CWSCharSegProcessor('sentence', 'char_list') | |||||
char_proc = CWSCharSegProcessor('sentence', 'chars_list') | |||||
tag_proc = CWSSegAppTagProcessor('sentence', 'tag') | |||||
tag_proc = CWSSegAppTagProcessor('sentence', 'tags') | |||||
bigram_proc = Pre2Post2BigramProcessor('char_list', 'bigram_list') | |||||
bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') | |||||
char_vocab_proc = VocabProcessor('char_list') | |||||
bigram_vocab_proc = VocabProcessor('bigram_list') | |||||
char_vocab_proc = VocabProcessor('chars_list') | |||||
bigram_vocab_proc = VocabProcessor('bigrams_list') | |||||
# 2. 使用processor | # 2. 使用processor | ||||
fs2hs_proc(tr_dataset) | fs2hs_proc(tr_dataset) | ||||
@@ -66,15 +53,18 @@ bigram_proc(tr_dataset) | |||||
char_vocab_proc(tr_dataset) | char_vocab_proc(tr_dataset) | ||||
bigram_vocab_proc(tr_dataset) | bigram_vocab_proc(tr_dataset) | ||||
char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'indexed_chars_list') | |||||
bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','indexed_bigrams_list') | |||||
char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'indexed_chars_list', | |||||
delete_old_field=True) | |||||
bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','indexed_bigrams_list', | |||||
delete_old_field=True) | |||||
seq_len_proc = SeqLenProcessor('indexed_chars_list') | |||||
char_index_proc(tr_dataset) | char_index_proc(tr_dataset) | ||||
bigram_index_proc(tr_dataset) | bigram_index_proc(tr_dataset) | ||||
seq_len_proc(tr_dataset) | |||||
# 2.1 处理dev_dataset | # 2.1 处理dev_dataset | ||||
fs2hs_proc(dev_dataset) | fs2hs_proc(dev_dataset) | ||||
sp_proc(dev_dataset) | sp_proc(dev_dataset) | ||||
char_proc(dev_dataset) | char_proc(dev_dataset) | ||||
@@ -83,14 +73,148 @@ bigram_proc(dev_dataset) | |||||
char_index_proc(dev_dataset) | char_index_proc(dev_dataset) | ||||
bigram_index_proc(dev_dataset) | bigram_index_proc(dev_dataset) | ||||
seq_len_proc(dev_dataset) | |||||
print("Finish preparing data.") | |||||
# 3. 得到数据集可以用于训练了 | # 3. 得到数据集可以用于训练了 | ||||
# TODO pretrain的embedding是怎么解决的? | |||||
cws_model = CWSBiLSTMSegApp(vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2) | |||||
from itertools import chain | |||||
def refine_ys_on_seq_len(ys, seq_lens): | |||||
refined_ys = [] | |||||
for b_idx, length in enumerate(seq_lens): | |||||
refined_ys.append(list(ys[b_idx][:length])) | |||||
return refined_ys | |||||
def flat_nested_list(nested_list): | |||||
return list(chain(*nested_list)) | |||||
def calculate_pre_rec_f1(model, batcher): | |||||
true_ys, pred_ys, seq_lens = decode_iterator(model, batcher) | |||||
refined_true_ys = refine_ys_on_seq_len(true_ys, seq_lens) | |||||
refined_pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) | |||||
true_ys = flat_nested_list(refined_true_ys) | |||||
pred_ys = flat_nested_list(refined_pred_ys) | |||||
cor_num = 0 | |||||
yp_wordnum = pred_ys.count(1) | |||||
yt_wordnum = true_ys.count(1) | |||||
start = 0 | |||||
for i in range(len(true_ys)): | |||||
if true_ys[i] == 1: | |||||
flag = True | |||||
for j in range(start, i + 1): | |||||
if true_ys[j] != pred_ys[j]: | |||||
flag = False | |||||
break | |||||
if flag: | |||||
cor_num += 1 | |||||
start = i + 1 | |||||
P = cor_num / (float(yp_wordnum) + 1e-6) | |||||
R = cor_num / (float(yt_wordnum) + 1e-6) | |||||
F = 2 * P * R / (P + R + 1e-6) | |||||
return P, R, F | |||||
def decode_iterator(model, batcher): | |||||
true_ys = [] | |||||
pred_ys = [] | |||||
seq_lens = [] | |||||
with torch.no_grad(): | |||||
model.eval() | |||||
for batch_x, batch_y in batcher: | |||||
pred_dict = model(batch_x) | |||||
seq_len = pred_dict['seq_lens'].cpu().numpy() | |||||
probs = pred_dict['pred_probs'] | |||||
_, pred_y = probs.max(dim=-1) | |||||
true_y = batch_y['tags'] | |||||
pred_y = pred_y.cpu().numpy() | |||||
true_y = true_y.cpu().numpy() | |||||
true_ys.extend(list(true_y)) | |||||
pred_ys.extend(list(pred_y)) | |||||
seq_lens.extend(list(seq_len)) | |||||
model.train() | |||||
return true_ys, pred_ys, seq_lens | |||||
# TODO pretrain的embedding是怎么解决的? | |||||
from reproduction.chinese_word_segment.utils import FocalLoss | |||||
from reproduction.chinese_word_segment.utils import seq_lens_to_mask | |||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import RandomSampler | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
import torch | |||||
from torch import optim | |||||
import sys | |||||
from tqdm import tqdm | |||||
tag_size = tag_proc.tag_size | |||||
cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100, | |||||
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), | |||||
bigram_embed_dim=100, num_bigram_per_char=8, | |||||
hidden_size=200, bidirectional=True, embed_drop_p=None, | |||||
num_layers=1, tag_size=tag_size) | |||||
num_epochs = 3 | |||||
loss_fn = FocalLoss(class_num=tag_size) | |||||
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01) | |||||
print_every = 50 | |||||
batch_size = 32 | |||||
tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False) | |||||
dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
num_batch_per_epoch = len(tr_dataset) // batch_size | |||||
best_f1 = 0 | |||||
best_epoch = 0 | |||||
for num_epoch in range(num_epochs): | |||||
print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) | |||||
sys.stdout.flush() | |||||
avg_loss = 0 | |||||
with tqdm(total=num_batch_per_epoch, leave=True) as pbar: | |||||
pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) | |||||
cws_model.train() | |||||
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | |||||
pred_dict = cws_model(batch_x) # B x L x tag_size | |||||
seq_lens = batch_x['seq_lens'] | |||||
masks = seq_lens_to_mask(seq_lens) | |||||
tags = batch_y['tags'] | |||||
loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size), | |||||
tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | |||||
# loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | |||||
avg_loss += loss.item() | |||||
loss.backward() | |||||
for group in optimizer.param_groups: | |||||
for param in group['params']: | |||||
param.grad.clamp_(-5, 5) | |||||
optimizer.step() | |||||
if batch_idx % print_every == 0: | |||||
pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | |||||
avg_loss = 0 | |||||
pbar.update(print_every) | |||||
# 验证集 | |||||
pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) | |||||
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||||
pre*100, | |||||
rec*100)) | |||||
if best_f1<f1: | |||||
best_f1 = f1 | |||||
# 缓存最佳的parameter,可能之后会用于保存 | |||||
best_state_dict = { | |||||
key:value.clone() for key, value in | |||||
cws_model.state_dict().items() | |||||
} | |||||
best_epoch = num_epoch | |||||
# 4. 组装需要存下的内容 | # 4. 组装需要存下的内容 | ||||