@@ -0,0 +1,61 @@ | |||||
# Multi-Criteria-CWS | |||||
An implementation of [Multi-Criteria Chinese Word Segmentation with Transformer](http://arxiv.org/abs/1906.12035) with fastNLP. | |||||
## Dataset | |||||
### Overview | |||||
We use the same datasets listed in paper. | |||||
- sighan2005 | |||||
- pku | |||||
- msr | |||||
- as | |||||
- cityu | |||||
- sighan2008 | |||||
- ctb | |||||
- ckip | |||||
- cityu (combined with data in sighan2005) | |||||
- ncc | |||||
- sxu | |||||
### Preprocess | |||||
First, download OpenCC to convert between Traditional Chinese and Simplified Chinese. | |||||
``` shell | |||||
pip install opencc-python-reimplemented | |||||
``` | |||||
Then, set a path to save processed data, and run the shell script to process the data. | |||||
```shell | |||||
export DATA_DIR=path/to/processed-data | |||||
bash make_data.sh path/to/sighan2005 path/to/sighan2008 | |||||
``` | |||||
It would take a few minutes to finish the process. | |||||
## Model | |||||
We use transformer to build the model, as described in paper. | |||||
## Train | |||||
Finally, to train the model, run the shell script. | |||||
The `train.sh` takes one argument, the GPU-IDs to use, for example: | |||||
``` shell | |||||
bash train.sh 0,1 | |||||
``` | |||||
This command use GPUs with ID 0 and 1. | |||||
Note: Please refer to the paper for details of hyper-parameters. And modify the settings in `train.sh` to match your experiment environment. | |||||
Type | |||||
``` shell | |||||
python main.py --help | |||||
``` | |||||
to learn all arguments to be specified in training. | |||||
## Performance | |||||
Results on the test sets of eight CWS datasets with multi-criteria learning. | |||||
| Dataset | MSRA | AS | PKU | CTB | CKIP | CITYU | NCC | SXU | Avg. | | |||||
| -------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | | |||||
| Original paper | 98.05 | 96.44 | 96.41 | 96.99 | 96.51 | 96.91 | 96.04 | 97.61 | 96.87 | | |||||
| Ours | 96.92 | 95.71 | 95.65 | 95.96 | 96.00 | 96.09 | 94.61 | 96.64 | 95.95 | | |||||
@@ -0,0 +1,262 @@ | |||||
import os | |||||
import re | |||||
import argparse | |||||
from opencc import OpenCC | |||||
cc = OpenCC("t2s") | |||||
from utils import make_sure_path_exists, append_tags | |||||
sighan05_root = "" | |||||
sighan08_root = "" | |||||
data_path = "" | |||||
E_pun = u",.!?[]()<>\"\"''," | |||||
C_pun = u",。!?【】()《》“”‘’、" | |||||
Table = {ord(f): ord(t) for f, t in zip(C_pun, E_pun)} | |||||
Table[12288] = 32 # 全半角空格 | |||||
def C_trans_to_E(string): | |||||
return string.translate(Table) | |||||
def normalize(ustring): | |||||
"""全角转半角""" | |||||
rstring = "" | |||||
for uchar in ustring: | |||||
inside_code = ord(uchar) | |||||
if inside_code == 12288: # 全角空格直接转换 | |||||
inside_code = 32 | |||||
elif 65281 <= inside_code <= 65374: # 全角字符(除空格)根据关系转化 | |||||
inside_code -= 65248 | |||||
rstring += chr(inside_code) | |||||
return rstring | |||||
def preprocess(text): | |||||
rNUM = u"(-|\+)?\d+((\.|·)\d+)?%?" | |||||
rENG = u"[A-Za-z_]+.*" | |||||
sent = normalize(C_trans_to_E(text.strip())).split() | |||||
new_sent = [] | |||||
for word in sent: | |||||
word = re.sub(u"\s+", "", word, flags=re.U) | |||||
word = re.sub(rNUM, u"0", word, flags=re.U) | |||||
word = re.sub(rENG, u"X", word) | |||||
new_sent.append(word) | |||||
return new_sent | |||||
def to_sentence_list(text, split_long_sentence=False): | |||||
text = preprocess(text) | |||||
delimiter = set() | |||||
delimiter.update("。!?:;…、,(),;!?、,\"'") | |||||
delimiter.add("……") | |||||
sent_list = [] | |||||
sent = [] | |||||
sent_len = 0 | |||||
for word in text: | |||||
sent.append(word) | |||||
sent_len += len(word) | |||||
if word in delimiter or (split_long_sentence and sent_len >= 50): | |||||
sent_list.append(sent) | |||||
sent = [] | |||||
sent_len = 0 | |||||
if len(sent) > 0: | |||||
sent_list.append(sent) | |||||
return sent_list | |||||
def is_traditional(dataset): | |||||
return dataset in ["as", "cityu", "ckip"] | |||||
def convert_file( | |||||
src, des, need_cc=False, split_long_sentence=False, encode="utf-8-sig" | |||||
): | |||||
with open(src, encoding=encode) as src, open(des, "w", encoding="utf-8") as des: | |||||
for line in src: | |||||
for sent in to_sentence_list(line, split_long_sentence): | |||||
line = " ".join(sent) + "\n" | |||||
if need_cc: | |||||
line = cc.convert(line) | |||||
des.write(line) | |||||
# if len(''.join(sent)) > 200: | |||||
# print(' '.join(sent)) | |||||
def split_train_dev(dataset): | |||||
root = data_path + "/" + dataset + "/raw/" | |||||
with open(root + "train-all.txt", encoding="UTF-8") as src, open( | |||||
root + "train.txt", "w", encoding="UTF-8" | |||||
) as train, open(root + "dev.txt", "w", encoding="UTF-8") as dev: | |||||
lines = src.readlines() | |||||
idx = int(len(lines) * 0.9) | |||||
for line in lines[:idx]: | |||||
train.write(line) | |||||
for line in lines[idx:]: | |||||
dev.write(line) | |||||
def combine_files(one, two, out): | |||||
if os.path.exists(out): | |||||
os.remove(out) | |||||
with open(one, encoding="utf-8") as one, open(two, encoding="utf-8") as two, open( | |||||
out, "a", encoding="utf-8" | |||||
) as out: | |||||
for line in one: | |||||
out.write(line) | |||||
for line in two: | |||||
out.write(line) | |||||
def bmes_tag(input_file, output_file): | |||||
with open(input_file, encoding="utf-8") as input_data, open( | |||||
output_file, "w", encoding="utf-8" | |||||
) as output_data: | |||||
for line in input_data: | |||||
word_list = line.strip().split() | |||||
for word in word_list: | |||||
if len(word) == 1 or ( | |||||
len(word) > 2 and word[0] == "<" and word[-1] == ">" | |||||
): | |||||
output_data.write(word + "\tS\n") | |||||
else: | |||||
output_data.write(word[0] + "\tB\n") | |||||
for w in word[1 : len(word) - 1]: | |||||
output_data.write(w + "\tM\n") | |||||
output_data.write(word[len(word) - 1] + "\tE\n") | |||||
output_data.write("\n") | |||||
def make_bmes(dataset="pku"): | |||||
path = data_path + "/" + dataset + "/" | |||||
make_sure_path_exists(path + "bmes") | |||||
bmes_tag(path + "raw/train.txt", path + "bmes/train.txt") | |||||
bmes_tag(path + "raw/train-all.txt", path + "bmes/train-all.txt") | |||||
bmes_tag(path + "raw/dev.txt", path + "bmes/dev.txt") | |||||
bmes_tag(path + "raw/test.txt", path + "bmes/test.txt") | |||||
def convert_sighan2005_dataset(dataset): | |||||
global sighan05_root | |||||
root = os.path.join(data_path, dataset) | |||||
make_sure_path_exists(root) | |||||
make_sure_path_exists(root + "/raw") | |||||
file_path = "{}/{}_training.utf8".format(sighan05_root, dataset) | |||||
convert_file( | |||||
file_path, "{}/raw/train-all.txt".format(root), is_traditional(dataset), True | |||||
) | |||||
if dataset == "as": | |||||
file_path = "{}/{}_testing_gold.utf8".format(sighan05_root, dataset) | |||||
else: | |||||
file_path = "{}/{}_test_gold.utf8".format(sighan05_root, dataset) | |||||
convert_file( | |||||
file_path, "{}/raw/test.txt".format(root), is_traditional(dataset), False | |||||
) | |||||
split_train_dev(dataset) | |||||
def convert_sighan2008_dataset(dataset, utf=16): | |||||
global sighan08_root | |||||
root = os.path.join(data_path, dataset) | |||||
make_sure_path_exists(root) | |||||
make_sure_path_exists(root + "/raw") | |||||
convert_file( | |||||
"{}/{}_train_utf{}.seg".format(sighan08_root, dataset, utf), | |||||
"{}/raw/train-all.txt".format(root), | |||||
is_traditional(dataset), | |||||
True, | |||||
"utf-{}".format(utf), | |||||
) | |||||
convert_file( | |||||
"{}/{}_seg_truth&resource/{}_truth_utf{}.seg".format( | |||||
sighan08_root, dataset, dataset, utf | |||||
), | |||||
"{}/raw/test.txt".format(root), | |||||
is_traditional(dataset), | |||||
False, | |||||
"utf-{}".format(utf), | |||||
) | |||||
split_train_dev(dataset) | |||||
def extract_conll(src, out): | |||||
words = [] | |||||
with open(src, encoding="utf-8") as src, open(out, "w", encoding="utf-8") as out: | |||||
for line in src: | |||||
line = line.strip() | |||||
if len(line) == 0: | |||||
out.write(" ".join(words) + "\n") | |||||
words = [] | |||||
continue | |||||
cells = line.split() | |||||
words.append(cells[1]) | |||||
def make_joint_corpus(datasets, joint): | |||||
parts = ["dev", "test", "train", "train-all"] | |||||
for part in parts: | |||||
old_file = "{}/{}/raw/{}.txt".format(data_path, joint, part) | |||||
if os.path.exists(old_file): | |||||
os.remove(old_file) | |||||
elif not os.path.exists(os.path.dirname(old_file)): | |||||
os.makedirs(os.path.dirname(old_file)) | |||||
for name in datasets: | |||||
append_tags( | |||||
os.path.join(data_path, name, "raw"), | |||||
os.path.dirname(old_file), | |||||
name, | |||||
part, | |||||
encode="utf-8", | |||||
) | |||||
def convert_all_sighan2005(datasets): | |||||
for dataset in datasets: | |||||
print(("Converting sighan bakeoff 2005 corpus: {}".format(dataset))) | |||||
convert_sighan2005_dataset(dataset) | |||||
make_bmes(dataset) | |||||
def convert_all_sighan2008(datasets): | |||||
for dataset in datasets: | |||||
print(("Converting sighan bakeoff 2008 corpus: {}".format(dataset))) | |||||
convert_sighan2008_dataset(dataset, 16) | |||||
make_bmes(dataset) | |||||
if __name__ == "__main__": | |||||
parser = argparse.ArgumentParser() | |||||
# fmt: off | |||||
parser.add_argument("--sighan05", required=True, type=str, help="path to sighan2005 dataset") | |||||
parser.add_argument("--sighan08", required=True, type=str, help="path to sighan2008 dataset") | |||||
parser.add_argument("--data_path", required=True, type=str, help="path to save dataset") | |||||
# fmt: on | |||||
args, _ = parser.parse_known_args() | |||||
sighan05_root = args.sighan05 | |||||
sighan08_root = args.sighan08 | |||||
data_path = args.data_path | |||||
print("Converting sighan2005 Simplified Chinese corpus") | |||||
datasets = "pku", "msr", "as", "cityu" | |||||
convert_all_sighan2005(datasets) | |||||
print("Combining sighan2005 corpus to one joint Simplified Chinese corpus") | |||||
datasets = "pku", "msr", "as", "cityu" | |||||
make_joint_corpus(datasets, "joint-sighan2005") | |||||
make_bmes("joint-sighan2005") | |||||
# For researchers who have access to sighan2008 corpus, use official corpora please. | |||||
print("Converting sighan2008 Simplified Chinese corpus") | |||||
datasets = "ctb", "ckip", "cityu", "ncc", "sxu" | |||||
convert_all_sighan2008(datasets) | |||||
print("Combining those 8 sighan corpora to one joint corpus") | |||||
datasets = "pku", "msr", "as", "ctb", "ckip", "cityu", "ncc", "sxu" | |||||
make_joint_corpus(datasets, "joint-sighan2008") | |||||
make_bmes("joint-sighan2008") | |||||
@@ -0,0 +1,166 @@ | |||||
import os | |||||
import sys | |||||
import codecs | |||||
import argparse | |||||
from _pickle import load, dump | |||||
import collections | |||||
from utils import get_processing_word, is_dataset_tag, make_sure_path_exists, get_bmes | |||||
from fastNLP import Instance, DataSet, Vocabulary, Const | |||||
max_len = 0 | |||||
def expand(x): | |||||
sent = ["<sos>"] + x[1:] + ["<eos>"] | |||||
return [x + y for x, y in zip(sent[:-1], sent[1:])] | |||||
def read_file(filename, processing_word=get_processing_word(lowercase=False)): | |||||
dataset = DataSet() | |||||
niter = 0 | |||||
with codecs.open(filename, "r", "utf-8-sig") as f: | |||||
words, tags = [], [] | |||||
for line in f: | |||||
line = line.strip() | |||||
if len(line) == 0 or line.startswith("-DOCSTART-"): | |||||
if len(words) != 0: | |||||
assert len(words) > 2 | |||||
if niter == 1: | |||||
print(words, tags) | |||||
niter += 1 | |||||
dataset.append(Instance(ori_words=words[:-1], ori_tags=tags[:-1])) | |||||
words, tags = [], [] | |||||
else: | |||||
word, tag = line.split() | |||||
word = processing_word(word) | |||||
words.append(word) | |||||
tags.append(tag.lower()) | |||||
dataset.apply_field(lambda x: [x[0]], field_name="ori_words", new_field_name="task") | |||||
dataset.apply_field( | |||||
lambda x: len(x), field_name="ori_tags", new_field_name="seq_len" | |||||
) | |||||
dataset.apply_field( | |||||
lambda x: expand(x), field_name="ori_words", new_field_name="bi1" | |||||
) | |||||
return dataset | |||||
def main(): | |||||
parser = argparse.ArgumentParser() | |||||
# fmt: off | |||||
parser.add_argument("--data_path", required=True, type=str, help="all of datasets pkl paths") | |||||
# fmt: on | |||||
options, _ = parser.parse_known_args() | |||||
train_set, test_set = DataSet(), DataSet() | |||||
input_dir = os.path.join(options.data_path, "joint-sighan2008/bmes") | |||||
options.output = os.path.join(options.data_path, "total_dataset.pkl") | |||||
print(input_dir, options.output) | |||||
for fn in os.listdir(input_dir): | |||||
if fn not in ["test.txt", "train-all.txt"]: | |||||
continue | |||||
print(fn) | |||||
abs_fn = os.path.join(input_dir, fn) | |||||
ds = read_file(abs_fn) | |||||
if "test.txt" == fn: | |||||
test_set = ds | |||||
else: | |||||
train_set = ds | |||||
print( | |||||
"num samples of total train, test: {}, {}".format(len(train_set), len(test_set)) | |||||
) | |||||
uni_vocab = Vocabulary(min_freq=None).from_dataset( | |||||
train_set, test_set, field_name="ori_words" | |||||
) | |||||
# bi_vocab = Vocabulary(min_freq=3, max_size=50000).from_dataset(train_set,test_set, field_name="bi1") | |||||
bi_vocab = Vocabulary(min_freq=3, max_size=None).from_dataset( | |||||
train_set, field_name="bi1", no_create_entry_dataset=[test_set] | |||||
) | |||||
tag_vocab = Vocabulary(min_freq=None, padding="s", unknown=None).from_dataset( | |||||
train_set, field_name="ori_tags" | |||||
) | |||||
task_vocab = Vocabulary(min_freq=None, padding=None, unknown=None).from_dataset( | |||||
train_set, field_name="task" | |||||
) | |||||
def to_index(dataset): | |||||
uni_vocab.index_dataset(dataset, field_name="ori_words", new_field_name="uni") | |||||
tag_vocab.index_dataset(dataset, field_name="ori_tags", new_field_name="tags") | |||||
task_vocab.index_dataset(dataset, field_name="task", new_field_name="task") | |||||
dataset.apply_field(lambda x: x[1:], field_name="bi1", new_field_name="bi2") | |||||
dataset.apply_field(lambda x: x[:-1], field_name="bi1", new_field_name="bi1") | |||||
bi_vocab.index_dataset(dataset, field_name="bi1", new_field_name="bi1") | |||||
bi_vocab.index_dataset(dataset, field_name="bi2", new_field_name="bi2") | |||||
dataset.set_input("task", "uni", "bi1", "bi2", "seq_len") | |||||
dataset.set_target("tags") | |||||
return dataset | |||||
train_set = to_index(train_set) | |||||
test_set = to_index(test_set) | |||||
output = {} | |||||
output["train_set"] = train_set | |||||
output["test_set"] = test_set | |||||
output["uni_vocab"] = uni_vocab | |||||
output["bi_vocab"] = bi_vocab | |||||
output["tag_vocab"] = tag_vocab | |||||
output["task_vocab"] = task_vocab | |||||
print(tag_vocab.word2idx) | |||||
print(task_vocab.word2idx) | |||||
make_sure_path_exists(os.path.dirname(options.output)) | |||||
print("Saving dataset to {}".format(os.path.abspath(options.output))) | |||||
with open(options.output, "wb") as outfile: | |||||
dump(output, outfile) | |||||
print(len(task_vocab), len(tag_vocab), len(uni_vocab), len(bi_vocab)) | |||||
dic = {} | |||||
tokens = {} | |||||
def process(words): | |||||
name = words[0][1:-1] | |||||
if name not in dic: | |||||
dic[name] = set() | |||||
tokens[name] = 0 | |||||
tokens[name] += len(words[1:]) | |||||
dic[name].update(words[1:]) | |||||
train_set.apply_field(process, "ori_words", None) | |||||
for name in dic.keys(): | |||||
print(name, len(dic[name]), tokens[name]) | |||||
with open(os.path.join(os.path.dirname(options.output), "oovdict.pkl"), "wb") as f: | |||||
dump(dic, f) | |||||
def get_max_len(ds): | |||||
global max_len | |||||
max_len = 0 | |||||
def find_max_len(words): | |||||
global max_len | |||||
if max_len < len(words): | |||||
max_len = len(words) | |||||
ds.apply_field(find_max_len, "ori_words", None) | |||||
return max_len | |||||
print( | |||||
"train max len: {}, test max len: {}".format( | |||||
get_max_len(train_set), get_max_len(test_set) | |||||
) | |||||
) | |||||
if __name__ == "__main__": | |||||
main() |
@@ -0,0 +1,506 @@ | |||||
import _pickle as pickle | |||||
import argparse | |||||
import collections | |||||
import logging | |||||
import math | |||||
import os | |||||
import pickle | |||||
import random | |||||
import sys | |||||
import time | |||||
from sys import maxsize | |||||
import fastNLP | |||||
import fastNLP.embeddings | |||||
import numpy as np | |||||
import torch | |||||
import torch.distributed as dist | |||||
import torch.nn as nn | |||||
from fastNLP import BucketSampler, DataSetIter, SequentialSampler, logger | |||||
from torch.nn.parallel import DistributedDataParallel | |||||
from torch.utils.data.distributed import DistributedSampler | |||||
import models | |||||
import optm | |||||
import utils | |||||
NONE_TAG = "<NONE>" | |||||
START_TAG = "<sos>" | |||||
END_TAG = "<eos>" | |||||
DEFAULT_WORD_EMBEDDING_SIZE = 100 | |||||
DEBUG_SCALE = 200 | |||||
# ===-----------------------------------------------------------------------=== | |||||
# Argument parsing | |||||
# ===-----------------------------------------------------------------------=== | |||||
# fmt: off | |||||
parser = argparse.ArgumentParser() | |||||
parser.add_argument("--dataset", required=True, dest="dataset", help="processed data dir") | |||||
parser.add_argument("--word-embeddings", dest="word_embeddings", help="File from which to read in pretrained embeds") | |||||
parser.add_argument("--bigram-embeddings", dest="bigram_embeddings", help="File from which to read in pretrained embeds") | |||||
parser.add_argument("--crf", dest="crf", action="store_true", help="crf") | |||||
# parser.add_argument("--devi", default="0", dest="devi", help="gpu") | |||||
parser.add_argument("--step", default=0, dest="step", type=int,help="step") | |||||
parser.add_argument("--num-epochs", default=100, dest="num_epochs", type=int, | |||||
help="Number of full passes through training set") | |||||
parser.add_argument("--batch-size", default=128, dest="batch_size", type=int, | |||||
help="Minibatch size of training set") | |||||
parser.add_argument("--d_model", default=256, dest="d_model", type=int, help="d_model") | |||||
parser.add_argument("--d_ff", default=1024, dest="d_ff", type=int, help="d_ff") | |||||
parser.add_argument("--N", default=6, dest="N", type=int, help="N") | |||||
parser.add_argument("--h", default=4, dest="h", type=int, help="h") | |||||
parser.add_argument("--factor", default=2, dest="factor", type=float, help="Initial learning rate") | |||||
parser.add_argument("--dropout", default=0.2, dest="dropout", type=float, | |||||
help="Amount of dropout(not keep rate, but drop rate) to apply to embeddings part of graph") | |||||
parser.add_argument("--log-dir", default="result", dest="log_dir", | |||||
help="Directory where to write logs / serialized models") | |||||
parser.add_argument("--task-name", default=time.strftime("%Y-%m-%d-%H-%M-%S"), dest="task_name", | |||||
help="Name for this task, use a comprehensive one") | |||||
parser.add_argument("--no-model", dest="no_model", action="store_true", help="Don't serialize model") | |||||
parser.add_argument("--always-model", dest="always_model", action="store_true", | |||||
help="Always serialize model after every epoch") | |||||
parser.add_argument("--old-model", dest="old_model", help="Path to old model for incremental training") | |||||
parser.add_argument("--skip-dev", dest="skip_dev", action="store_true", help="Skip dev set, would save some time") | |||||
parser.add_argument("--freeze", dest="freeze", action="store_true", help="freeze pretrained embedding") | |||||
parser.add_argument("--only-task", dest="only_task", action="store_true", help="only train task embedding") | |||||
parser.add_argument("--subset", dest="subset", help="Only train and test on a subset of the whole dataset") | |||||
parser.add_argument("--seclude", dest="seclude", help="train and test except a subset") | |||||
parser.add_argument("--instances", default=None, dest="instances", type=int,help="num of instances of subset") | |||||
parser.add_argument("--seed", dest="python_seed", type=int, default=random.randrange(maxsize), | |||||
help="Random seed of Python and NumPy") | |||||
parser.add_argument("--debug", dest="debug", default=False, action="store_true", help="Debug mode") | |||||
parser.add_argument("--test", dest="test", action="store_true", help="Test mode") | |||||
parser.add_argument('--local_rank', type=int, default=None) | |||||
parser.add_argument('--init_method', type=str, default='env://') | |||||
# fmt: on | |||||
options, _ = parser.parse_known_args() | |||||
print("unknown args", _) | |||||
task_name = options.task_name | |||||
root_dir = "{}/{}".format(options.log_dir, task_name) | |||||
utils.make_sure_path_exists(root_dir) | |||||
if options.local_rank is not None: | |||||
torch.cuda.set_device(options.local_rank) | |||||
dist.init_process_group("nccl", init_method=options.init_method) | |||||
def init_logger(): | |||||
if not os.path.exists(root_dir): | |||||
os.mkdir(root_dir) | |||||
log_formatter = logging.Formatter("%(asctime)s - %(message)s") | |||||
logger = logging.getLogger() | |||||
file_handler = logging.FileHandler("{0}/info.log".format(root_dir), mode="w") | |||||
file_handler.setFormatter(log_formatter) | |||||
logger.addHandler(file_handler) | |||||
console_handler = logging.StreamHandler() | |||||
console_handler.setFormatter(log_formatter) | |||||
logger.addHandler(console_handler) | |||||
if options.local_rank is None or options.local_rank == 0: | |||||
logger.setLevel(logging.INFO) | |||||
else: | |||||
logger.setLevel(logging.WARNING) | |||||
return logger | |||||
# ===-----------------------------------------------------------------------=== | |||||
# Set up logging | |||||
# ===-----------------------------------------------------------------------=== | |||||
# logger = init_logger() | |||||
logger.add_file("{}/info.log".format(root_dir), "INFO") | |||||
logger.setLevel(logging.INFO if dist.get_rank() == 0 else logging.WARNING) | |||||
# ===-----------------------------------------------------------------------=== | |||||
# Log some stuff about this run | |||||
# ===-----------------------------------------------------------------------=== | |||||
logger.info(" ".join(sys.argv)) | |||||
logger.info("") | |||||
logger.info(options) | |||||
if options.debug: | |||||
logger.info("DEBUG MODE") | |||||
options.num_epochs = 2 | |||||
options.batch_size = 20 | |||||
random.seed(options.python_seed) | |||||
np.random.seed(options.python_seed % (2 ** 32 - 1)) | |||||
torch.cuda.manual_seed_all(options.python_seed) | |||||
logger.info("Python random seed: {}".format(options.python_seed)) | |||||
# ===-----------------------------------------------------------------------=== | |||||
# Read in dataset | |||||
# ===-----------------------------------------------------------------------=== | |||||
dataset = pickle.load(open(options.dataset + "/total_dataset.pkl", "rb")) | |||||
train_set = dataset["train_set"] | |||||
test_set = dataset["test_set"] | |||||
uni_vocab = dataset["uni_vocab"] | |||||
bi_vocab = dataset["bi_vocab"] | |||||
task_vocab = dataset["task_vocab"] | |||||
tag_vocab = dataset["tag_vocab"] | |||||
for v in (bi_vocab, uni_vocab, tag_vocab, task_vocab): | |||||
if hasattr(v, "_word2idx"): | |||||
v.word2idx = v._word2idx | |||||
for ds in (train_set, test_set): | |||||
ds.rename_field("ori_words", "words") | |||||
logger.info("{} {}".format(bi_vocab.to_word(0), tag_vocab.word2idx)) | |||||
logger.info(task_vocab.word2idx) | |||||
if options.skip_dev: | |||||
dev_set = test_set | |||||
else: | |||||
train_set, dev_set = train_set.split(0.1) | |||||
logger.info("{} {} {}".format(len(train_set), len(dev_set), len(test_set))) | |||||
if options.debug: | |||||
train_set = train_set[0:DEBUG_SCALE] | |||||
dev_set = dev_set[0:DEBUG_SCALE] | |||||
test_set = test_set[0:DEBUG_SCALE] | |||||
# ===-----------------------------------------------------------------------=== | |||||
# Build model and trainer | |||||
# ===-----------------------------------------------------------------------=== | |||||
# =============================== | |||||
if dist.get_rank() != 0: | |||||
dist.barrier() | |||||
if options.word_embeddings is None: | |||||
init_embedding = None | |||||
else: | |||||
# logger.info("Load: {}".format(options.word_embeddings)) | |||||
# init_embedding = utils.embedding_load_with_cache(options.word_embeddings, options.cache_dir, uni_vocab, normalize=False) | |||||
init_embedding = fastNLP.embeddings.StaticEmbedding( | |||||
uni_vocab, options.word_embeddings, word_drop=0.01 | |||||
) | |||||
bigram_embedding = None | |||||
if options.bigram_embeddings: | |||||
# logger.info("Load: {}".format(options.bigram_embeddings)) | |||||
# bigram_embedding = utils.embedding_load_with_cache(options.bigram_embeddings, options.cache_dir, bi_vocab, normalize=False) | |||||
bigram_embedding = fastNLP.embeddings.StaticEmbedding( | |||||
bi_vocab, options.bigram_embeddings | |||||
) | |||||
if dist.get_rank() == 0: | |||||
dist.barrier() | |||||
# =============================== | |||||
# select subset training | |||||
if options.seclude is not None: | |||||
setname = "<{}>".format(options.seclude) | |||||
logger.info("seclude {}".format(setname)) | |||||
train_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||||
test_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||||
dev_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||||
if options.subset is not None: | |||||
setname = "<{}>".format(options.subset) | |||||
logger.info("select {}".format(setname)) | |||||
train_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||||
test_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||||
dev_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||||
# build model and optimizer | |||||
i2t = None | |||||
if options.crf: | |||||
# i2t=utils.to_id_list(tag_vocab.word2idx) | |||||
i2t = {} | |||||
for x, y in tag_vocab.word2idx.items(): | |||||
i2t[y] = x | |||||
logger.info(i2t) | |||||
freeze = True if options.freeze else False | |||||
model = models.make_CWS( | |||||
d_model=options.d_model, | |||||
N=options.N, | |||||
h=options.h, | |||||
d_ff=options.d_ff, | |||||
dropout=options.dropout, | |||||
word_embedding=init_embedding, | |||||
bigram_embedding=bigram_embedding, | |||||
tag_size=len(tag_vocab), | |||||
task_size=len(task_vocab), | |||||
crf=i2t, | |||||
freeze=freeze, | |||||
) | |||||
device = "cpu" | |||||
if torch.cuda.device_count() > 0: | |||||
if options.local_rank is not None: | |||||
device = "cuda:{}".format(options.local_rank) | |||||
# model=nn.DataParallel(model) | |||||
model = model.to(device) | |||||
model = torch.nn.parallel.DistributedDataParallel( | |||||
model, device_ids=[options.local_rank], output_device=options.local_rank | |||||
) | |||||
else: | |||||
device = "cuda:0" | |||||
model.to(device) | |||||
if options.only_task and options.old_model is not None: | |||||
logger.info("fix para except task embedding") | |||||
for name, para in model.named_parameters(): | |||||
if name.find("task_embed") == -1: | |||||
para.requires_grad = False | |||||
else: | |||||
para.requires_grad = True | |||||
logger.info(name) | |||||
optimizer = optm.NoamOpt( | |||||
options.d_model, | |||||
options.factor, | |||||
4000, | |||||
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9), | |||||
) | |||||
optimizer._step = options.step | |||||
best_model_file_name = "{}/model.bin".format(root_dir) | |||||
if options.local_rank is None: | |||||
train_sampler = BucketSampler( | |||||
batch_size=options.batch_size, seq_len_field_name="seq_len" | |||||
) | |||||
else: | |||||
train_sampler = DistributedSampler( | |||||
train_set, dist.get_world_size(), dist.get_rank() | |||||
) | |||||
dev_sampler = SequentialSampler() | |||||
i2t = utils.to_id_list(tag_vocab.word2idx) | |||||
i2task = utils.to_id_list(task_vocab.word2idx) | |||||
dev_set.set_input("words") | |||||
test_set.set_input("words") | |||||
test_batch = DataSetIter(test_set, options.batch_size, num_workers=2) | |||||
word_dic = pickle.load(open(options.dataset + "/oovdict.pkl", "rb")) | |||||
def batch_to_device(batch, device): | |||||
for k, v in batch.items(): | |||||
if torch.is_tensor(v): | |||||
batch[k] = v.to(device) | |||||
return batch | |||||
def tester(model, test_batch, write_out=False): | |||||
res = [] | |||||
prf = utils.CWSEvaluator(i2t) | |||||
prf_dataset = {} | |||||
oov_dataset = {} | |||||
logger.info("start evaluation") | |||||
# import ipdb; ipdb.set_trace() | |||||
with torch.no_grad(): | |||||
for batch_x, batch_y in test_batch: | |||||
batch_to_device(batch_x, device) | |||||
# batch_to_device(batch_y, device) | |||||
if bigram_embedding is not None: | |||||
out = model( | |||||
batch_x["task"], | |||||
batch_x["uni"], | |||||
batch_x["seq_len"], | |||||
batch_x["bi1"], | |||||
batch_x["bi2"], | |||||
) | |||||
else: | |||||
out = model(batch_x["task"], batch_x["uni"], batch_x["seq_len"]) | |||||
out = out["pred"] | |||||
# print(out) | |||||
num = out.size(0) | |||||
out = out.detach().cpu().numpy() | |||||
for i in range(num): | |||||
length = int(batch_x["seq_len"][i]) | |||||
out_tags = out[i, 1:length].tolist() | |||||
sentence = batch_x["words"][i] | |||||
gold_tags = batch_y["tags"][i][1:length].numpy().tolist() | |||||
dataset_name = sentence[0] | |||||
sentence = sentence[1:] | |||||
# print(out_tags,gold_tags) | |||||
assert utils.is_dataset_tag(dataset_name), dataset_name | |||||
assert len(gold_tags) == len(out_tags) and len(gold_tags) == len( | |||||
sentence | |||||
) | |||||
if dataset_name not in prf_dataset: | |||||
prf_dataset[dataset_name] = utils.CWSEvaluator(i2t) | |||||
oov_dataset[dataset_name] = utils.CWS_OOV( | |||||
word_dic[dataset_name[1:-1]] | |||||
) | |||||
prf_dataset[dataset_name].add_instance(gold_tags, out_tags) | |||||
prf.add_instance(gold_tags, out_tags) | |||||
if write_out: | |||||
gold_strings = utils.to_tag_strings(i2t, gold_tags) | |||||
obs_strings = utils.to_tag_strings(i2t, out_tags) | |||||
word_list = utils.bmes_to_words(sentence, obs_strings) | |||||
oov_dataset[dataset_name].update( | |||||
utils.bmes_to_words(sentence, gold_strings), word_list | |||||
) | |||||
raw_string = " ".join(word_list) | |||||
res.append(dataset_name + " " + raw_string + " " + dataset_name) | |||||
Ap = 0.0 | |||||
Ar = 0.0 | |||||
Af = 0.0 | |||||
Aoov = 0.0 | |||||
tot = 0 | |||||
nw = 0.0 | |||||
for dataset_name, performance in sorted(prf_dataset.items()): | |||||
p = performance.result() | |||||
if write_out: | |||||
nw = oov_dataset[dataset_name].oov() | |||||
# nw = 0 | |||||
logger.info( | |||||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||||
dataset_name, p[0], p[1], p[2], nw | |||||
) | |||||
) | |||||
else: | |||||
logger.info( | |||||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||||
dataset_name, p[0], p[1], p[2] | |||||
) | |||||
) | |||||
Ap += p[0] | |||||
Ar += p[1] | |||||
Af += p[2] | |||||
Aoov += nw | |||||
tot += 1 | |||||
prf = prf.result() | |||||
logger.info( | |||||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format("TOT", prf[0], prf[1], prf[2]) | |||||
) | |||||
if not write_out: | |||||
logger.info( | |||||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||||
"AVG", Ap / tot, Ar / tot, Af / tot | |||||
) | |||||
) | |||||
else: | |||||
logger.info( | |||||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||||
"AVG", Ap / tot, Ar / tot, Af / tot, Aoov / tot | |||||
) | |||||
) | |||||
return prf[-1], res | |||||
# start training | |||||
if not options.test: | |||||
if options.old_model: | |||||
# incremental training | |||||
logger.info("Incremental training from old model: {}".format(options.old_model)) | |||||
model.load_state_dict(torch.load(options.old_model, map_location="cuda:0")) | |||||
logger.info("Number training instances: {}".format(len(train_set))) | |||||
logger.info("Number dev instances: {}".format(len(dev_set))) | |||||
train_batch = DataSetIter( | |||||
batch_size=options.batch_size, | |||||
dataset=train_set, | |||||
sampler=train_sampler, | |||||
num_workers=4, | |||||
) | |||||
dev_batch = DataSetIter( | |||||
batch_size=options.batch_size, | |||||
dataset=dev_set, | |||||
sampler=dev_sampler, | |||||
num_workers=4, | |||||
) | |||||
best_f1 = 0.0 | |||||
for epoch in range(int(options.num_epochs)): | |||||
logger.info("Epoch {} out of {}".format(epoch + 1, options.num_epochs)) | |||||
train_loss = 0.0 | |||||
model.train() | |||||
tot = 0 | |||||
t1 = time.time() | |||||
for batch_x, batch_y in train_batch: | |||||
model.zero_grad() | |||||
if bigram_embedding is not None: | |||||
out = model( | |||||
batch_x["task"], | |||||
batch_x["uni"], | |||||
batch_x["seq_len"], | |||||
batch_x["bi1"], | |||||
batch_x["bi2"], | |||||
batch_y["tags"], | |||||
) | |||||
else: | |||||
out = model( | |||||
batch_x["task"], batch_x["uni"], batch_x["seq_len"], batch_y["tags"] | |||||
) | |||||
loss = out["loss"] | |||||
train_loss += loss.item() | |||||
tot += 1 | |||||
loss.backward() | |||||
# nn.utils.clip_grad_value_(model.parameters(), 1) | |||||
optimizer.step() | |||||
t2 = time.time() | |||||
train_loss = train_loss / tot | |||||
logger.info( | |||||
"time: {} loss: {} step: {}".format(t2 - t1, train_loss, optimizer._step) | |||||
) | |||||
# Evaluate dev data | |||||
if options.skip_dev and dist.get_rank() == 0: | |||||
logger.info("Saving model to {}".format(best_model_file_name)) | |||||
torch.save(model.module.state_dict(), best_model_file_name) | |||||
continue | |||||
model.eval() | |||||
if dist.get_rank() == 0: | |||||
f1, _ = tester(model.module, dev_batch) | |||||
if f1 > best_f1: | |||||
best_f1 = f1 | |||||
logger.info("- new best score!") | |||||
if not options.no_model: | |||||
logger.info("Saving model to {}".format(best_model_file_name)) | |||||
torch.save(model.module.state_dict(), best_model_file_name) | |||||
elif options.always_model: | |||||
logger.info("Saving model to {}".format(best_model_file_name)) | |||||
torch.save(model.module.state_dict(), best_model_file_name) | |||||
dist.barrier() | |||||
# Evaluate test data (once) | |||||
logger.info("\nNumber test instances: {}".format(len(test_set))) | |||||
if not options.skip_dev: | |||||
if options.test: | |||||
model.module.load_state_dict( | |||||
torch.load(options.old_model, map_location="cuda:0") | |||||
) | |||||
else: | |||||
model.module.load_state_dict( | |||||
torch.load(best_model_file_name, map_location="cuda:0") | |||||
) | |||||
if dist.get_rank() == 0: | |||||
for name, para in model.named_parameters(): | |||||
if name.find("task_embed") != -1: | |||||
tm = para.detach().cpu().numpy() | |||||
logger.info(tm.shape) | |||||
np.save("{}/task.npy".format(root_dir), tm) | |||||
break | |||||
_, res = tester(model.module, test_batch, True) | |||||
if dist.get_rank() == 0: | |||||
with open("{}/testout.txt".format(root_dir), "w", encoding="utf-8") as raw_writer: | |||||
for sent in res: | |||||
raw_writer.write(sent) | |||||
raw_writer.write("\n") | |||||
@@ -0,0 +1,14 @@ | |||||
if [ -z "$DATA_DIR" ] | |||||
then | |||||
DATA_DIR="./data" | |||||
fi | |||||
mkdir -vp $DATA_DIR | |||||
cmd="python -u ./data-prepare.py --sighan05 $1 --sighan08 $2 --data_path $DATA_DIR" | |||||
echo $cmd | |||||
eval $cmd | |||||
cmd="python -u ./data-process.py --data_path $DATA_DIR" | |||||
echo $cmd | |||||
eval $cmd |
@@ -0,0 +1,200 @@ | |||||
import fastNLP | |||||
import torch | |||||
import math | |||||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||||
from fastNLP.modules.decoder.crf import ConditionalRandomField | |||||
from fastNLP import Const | |||||
import copy | |||||
import numpy as np | |||||
from torch.autograd import Variable | |||||
import torch.autograd as autograd | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import transformer | |||||
class PositionalEncoding(nn.Module): | |||||
"Implement the PE function." | |||||
def __init__(self, d_model, dropout, max_len=512): | |||||
super(PositionalEncoding, self).__init__() | |||||
self.dropout = nn.Dropout(p=dropout) | |||||
# Compute the positional encodings once in log space. | |||||
pe = torch.zeros(max_len, d_model).float() | |||||
position = torch.arange(0, max_len).unsqueeze(1).float() | |||||
div_term = torch.exp( | |||||
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | |||||
) | |||||
pe[:, 0::2] = torch.sin(position * div_term) | |||||
pe[:, 1::2] = torch.cos(position * div_term) | |||||
pe = pe.unsqueeze(0) | |||||
self.register_buffer("pe", pe) | |||||
def forward(self, x): | |||||
x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False) | |||||
return self.dropout(x) | |||||
class Embedding(nn.Module): | |||||
def __init__( | |||||
self, | |||||
task_size, | |||||
d_model, | |||||
word_embedding=None, | |||||
bi_embedding=None, | |||||
word_size=None, | |||||
freeze=True, | |||||
): | |||||
super(Embedding, self).__init__() | |||||
self.task_size = task_size | |||||
self.embed_dim = 0 | |||||
self.task_embed = nn.Embedding(task_size, d_model) | |||||
if word_embedding is not None: | |||||
# self.uni_embed = nn.Embedding.from_pretrained(torch.FloatTensor(word_embedding), freeze=freeze) | |||||
# self.embed_dim+=word_embedding.shape[1] | |||||
self.uni_embed = word_embedding | |||||
self.embed_dim += word_embedding.embedding_dim | |||||
else: | |||||
if bi_embedding is not None: | |||||
self.embed_dim += bi_embedding.shape[1] | |||||
else: | |||||
self.embed_dim = d_model | |||||
assert word_size is not None | |||||
self.uni_embed = Embedding(word_size, self.embed_dim) | |||||
if bi_embedding is not None: | |||||
# self.bi_embed = nn.Embedding.from_pretrained(torch.FloatTensor(bi_embedding), freeze=freeze) | |||||
# self.embed_dim += bi_embedding.shape[1]*2 | |||||
self.bi_embed = bi_embedding | |||||
self.embed_dim += bi_embedding.embedding_dim * 2 | |||||
print("Trans Freeze", freeze, self.embed_dim) | |||||
if d_model != self.embed_dim: | |||||
self.F = nn.Linear(self.embed_dim, d_model) | |||||
else: | |||||
self.F = None | |||||
self.d_model = d_model | |||||
def forward(self, task, uni, bi1=None, bi2=None): | |||||
y_task = self.task_embed(task[:, 0:1]) | |||||
y = self.uni_embed(uni[:, 1:]) | |||||
if bi1 is not None: | |||||
assert self.bi_embed is not None | |||||
y = torch.cat([y, self.bi_embed(bi1), self.bi_embed(bi2)], dim=-1) | |||||
# y2=self.bi_embed(bi) | |||||
# y=torch.cat([y,y2[:,:-1,:],y2[:,1:,:]],dim=-1) | |||||
# y=torch.cat([y_task,y],dim=1) | |||||
if self.F is not None: | |||||
y = self.F(y) | |||||
y = torch.cat([y_task, y], dim=1) | |||||
return y * math.sqrt(self.d_model) | |||||
def seq_len_to_mask(seq_len, max_len=None): | |||||
if isinstance(seq_len, np.ndarray): | |||||
assert ( | |||||
len(np.shape(seq_len)) == 1 | |||||
), f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||||
if max_len is None: | |||||
max_len = int(seq_len.max()) | |||||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||||
elif isinstance(seq_len, torch.Tensor): | |||||
assert ( | |||||
seq_len.dim() == 1 | |||||
), f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||||
batch_size = seq_len.size(0) | |||||
if max_len is None: | |||||
max_len = seq_len.max().long() | |||||
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) | |||||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||||
else: | |||||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | |||||
return mask | |||||
class CWSModel(nn.Module): | |||||
def __init__(self, encoder, src_embed, position, d_model, tag_size, crf=None): | |||||
super(CWSModel, self).__init__() | |||||
self.encoder = encoder | |||||
self.src_embed = src_embed | |||||
self.pos = copy.deepcopy(position) | |||||
self.proj = nn.Linear(d_model, tag_size) | |||||
self.tag_size = tag_size | |||||
if crf is None: | |||||
self.crf = None | |||||
self.loss_f = nn.CrossEntropyLoss(reduction="mean", ignore_index=-100) | |||||
else: | |||||
print("crf") | |||||
trans = fastNLP.modules.decoder.crf.allowed_transitions( | |||||
crf, encoding_type="bmes" | |||||
) | |||||
self.crf = ConditionalRandomField(tag_size, allowed_transitions=trans) | |||||
# self.norm=nn.LayerNorm(d_model) | |||||
def forward(self, task, uni, seq_len, bi1=None, bi2=None, tags=None): | |||||
# mask=fastNLP.core.utils.seq_len_to_mask(seq_len,uni.size(1)) # for dev 0.5.1 | |||||
mask = seq_len_to_mask(seq_len, uni.size(1)) | |||||
out = self.src_embed(task, uni, bi1, bi2) | |||||
out = self.pos(out) | |||||
# out=self.norm(out) | |||||
out = self.proj(self.encoder(out, mask.float())) | |||||
if self.crf is not None: | |||||
if tags is not None: | |||||
out = self.crf(out, tags, mask) | |||||
return {"loss": out} | |||||
else: | |||||
out, _ = self.crf.viterbi_decode(out, mask) | |||||
return {"pred": out} | |||||
else: | |||||
if tags is not None: | |||||
out = out.contiguous().view(-1, self.tag_size) | |||||
tags = tags.data.masked_fill_(mask == 0, -100).view(-1) | |||||
loss = self.loss_f(out, tags) | |||||
return {"loss": loss} | |||||
else: | |||||
out = torch.argmax(out, dim=-1) | |||||
return {"pred": out} | |||||
def make_CWS( | |||||
N=6, | |||||
d_model=256, | |||||
d_ff=1024, | |||||
h=4, | |||||
dropout=0.2, | |||||
tag_size=4, | |||||
task_size=8, | |||||
bigram_embedding=None, | |||||
word_embedding=None, | |||||
word_size=None, | |||||
crf=None, | |||||
freeze=True, | |||||
): | |||||
c = copy.deepcopy | |||||
# encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) | |||||
encoder = transformer.make_encoder( | |||||
N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff | |||||
) | |||||
position = PositionalEncoding(d_model, dropout) | |||||
embed = Embedding( | |||||
task_size, d_model, word_embedding, bigram_embedding, word_size, freeze | |||||
) | |||||
model = CWSModel(encoder, embed, position, d_model, tag_size, crf=crf) | |||||
for p in model.parameters(): | |||||
if p.dim() > 1 and p.requires_grad: | |||||
nn.init.xavier_uniform_(p) | |||||
return model |
@@ -0,0 +1,49 @@ | |||||
import torch | |||||
import torch.optim as optim | |||||
class NoamOpt: | |||||
"Optim wrapper that implements rate." | |||||
def __init__(self, model_size, factor, warmup, optimizer): | |||||
self.optimizer = optimizer | |||||
self._step = 0 | |||||
self.warmup = warmup | |||||
self.factor = factor | |||||
self.model_size = model_size | |||||
self._rate = 0 | |||||
def step(self): | |||||
"Update parameters and rate" | |||||
self._step += 1 | |||||
rate = self.rate() | |||||
for p in self.optimizer.param_groups: | |||||
p["lr"] = rate | |||||
self._rate = rate | |||||
self.optimizer.step() | |||||
def rate(self, step=None): | |||||
"Implement `lrate` above" | |||||
if step is None: | |||||
step = self._step | |||||
lr = self.factor * ( | |||||
self.model_size ** (-0.5) | |||||
* min(step ** (-0.5), step * self.warmup ** (-1.5)) | |||||
) | |||||
# if step>self.warmup: lr = max(1e-4,lr) | |||||
return lr | |||||
def get_std_opt(model): | |||||
return NoamOpt( | |||||
model.src_embed[0].d_model, | |||||
2, | |||||
4000, | |||||
torch.optim.Adam( | |||||
filter(lambda p: p.requires_grad, model.parameters()), | |||||
lr=0, | |||||
betas=(0.9, 0.98), | |||||
eps=1e-9, | |||||
), | |||||
) | |||||
@@ -0,0 +1,138 @@ | |||||
from fastNLP import (Trainer, Tester, Callback, GradientClipCallback, LRScheduler, SpanFPreRecMetric) | |||||
import torch | |||||
import torch.cuda | |||||
from torch.optim import Adam, SGD | |||||
from argparse import ArgumentParser | |||||
import logging | |||||
from .utils import set_seed | |||||
class LoggingCallback(Callback): | |||||
def __init__(self, filepath=None): | |||||
super().__init__() | |||||
# create file handler and set level to debug | |||||
if filepath is not None: | |||||
file_handler = logging.FileHandler(filepath, "a") | |||||
else: | |||||
file_handler = logging.StreamHandler() | |||||
file_handler.setLevel(logging.DEBUG) | |||||
file_handler.setFormatter( | |||||
logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||||
datefmt='%m/%d/%Y %H:%M:%S')) | |||||
# create logger and set level to debug | |||||
logger = logging.getLogger() | |||||
logger.handlers = [] | |||||
logger.setLevel(logging.DEBUG) | |||||
logger.propagate = False | |||||
logger.addHandler(file_handler) | |||||
self.log_writer = logger | |||||
def on_backward_begin(self, loss): | |||||
if self.step % self.trainer.print_every == 0: | |||||
self.log_writer.info( | |||||
'Step/Epoch {}/{}: Loss {}'.format(self.step, self.epoch, loss.item())) | |||||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||||
self.log_writer.info( | |||||
'Step/Epoch {}/{}: Eval result {}'.format(self.step, self.epoch, eval_result)) | |||||
def on_backward_end(self): | |||||
pass | |||||
def main(): | |||||
parser = ArgumentParser() | |||||
register_args(parser) | |||||
args = parser.parse_known_args()[0] | |||||
set_seed(args.seed) | |||||
if args.train: | |||||
train(args) | |||||
if args.eval: | |||||
evaluate(args) | |||||
def get_optim(args): | |||||
name = args.optim.strip().split(' ')[0].lower() | |||||
p = args.optim.strip() | |||||
l = p.find('(') | |||||
r = p.find(')') | |||||
optim_args = eval('dict({})'.format(p[[l+1,r]])) | |||||
if name == 'sgd': | |||||
return SGD(**optim_args) | |||||
elif name == 'adam': | |||||
return Adam(**optim_args) | |||||
else: | |||||
raise ValueError(args.optim) | |||||
def load_model_from_path(args): | |||||
pass | |||||
def train(args): | |||||
data = get_data(args) | |||||
train_data = data['train'] | |||||
dev_data = data['dev'] | |||||
model = get_model(args) | |||||
optimizer = get_optim(args) | |||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
callbacks = [] | |||||
trainer = Trainer( | |||||
train_data=train_data, | |||||
model=model, | |||||
optimizer=optimizer, | |||||
loss=None, | |||||
batch_size=args.batch_size, | |||||
n_epochs=args.epochs, | |||||
num_workers=4, | |||||
metrics=SpanFPreRecMetric( | |||||
tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], | |||||
ignore_labels=data['ignore_labels']), | |||||
metric_key='f1', | |||||
dev_data=dev_data, | |||||
save_path=args.save_path, | |||||
device=device, | |||||
callbacks=callbacks, | |||||
check_code_level=-1, | |||||
) | |||||
print(trainer.train()) | |||||
def evaluate(args): | |||||
data = get_data(args) | |||||
test_data = data['test'] | |||||
model = load_model_from_path(args) | |||||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||||
tester = Tester( | |||||
data=test_data, model=model, batch_size=args.batch_size, | |||||
num_workers=2, device=device, | |||||
metrics=SpanFPreRecMetric( | |||||
tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], | |||||
ignore_labels=data['ignore_labels']), | |||||
) | |||||
print(tester.test()) | |||||
def register_args(parser): | |||||
parser.add_argument('--optim', type=str, default='adam (lr=2e-3, weight_decay=0.0)') | |||||
parser.add_argument('--batch_size', type=int, default=128) | |||||
parser.add_argument('--epochs', type=int, default=10) | |||||
parser.add_argument('--save_path', type=str, default=None) | |||||
parser.add_argument('--data_path', type=str, required=True) | |||||
parser.add_argument('--log_path', type=str, default=None) | |||||
parser.add_argument('--model_config', type=str, required=True) | |||||
parser.add_argument('--load_path', type=str, default=None) | |||||
parser.add_argument('--train', action='store_true', default=False) | |||||
parser.add_argument('--eval', action='store_true', default=False) | |||||
parser.add_argument('--seed', type=int, default=42, help='rng seed') | |||||
def get_model(args): | |||||
pass | |||||
def get_data(args): | |||||
return torch.load(args.data_path) | |||||
if __name__ == '__main__': | |||||
main() |
@@ -0,0 +1,26 @@ | |||||
export EXP_NAME=release04 | |||||
export NGPU=2 | |||||
export PORT=9988 | |||||
export CUDA_DEVICE_ORDER=PCI_BUS_ID | |||||
export CUDA_VISIBLE_DEVICES=$1 | |||||
if [ -z "$DATA_DIR" ] | |||||
then | |||||
DATA_DIR="./data" | |||||
fi | |||||
echo $CUDA_VISIBLE_DEVICES | |||||
cmd=" | |||||
python -m torch.distributed.launch --nproc_per_node=$NGPU --master_port $PORT\ | |||||
main.py \ | |||||
--word-embeddings cn-char-fastnlp-100d \ | |||||
--bigram-embeddings cn-bi-fastnlp-100d \ | |||||
--num-epochs 100 \ | |||||
--batch-size 256 \ | |||||
--seed 1234 \ | |||||
--task-name $EXP_NAME \ | |||||
--dataset $DATA_DIR \ | |||||
--freeze \ | |||||
" | |||||
echo $cmd | |||||
eval $cmd |
@@ -0,0 +1,152 @@ | |||||
import numpy as np | |||||
import torch | |||||
import torch.autograd as autograd | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import math, copy, time | |||||
from torch.autograd import Variable | |||||
# import matplotlib.pyplot as plt | |||||
def clones(module, N): | |||||
"Produce N identical layers." | |||||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |||||
def subsequent_mask(size): | |||||
"Mask out subsequent positions." | |||||
attn_shape = (1, size, size) | |||||
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") | |||||
return torch.from_numpy(subsequent_mask) == 0 | |||||
def attention(query, key, value, mask=None, dropout=None): | |||||
"Compute 'Scaled Dot Product Attention'" | |||||
d_k = query.size(-1) | |||||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | |||||
if mask is not None: | |||||
# print(scores.size(),mask.size()) # [bsz,1,1,len] | |||||
scores = scores.masked_fill(mask == 0, -1e9) | |||||
p_attn = F.softmax(scores, dim=-1) | |||||
if dropout is not None: | |||||
p_attn = dropout(p_attn) | |||||
return torch.matmul(p_attn, value), p_attn | |||||
class MultiHeadedAttention(nn.Module): | |||||
def __init__(self, h, d_model, dropout=0.1): | |||||
"Take in model size and number of heads." | |||||
super(MultiHeadedAttention, self).__init__() | |||||
assert d_model % h == 0 | |||||
# We assume d_v always equals d_k | |||||
self.d_k = d_model // h | |||||
self.h = h | |||||
self.linears = clones(nn.Linear(d_model, d_model), 4) | |||||
self.attn = None | |||||
self.dropout = nn.Dropout(p=dropout) | |||||
def forward(self, query, key, value, mask=None): | |||||
"Implements Figure 2" | |||||
if mask is not None: | |||||
# Same mask applied to all h heads. | |||||
mask = mask.unsqueeze(1) | |||||
nbatches = query.size(0) | |||||
# 1) Do all the linear projections in batch from d_model => h x d_k | |||||
query, key, value = [ | |||||
l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |||||
for l, x in zip(self.linears, (query, key, value)) | |||||
] | |||||
# 2) Apply attention on all the projected vectors in batch. | |||||
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) | |||||
# 3) "Concat" using a view and apply a final linear. | |||||
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) | |||||
return self.linears[-1](x) | |||||
class LayerNorm(nn.Module): | |||||
"Construct a layernorm module (See citation for details)." | |||||
def __init__(self, features, eps=1e-6): | |||||
super(LayerNorm, self).__init__() | |||||
self.a_2 = nn.Parameter(torch.ones(features)) | |||||
self.b_2 = nn.Parameter(torch.zeros(features)) | |||||
self.eps = eps | |||||
def forward(self, x): | |||||
mean = x.mean(-1, keepdim=True) | |||||
std = x.std(-1, keepdim=True) | |||||
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 | |||||
class PositionwiseFeedForward(nn.Module): | |||||
"Implements FFN equation." | |||||
def __init__(self, d_model, d_ff, dropout=0.1): | |||||
super(PositionwiseFeedForward, self).__init__() | |||||
self.w_1 = nn.Linear(d_model, d_ff) | |||||
self.w_2 = nn.Linear(d_ff, d_model) | |||||
self.dropout = nn.Dropout(dropout) | |||||
def forward(self, x): | |||||
return self.w_2(self.dropout(F.relu(self.w_1(x)))) | |||||
class SublayerConnection(nn.Module): | |||||
""" | |||||
A residual connection followed by a layer norm. | |||||
Note for code simplicity the norm is first as opposed to last. | |||||
""" | |||||
def __init__(self, size, dropout): | |||||
super(SublayerConnection, self).__init__() | |||||
self.norm = LayerNorm(size) | |||||
self.dropout = nn.Dropout(dropout) | |||||
def forward(self, x, sublayer): | |||||
"Apply residual connection to any sublayer with the same size." | |||||
return x + self.dropout(sublayer(self.norm(x))) | |||||
class EncoderLayer(nn.Module): | |||||
"Encoder is made up of self-attn and feed forward (defined below)" | |||||
def __init__(self, size, self_attn, feed_forward, dropout): | |||||
super(EncoderLayer, self).__init__() | |||||
self.self_attn = self_attn | |||||
self.feed_forward = feed_forward | |||||
self.sublayer = clones(SublayerConnection(size, dropout), 2) | |||||
self.size = size | |||||
def forward(self, x, mask): | |||||
"Follow Figure 1 (left) for connections." | |||||
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) | |||||
return self.sublayer[1](x, self.feed_forward) | |||||
class Encoder(nn.Module): | |||||
"Core encoder is a stack of N layers" | |||||
def __init__(self, layer, N): | |||||
super(Encoder, self).__init__() | |||||
self.layers = clones(layer, N) | |||||
self.norm = LayerNorm(layer.size) | |||||
def forward(self, x, mask): | |||||
# print(x.size(),mask.size()) | |||||
"Pass the input (and mask) through each layer in turn." | |||||
mask = mask.byte().unsqueeze(-2) | |||||
for layer in self.layers: | |||||
x = layer(x, mask) | |||||
return self.norm(x) | |||||
def make_encoder(N=6, d_model=512, d_ff=2048, h=8, dropout=0.1): | |||||
c = copy.deepcopy | |||||
attn = MultiHeadedAttention(h, d_model) | |||||
ff = PositionwiseFeedForward(d_model, d_ff, dropout) | |||||
return Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N) |
@@ -0,0 +1,308 @@ | |||||
import numpy as np | |||||
import torch | |||||
import torch.cuda | |||||
import random | |||||
import os | |||||
import sys | |||||
import errno | |||||
import time | |||||
import codecs | |||||
import hashlib | |||||
import _pickle as pickle | |||||
import warnings | |||||
from fastNLP.io import EmbedLoader | |||||
UNK_TAG = "<unk>" | |||||
def set_seed(seed): | |||||
random.seed(seed) | |||||
np.random.seed(seed) | |||||
torch.manual_seed(seed) | |||||
torch.cuda.manual_seed_all(seed) | |||||
def bmes_to_words(chars, tags): | |||||
result = [] | |||||
if len(chars) == 0: | |||||
return result | |||||
word = chars[0] | |||||
for c, t in zip(chars[1:], tags[1:]): | |||||
if t.upper() == "B" or t.upper() == "S": | |||||
result.append(word) | |||||
word = "" | |||||
word += c | |||||
if len(word) != 0: | |||||
result.append(word) | |||||
return result | |||||
def bmes_to_index(tags): | |||||
result = [] | |||||
if len(tags) == 0: | |||||
return result | |||||
word = (0, 0) | |||||
for i, t in enumerate(tags): | |||||
if i == 0: | |||||
word = (0, 0) | |||||
elif t.upper() == "B" or t.upper() == "S": | |||||
result.append(word) | |||||
word = (i, 0) | |||||
word = (word[0], word[1] + 1) | |||||
if word[1] != 0: | |||||
result.append(word) | |||||
return result | |||||
def get_bmes(sent): | |||||
x = [] | |||||
y = [] | |||||
for word in sent: | |||||
length = len(word) | |||||
tag = ["m"] * length if length > 1 else ["s"] * length | |||||
if length > 1: | |||||
tag[0] = "b" | |||||
tag[-1] = "e" | |||||
x += list(word) | |||||
y += tag | |||||
return x, y | |||||
class CWSEvaluator: | |||||
def __init__(self, i2t): | |||||
self.correct_preds = 0.0 | |||||
self.total_preds = 0.0 | |||||
self.total_correct = 0.0 | |||||
self.i2t = i2t | |||||
def add_instance(self, pred_tags, gold_tags): | |||||
pred_tags = [self.i2t[i] for i in pred_tags] | |||||
gold_tags = [self.i2t[i] for i in gold_tags] | |||||
# Evaluate PRF | |||||
lab_gold_chunks = set(bmes_to_index(gold_tags)) | |||||
lab_pred_chunks = set(bmes_to_index(pred_tags)) | |||||
self.correct_preds += len(lab_gold_chunks & lab_pred_chunks) | |||||
self.total_preds += len(lab_pred_chunks) | |||||
self.total_correct += len(lab_gold_chunks) | |||||
def result(self, percentage=True): | |||||
p = self.correct_preds / self.total_preds if self.correct_preds > 0 else 0 | |||||
r = self.correct_preds / self.total_correct if self.correct_preds > 0 else 0 | |||||
f1 = 2 * p * r / (p + r) if p + r > 0 else 0 | |||||
if percentage: | |||||
p *= 100 | |||||
r *= 100 | |||||
f1 *= 100 | |||||
return p, r, f1 | |||||
class CWS_OOV: | |||||
def __init__(self, dic): | |||||
self.dic = dic | |||||
self.recall = 0 | |||||
self.tot = 0 | |||||
def update(self, gold_sent, pred_sent): | |||||
i = 0 | |||||
j = 0 | |||||
id = 0 | |||||
for w in gold_sent: | |||||
if w not in self.dic: | |||||
self.tot += 1 | |||||
while i + len(pred_sent[id]) <= j: | |||||
i += len(pred_sent[id]) | |||||
id += 1 | |||||
if ( | |||||
i == j | |||||
and len(pred_sent[id]) == len(w) | |||||
and w.find(pred_sent[id]) != -1 | |||||
): | |||||
self.recall += 1 | |||||
j += len(w) | |||||
# print(gold_sent,pred_sent,self.tot) | |||||
def oov(self, percentage=True): | |||||
ins = 1.0 * self.recall / self.tot | |||||
if percentage: | |||||
ins *= 100 | |||||
return ins | |||||
def get_processing_word( | |||||
vocab_words=None, vocab_chars=None, lowercase=False, chars=False | |||||
): | |||||
def f(word): | |||||
# 0. get chars of words | |||||
if vocab_chars is not None and chars: | |||||
char_ids = [] | |||||
for char in word: | |||||
# ignore chars out of vocabulary | |||||
if char in vocab_chars: | |||||
char_ids += [vocab_chars[char]] | |||||
# 1. preprocess word | |||||
if lowercase: | |||||
word = word.lower() | |||||
if word.isdigit(): | |||||
word = "0" | |||||
# 2. get id of word | |||||
if vocab_words is not None: | |||||
if word in vocab_words: | |||||
word = vocab_words[word] | |||||
else: | |||||
word = vocab_words[UNK_TAG] | |||||
# 3. return tuple char ids, word id | |||||
if vocab_chars is not None and chars: | |||||
return char_ids, word | |||||
else: | |||||
return word | |||||
return f | |||||
def append_tags(src, des, name, part, encode="utf-16"): | |||||
with open("{}/{}.txt".format(src, part), encoding=encode) as input, open( | |||||
"{}/{}.txt".format(des, part), "a", encoding=encode | |||||
) as output: | |||||
for line in input: | |||||
line = line.strip() | |||||
if len(line) > 0: | |||||
output.write("<{}> {} </{}>".format(name, line, name)) | |||||
output.write("\n") | |||||
def is_dataset_tag(word): | |||||
return len(word) > 2 and word[0] == "<" and word[-1] == ">" | |||||
def to_tag_strings(i2ts, tag_mapping, pos_separate_col=True): | |||||
senlen = len(tag_mapping) | |||||
key_value_strs = [] | |||||
for j in range(senlen): | |||||
val = i2ts[tag_mapping[j]] | |||||
pos_str = val | |||||
key_value_strs.append(pos_str) | |||||
return key_value_strs | |||||
def to_id_list(w2i): | |||||
i2w = [None] * len(w2i) | |||||
for w, i in w2i.items(): | |||||
i2w[i] = w | |||||
return i2w | |||||
def make_sure_path_exists(path): | |||||
try: | |||||
os.makedirs(path) | |||||
except OSError as exception: | |||||
if exception.errno != errno.EEXIST: | |||||
raise | |||||
def md5_for_file(fn): | |||||
md5 = hashlib.md5() | |||||
with open(fn, "rb") as f: | |||||
for chunk in iter(lambda: f.read(128 * md5.block_size), b""): | |||||
md5.update(chunk) | |||||
return md5.hexdigest() | |||||
def embedding_match_vocab( | |||||
vocab, | |||||
emb, | |||||
ori_vocab, | |||||
dtype=np.float32, | |||||
padding="<pad>", | |||||
unknown="<unk>", | |||||
normalize=True, | |||||
error="ignore", | |||||
init_method=None, | |||||
): | |||||
dim = emb.shape[-1] | |||||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||||
if init_method: | |||||
matrix = init_method(matrix) | |||||
for word, idx in ori_vocab.word2idx.items(): | |||||
try: | |||||
if word == padding and vocab.padding is not None: | |||||
word = vocab.padding | |||||
elif word == unknown and vocab.unknown is not None: | |||||
word = vocab.unknown | |||||
if word in vocab: | |||||
index = vocab.to_index(word) | |||||
matrix[index] = emb[idx] | |||||
hit_flags[index] = True | |||||
except Exception as e: | |||||
if error == "ignore": | |||||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||||
else: | |||||
print("Error occurred at the {} line.".format(idx)) | |||||
raise e | |||||
total_hits = np.sum(hit_flags) | |||||
print( | |||||
"Found {} out of {} words in the pre-training embedding.".format( | |||||
total_hits, len(vocab) | |||||
) | |||||
) | |||||
if init_method is None: | |||||
found_vectors = matrix[hit_flags] | |||||
if len(found_vectors) != 0: | |||||
mean = np.mean(found_vectors, axis=0, keepdims=True) | |||||
std = np.std(found_vectors, axis=0, keepdims=True) | |||||
unfound_vec_num = len(vocab) - total_hits | |||||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean | |||||
matrix[hit_flags == False] = r_vecs | |||||
if normalize: | |||||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||||
return matrix | |||||
def embedding_load_with_cache(emb_file, cache_dir, vocab, **kwargs): | |||||
def match_cache(file, cache_dir): | |||||
md5 = md5_for_file(file) | |||||
cache_files = os.listdir(cache_dir) | |||||
for fn in cache_files: | |||||
if md5 in fn.split("-")[-1]: | |||||
return os.path.join(cache_dir, fn), True | |||||
return ( | |||||
"{}-{}.pkl".format(os.path.join(cache_dir, os.path.basename(file)), md5), | |||||
False, | |||||
) | |||||
def get_cache(file): | |||||
if not os.path.exists(file): | |||||
return None | |||||
with open(file, "rb") as f: | |||||
emb = pickle.load(f) | |||||
return emb | |||||
os.makedirs(cache_dir, exist_ok=True) | |||||
cache_fn, match = match_cache(emb_file, cache_dir) | |||||
if not match: | |||||
print("cache missed, re-generating cache at {}".format(cache_fn)) | |||||
emb, ori_vocab = EmbedLoader.load_without_vocab( | |||||
emb_file, padding=None, unknown=None, normalize=False | |||||
) | |||||
with open(cache_fn, "wb") as f: | |||||
pickle.dump((emb, ori_vocab), f) | |||||
else: | |||||
print("cache matched at {}".format(cache_fn)) | |||||
# use cache | |||||
print("loading embeddings ...") | |||||
emb = get_cache(cache_fn) | |||||
assert emb is not None | |||||
return embedding_match_vocab(vocab, emb[0], emb[1], **kwargs) |