@@ -0,0 +1,61 @@ | |||
# Multi-Criteria-CWS | |||
An implementation of [Multi-Criteria Chinese Word Segmentation with Transformer](http://arxiv.org/abs/1906.12035) with fastNLP. | |||
## Dataset | |||
### Overview | |||
We use the same datasets listed in paper. | |||
- sighan2005 | |||
- pku | |||
- msr | |||
- as | |||
- cityu | |||
- sighan2008 | |||
- ctb | |||
- ckip | |||
- cityu (combined with data in sighan2005) | |||
- ncc | |||
- sxu | |||
### Preprocess | |||
First, download OpenCC to convert between Traditional Chinese and Simplified Chinese. | |||
``` shell | |||
pip install opencc-python-reimplemented | |||
``` | |||
Then, set a path to save processed data, and run the shell script to process the data. | |||
```shell | |||
export DATA_DIR=path/to/processed-data | |||
bash make_data.sh path/to/sighan2005 path/to/sighan2008 | |||
``` | |||
It would take a few minutes to finish the process. | |||
## Model | |||
We use transformer to build the model, as described in paper. | |||
## Train | |||
Finally, to train the model, run the shell script. | |||
The `train.sh` takes one argument, the GPU-IDs to use, for example: | |||
``` shell | |||
bash train.sh 0,1 | |||
``` | |||
This command use GPUs with ID 0 and 1. | |||
Note: Please refer to the paper for details of hyper-parameters. And modify the settings in `train.sh` to match your experiment environment. | |||
Type | |||
``` shell | |||
python main.py --help | |||
``` | |||
to learn all arguments to be specified in training. | |||
## Performance | |||
Results on the test sets of eight CWS datasets with multi-criteria learning. | |||
| Dataset | MSRA | AS | PKU | CTB | CKIP | CITYU | NCC | SXU | Avg. | | |||
| -------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | | |||
| Original paper | 98.05 | 96.44 | 96.41 | 96.99 | 96.51 | 96.91 | 96.04 | 97.61 | 96.87 | | |||
| Ours | 96.92 | 95.71 | 95.65 | 95.96 | 96.00 | 96.09 | 94.61 | 96.64 | 95.95 | | |||
@@ -0,0 +1,262 @@ | |||
import os | |||
import re | |||
import argparse | |||
from opencc import OpenCC | |||
cc = OpenCC("t2s") | |||
from utils import make_sure_path_exists, append_tags | |||
sighan05_root = "" | |||
sighan08_root = "" | |||
data_path = "" | |||
E_pun = u",.!?[]()<>\"\"''," | |||
C_pun = u",。!?【】()《》“”‘’、" | |||
Table = {ord(f): ord(t) for f, t in zip(C_pun, E_pun)} | |||
Table[12288] = 32 # 全半角空格 | |||
def C_trans_to_E(string): | |||
return string.translate(Table) | |||
def normalize(ustring): | |||
"""全角转半角""" | |||
rstring = "" | |||
for uchar in ustring: | |||
inside_code = ord(uchar) | |||
if inside_code == 12288: # 全角空格直接转换 | |||
inside_code = 32 | |||
elif 65281 <= inside_code <= 65374: # 全角字符(除空格)根据关系转化 | |||
inside_code -= 65248 | |||
rstring += chr(inside_code) | |||
return rstring | |||
def preprocess(text): | |||
rNUM = u"(-|\+)?\d+((\.|·)\d+)?%?" | |||
rENG = u"[A-Za-z_]+.*" | |||
sent = normalize(C_trans_to_E(text.strip())).split() | |||
new_sent = [] | |||
for word in sent: | |||
word = re.sub(u"\s+", "", word, flags=re.U) | |||
word = re.sub(rNUM, u"0", word, flags=re.U) | |||
word = re.sub(rENG, u"X", word) | |||
new_sent.append(word) | |||
return new_sent | |||
def to_sentence_list(text, split_long_sentence=False): | |||
text = preprocess(text) | |||
delimiter = set() | |||
delimiter.update("。!?:;…、,(),;!?、,\"'") | |||
delimiter.add("……") | |||
sent_list = [] | |||
sent = [] | |||
sent_len = 0 | |||
for word in text: | |||
sent.append(word) | |||
sent_len += len(word) | |||
if word in delimiter or (split_long_sentence and sent_len >= 50): | |||
sent_list.append(sent) | |||
sent = [] | |||
sent_len = 0 | |||
if len(sent) > 0: | |||
sent_list.append(sent) | |||
return sent_list | |||
def is_traditional(dataset): | |||
return dataset in ["as", "cityu", "ckip"] | |||
def convert_file( | |||
src, des, need_cc=False, split_long_sentence=False, encode="utf-8-sig" | |||
): | |||
with open(src, encoding=encode) as src, open(des, "w", encoding="utf-8") as des: | |||
for line in src: | |||
for sent in to_sentence_list(line, split_long_sentence): | |||
line = " ".join(sent) + "\n" | |||
if need_cc: | |||
line = cc.convert(line) | |||
des.write(line) | |||
# if len(''.join(sent)) > 200: | |||
# print(' '.join(sent)) | |||
def split_train_dev(dataset): | |||
root = data_path + "/" + dataset + "/raw/" | |||
with open(root + "train-all.txt", encoding="UTF-8") as src, open( | |||
root + "train.txt", "w", encoding="UTF-8" | |||
) as train, open(root + "dev.txt", "w", encoding="UTF-8") as dev: | |||
lines = src.readlines() | |||
idx = int(len(lines) * 0.9) | |||
for line in lines[:idx]: | |||
train.write(line) | |||
for line in lines[idx:]: | |||
dev.write(line) | |||
def combine_files(one, two, out): | |||
if os.path.exists(out): | |||
os.remove(out) | |||
with open(one, encoding="utf-8") as one, open(two, encoding="utf-8") as two, open( | |||
out, "a", encoding="utf-8" | |||
) as out: | |||
for line in one: | |||
out.write(line) | |||
for line in two: | |||
out.write(line) | |||
def bmes_tag(input_file, output_file): | |||
with open(input_file, encoding="utf-8") as input_data, open( | |||
output_file, "w", encoding="utf-8" | |||
) as output_data: | |||
for line in input_data: | |||
word_list = line.strip().split() | |||
for word in word_list: | |||
if len(word) == 1 or ( | |||
len(word) > 2 and word[0] == "<" and word[-1] == ">" | |||
): | |||
output_data.write(word + "\tS\n") | |||
else: | |||
output_data.write(word[0] + "\tB\n") | |||
for w in word[1 : len(word) - 1]: | |||
output_data.write(w + "\tM\n") | |||
output_data.write(word[len(word) - 1] + "\tE\n") | |||
output_data.write("\n") | |||
def make_bmes(dataset="pku"): | |||
path = data_path + "/" + dataset + "/" | |||
make_sure_path_exists(path + "bmes") | |||
bmes_tag(path + "raw/train.txt", path + "bmes/train.txt") | |||
bmes_tag(path + "raw/train-all.txt", path + "bmes/train-all.txt") | |||
bmes_tag(path + "raw/dev.txt", path + "bmes/dev.txt") | |||
bmes_tag(path + "raw/test.txt", path + "bmes/test.txt") | |||
def convert_sighan2005_dataset(dataset): | |||
global sighan05_root | |||
root = os.path.join(data_path, dataset) | |||
make_sure_path_exists(root) | |||
make_sure_path_exists(root + "/raw") | |||
file_path = "{}/{}_training.utf8".format(sighan05_root, dataset) | |||
convert_file( | |||
file_path, "{}/raw/train-all.txt".format(root), is_traditional(dataset), True | |||
) | |||
if dataset == "as": | |||
file_path = "{}/{}_testing_gold.utf8".format(sighan05_root, dataset) | |||
else: | |||
file_path = "{}/{}_test_gold.utf8".format(sighan05_root, dataset) | |||
convert_file( | |||
file_path, "{}/raw/test.txt".format(root), is_traditional(dataset), False | |||
) | |||
split_train_dev(dataset) | |||
def convert_sighan2008_dataset(dataset, utf=16): | |||
global sighan08_root | |||
root = os.path.join(data_path, dataset) | |||
make_sure_path_exists(root) | |||
make_sure_path_exists(root + "/raw") | |||
convert_file( | |||
"{}/{}_train_utf{}.seg".format(sighan08_root, dataset, utf), | |||
"{}/raw/train-all.txt".format(root), | |||
is_traditional(dataset), | |||
True, | |||
"utf-{}".format(utf), | |||
) | |||
convert_file( | |||
"{}/{}_seg_truth&resource/{}_truth_utf{}.seg".format( | |||
sighan08_root, dataset, dataset, utf | |||
), | |||
"{}/raw/test.txt".format(root), | |||
is_traditional(dataset), | |||
False, | |||
"utf-{}".format(utf), | |||
) | |||
split_train_dev(dataset) | |||
def extract_conll(src, out): | |||
words = [] | |||
with open(src, encoding="utf-8") as src, open(out, "w", encoding="utf-8") as out: | |||
for line in src: | |||
line = line.strip() | |||
if len(line) == 0: | |||
out.write(" ".join(words) + "\n") | |||
words = [] | |||
continue | |||
cells = line.split() | |||
words.append(cells[1]) | |||
def make_joint_corpus(datasets, joint): | |||
parts = ["dev", "test", "train", "train-all"] | |||
for part in parts: | |||
old_file = "{}/{}/raw/{}.txt".format(data_path, joint, part) | |||
if os.path.exists(old_file): | |||
os.remove(old_file) | |||
elif not os.path.exists(os.path.dirname(old_file)): | |||
os.makedirs(os.path.dirname(old_file)) | |||
for name in datasets: | |||
append_tags( | |||
os.path.join(data_path, name, "raw"), | |||
os.path.dirname(old_file), | |||
name, | |||
part, | |||
encode="utf-8", | |||
) | |||
def convert_all_sighan2005(datasets): | |||
for dataset in datasets: | |||
print(("Converting sighan bakeoff 2005 corpus: {}".format(dataset))) | |||
convert_sighan2005_dataset(dataset) | |||
make_bmes(dataset) | |||
def convert_all_sighan2008(datasets): | |||
for dataset in datasets: | |||
print(("Converting sighan bakeoff 2008 corpus: {}".format(dataset))) | |||
convert_sighan2008_dataset(dataset, 16) | |||
make_bmes(dataset) | |||
if __name__ == "__main__": | |||
parser = argparse.ArgumentParser() | |||
# fmt: off | |||
parser.add_argument("--sighan05", required=True, type=str, help="path to sighan2005 dataset") | |||
parser.add_argument("--sighan08", required=True, type=str, help="path to sighan2008 dataset") | |||
parser.add_argument("--data_path", required=True, type=str, help="path to save dataset") | |||
# fmt: on | |||
args, _ = parser.parse_known_args() | |||
sighan05_root = args.sighan05 | |||
sighan08_root = args.sighan08 | |||
data_path = args.data_path | |||
print("Converting sighan2005 Simplified Chinese corpus") | |||
datasets = "pku", "msr", "as", "cityu" | |||
convert_all_sighan2005(datasets) | |||
print("Combining sighan2005 corpus to one joint Simplified Chinese corpus") | |||
datasets = "pku", "msr", "as", "cityu" | |||
make_joint_corpus(datasets, "joint-sighan2005") | |||
make_bmes("joint-sighan2005") | |||
# For researchers who have access to sighan2008 corpus, use official corpora please. | |||
print("Converting sighan2008 Simplified Chinese corpus") | |||
datasets = "ctb", "ckip", "cityu", "ncc", "sxu" | |||
convert_all_sighan2008(datasets) | |||
print("Combining those 8 sighan corpora to one joint corpus") | |||
datasets = "pku", "msr", "as", "ctb", "ckip", "cityu", "ncc", "sxu" | |||
make_joint_corpus(datasets, "joint-sighan2008") | |||
make_bmes("joint-sighan2008") | |||
@@ -0,0 +1,166 @@ | |||
import os | |||
import sys | |||
import codecs | |||
import argparse | |||
from _pickle import load, dump | |||
import collections | |||
from utils import get_processing_word, is_dataset_tag, make_sure_path_exists, get_bmes | |||
from fastNLP import Instance, DataSet, Vocabulary, Const | |||
max_len = 0 | |||
def expand(x): | |||
sent = ["<sos>"] + x[1:] + ["<eos>"] | |||
return [x + y for x, y in zip(sent[:-1], sent[1:])] | |||
def read_file(filename, processing_word=get_processing_word(lowercase=False)): | |||
dataset = DataSet() | |||
niter = 0 | |||
with codecs.open(filename, "r", "utf-8-sig") as f: | |||
words, tags = [], [] | |||
for line in f: | |||
line = line.strip() | |||
if len(line) == 0 or line.startswith("-DOCSTART-"): | |||
if len(words) != 0: | |||
assert len(words) > 2 | |||
if niter == 1: | |||
print(words, tags) | |||
niter += 1 | |||
dataset.append(Instance(ori_words=words[:-1], ori_tags=tags[:-1])) | |||
words, tags = [], [] | |||
else: | |||
word, tag = line.split() | |||
word = processing_word(word) | |||
words.append(word) | |||
tags.append(tag.lower()) | |||
dataset.apply_field(lambda x: [x[0]], field_name="ori_words", new_field_name="task") | |||
dataset.apply_field( | |||
lambda x: len(x), field_name="ori_tags", new_field_name="seq_len" | |||
) | |||
dataset.apply_field( | |||
lambda x: expand(x), field_name="ori_words", new_field_name="bi1" | |||
) | |||
return dataset | |||
def main(): | |||
parser = argparse.ArgumentParser() | |||
# fmt: off | |||
parser.add_argument("--data_path", required=True, type=str, help="all of datasets pkl paths") | |||
# fmt: on | |||
options, _ = parser.parse_known_args() | |||
train_set, test_set = DataSet(), DataSet() | |||
input_dir = os.path.join(options.data_path, "joint-sighan2008/bmes") | |||
options.output = os.path.join(options.data_path, "total_dataset.pkl") | |||
print(input_dir, options.output) | |||
for fn in os.listdir(input_dir): | |||
if fn not in ["test.txt", "train-all.txt"]: | |||
continue | |||
print(fn) | |||
abs_fn = os.path.join(input_dir, fn) | |||
ds = read_file(abs_fn) | |||
if "test.txt" == fn: | |||
test_set = ds | |||
else: | |||
train_set = ds | |||
print( | |||
"num samples of total train, test: {}, {}".format(len(train_set), len(test_set)) | |||
) | |||
uni_vocab = Vocabulary(min_freq=None).from_dataset( | |||
train_set, test_set, field_name="ori_words" | |||
) | |||
# bi_vocab = Vocabulary(min_freq=3, max_size=50000).from_dataset(train_set,test_set, field_name="bi1") | |||
bi_vocab = Vocabulary(min_freq=3, max_size=None).from_dataset( | |||
train_set, field_name="bi1", no_create_entry_dataset=[test_set] | |||
) | |||
tag_vocab = Vocabulary(min_freq=None, padding="s", unknown=None).from_dataset( | |||
train_set, field_name="ori_tags" | |||
) | |||
task_vocab = Vocabulary(min_freq=None, padding=None, unknown=None).from_dataset( | |||
train_set, field_name="task" | |||
) | |||
def to_index(dataset): | |||
uni_vocab.index_dataset(dataset, field_name="ori_words", new_field_name="uni") | |||
tag_vocab.index_dataset(dataset, field_name="ori_tags", new_field_name="tags") | |||
task_vocab.index_dataset(dataset, field_name="task", new_field_name="task") | |||
dataset.apply_field(lambda x: x[1:], field_name="bi1", new_field_name="bi2") | |||
dataset.apply_field(lambda x: x[:-1], field_name="bi1", new_field_name="bi1") | |||
bi_vocab.index_dataset(dataset, field_name="bi1", new_field_name="bi1") | |||
bi_vocab.index_dataset(dataset, field_name="bi2", new_field_name="bi2") | |||
dataset.set_input("task", "uni", "bi1", "bi2", "seq_len") | |||
dataset.set_target("tags") | |||
return dataset | |||
train_set = to_index(train_set) | |||
test_set = to_index(test_set) | |||
output = {} | |||
output["train_set"] = train_set | |||
output["test_set"] = test_set | |||
output["uni_vocab"] = uni_vocab | |||
output["bi_vocab"] = bi_vocab | |||
output["tag_vocab"] = tag_vocab | |||
output["task_vocab"] = task_vocab | |||
print(tag_vocab.word2idx) | |||
print(task_vocab.word2idx) | |||
make_sure_path_exists(os.path.dirname(options.output)) | |||
print("Saving dataset to {}".format(os.path.abspath(options.output))) | |||
with open(options.output, "wb") as outfile: | |||
dump(output, outfile) | |||
print(len(task_vocab), len(tag_vocab), len(uni_vocab), len(bi_vocab)) | |||
dic = {} | |||
tokens = {} | |||
def process(words): | |||
name = words[0][1:-1] | |||
if name not in dic: | |||
dic[name] = set() | |||
tokens[name] = 0 | |||
tokens[name] += len(words[1:]) | |||
dic[name].update(words[1:]) | |||
train_set.apply_field(process, "ori_words", None) | |||
for name in dic.keys(): | |||
print(name, len(dic[name]), tokens[name]) | |||
with open(os.path.join(os.path.dirname(options.output), "oovdict.pkl"), "wb") as f: | |||
dump(dic, f) | |||
def get_max_len(ds): | |||
global max_len | |||
max_len = 0 | |||
def find_max_len(words): | |||
global max_len | |||
if max_len < len(words): | |||
max_len = len(words) | |||
ds.apply_field(find_max_len, "ori_words", None) | |||
return max_len | |||
print( | |||
"train max len: {}, test max len: {}".format( | |||
get_max_len(train_set), get_max_len(test_set) | |||
) | |||
) | |||
if __name__ == "__main__": | |||
main() |
@@ -0,0 +1,506 @@ | |||
import _pickle as pickle | |||
import argparse | |||
import collections | |||
import logging | |||
import math | |||
import os | |||
import pickle | |||
import random | |||
import sys | |||
import time | |||
from sys import maxsize | |||
import fastNLP | |||
import fastNLP.embeddings | |||
import numpy as np | |||
import torch | |||
import torch.distributed as dist | |||
import torch.nn as nn | |||
from fastNLP import BucketSampler, DataSetIter, SequentialSampler, logger | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.utils.data.distributed import DistributedSampler | |||
import models | |||
import optm | |||
import utils | |||
NONE_TAG = "<NONE>" | |||
START_TAG = "<sos>" | |||
END_TAG = "<eos>" | |||
DEFAULT_WORD_EMBEDDING_SIZE = 100 | |||
DEBUG_SCALE = 200 | |||
# ===-----------------------------------------------------------------------=== | |||
# Argument parsing | |||
# ===-----------------------------------------------------------------------=== | |||
# fmt: off | |||
parser = argparse.ArgumentParser() | |||
parser.add_argument("--dataset", required=True, dest="dataset", help="processed data dir") | |||
parser.add_argument("--word-embeddings", dest="word_embeddings", help="File from which to read in pretrained embeds") | |||
parser.add_argument("--bigram-embeddings", dest="bigram_embeddings", help="File from which to read in pretrained embeds") | |||
parser.add_argument("--crf", dest="crf", action="store_true", help="crf") | |||
# parser.add_argument("--devi", default="0", dest="devi", help="gpu") | |||
parser.add_argument("--step", default=0, dest="step", type=int,help="step") | |||
parser.add_argument("--num-epochs", default=100, dest="num_epochs", type=int, | |||
help="Number of full passes through training set") | |||
parser.add_argument("--batch-size", default=128, dest="batch_size", type=int, | |||
help="Minibatch size of training set") | |||
parser.add_argument("--d_model", default=256, dest="d_model", type=int, help="d_model") | |||
parser.add_argument("--d_ff", default=1024, dest="d_ff", type=int, help="d_ff") | |||
parser.add_argument("--N", default=6, dest="N", type=int, help="N") | |||
parser.add_argument("--h", default=4, dest="h", type=int, help="h") | |||
parser.add_argument("--factor", default=2, dest="factor", type=float, help="Initial learning rate") | |||
parser.add_argument("--dropout", default=0.2, dest="dropout", type=float, | |||
help="Amount of dropout(not keep rate, but drop rate) to apply to embeddings part of graph") | |||
parser.add_argument("--log-dir", default="result", dest="log_dir", | |||
help="Directory where to write logs / serialized models") | |||
parser.add_argument("--task-name", default=time.strftime("%Y-%m-%d-%H-%M-%S"), dest="task_name", | |||
help="Name for this task, use a comprehensive one") | |||
parser.add_argument("--no-model", dest="no_model", action="store_true", help="Don't serialize model") | |||
parser.add_argument("--always-model", dest="always_model", action="store_true", | |||
help="Always serialize model after every epoch") | |||
parser.add_argument("--old-model", dest="old_model", help="Path to old model for incremental training") | |||
parser.add_argument("--skip-dev", dest="skip_dev", action="store_true", help="Skip dev set, would save some time") | |||
parser.add_argument("--freeze", dest="freeze", action="store_true", help="freeze pretrained embedding") | |||
parser.add_argument("--only-task", dest="only_task", action="store_true", help="only train task embedding") | |||
parser.add_argument("--subset", dest="subset", help="Only train and test on a subset of the whole dataset") | |||
parser.add_argument("--seclude", dest="seclude", help="train and test except a subset") | |||
parser.add_argument("--instances", default=None, dest="instances", type=int,help="num of instances of subset") | |||
parser.add_argument("--seed", dest="python_seed", type=int, default=random.randrange(maxsize), | |||
help="Random seed of Python and NumPy") | |||
parser.add_argument("--debug", dest="debug", default=False, action="store_true", help="Debug mode") | |||
parser.add_argument("--test", dest="test", action="store_true", help="Test mode") | |||
parser.add_argument('--local_rank', type=int, default=None) | |||
parser.add_argument('--init_method', type=str, default='env://') | |||
# fmt: on | |||
options, _ = parser.parse_known_args() | |||
print("unknown args", _) | |||
task_name = options.task_name | |||
root_dir = "{}/{}".format(options.log_dir, task_name) | |||
utils.make_sure_path_exists(root_dir) | |||
if options.local_rank is not None: | |||
torch.cuda.set_device(options.local_rank) | |||
dist.init_process_group("nccl", init_method=options.init_method) | |||
def init_logger(): | |||
if not os.path.exists(root_dir): | |||
os.mkdir(root_dir) | |||
log_formatter = logging.Formatter("%(asctime)s - %(message)s") | |||
logger = logging.getLogger() | |||
file_handler = logging.FileHandler("{0}/info.log".format(root_dir), mode="w") | |||
file_handler.setFormatter(log_formatter) | |||
logger.addHandler(file_handler) | |||
console_handler = logging.StreamHandler() | |||
console_handler.setFormatter(log_formatter) | |||
logger.addHandler(console_handler) | |||
if options.local_rank is None or options.local_rank == 0: | |||
logger.setLevel(logging.INFO) | |||
else: | |||
logger.setLevel(logging.WARNING) | |||
return logger | |||
# ===-----------------------------------------------------------------------=== | |||
# Set up logging | |||
# ===-----------------------------------------------------------------------=== | |||
# logger = init_logger() | |||
logger.add_file("{}/info.log".format(root_dir), "INFO") | |||
logger.setLevel(logging.INFO if dist.get_rank() == 0 else logging.WARNING) | |||
# ===-----------------------------------------------------------------------=== | |||
# Log some stuff about this run | |||
# ===-----------------------------------------------------------------------=== | |||
logger.info(" ".join(sys.argv)) | |||
logger.info("") | |||
logger.info(options) | |||
if options.debug: | |||
logger.info("DEBUG MODE") | |||
options.num_epochs = 2 | |||
options.batch_size = 20 | |||
random.seed(options.python_seed) | |||
np.random.seed(options.python_seed % (2 ** 32 - 1)) | |||
torch.cuda.manual_seed_all(options.python_seed) | |||
logger.info("Python random seed: {}".format(options.python_seed)) | |||
# ===-----------------------------------------------------------------------=== | |||
# Read in dataset | |||
# ===-----------------------------------------------------------------------=== | |||
dataset = pickle.load(open(options.dataset + "/total_dataset.pkl", "rb")) | |||
train_set = dataset["train_set"] | |||
test_set = dataset["test_set"] | |||
uni_vocab = dataset["uni_vocab"] | |||
bi_vocab = dataset["bi_vocab"] | |||
task_vocab = dataset["task_vocab"] | |||
tag_vocab = dataset["tag_vocab"] | |||
for v in (bi_vocab, uni_vocab, tag_vocab, task_vocab): | |||
if hasattr(v, "_word2idx"): | |||
v.word2idx = v._word2idx | |||
for ds in (train_set, test_set): | |||
ds.rename_field("ori_words", "words") | |||
logger.info("{} {}".format(bi_vocab.to_word(0), tag_vocab.word2idx)) | |||
logger.info(task_vocab.word2idx) | |||
if options.skip_dev: | |||
dev_set = test_set | |||
else: | |||
train_set, dev_set = train_set.split(0.1) | |||
logger.info("{} {} {}".format(len(train_set), len(dev_set), len(test_set))) | |||
if options.debug: | |||
train_set = train_set[0:DEBUG_SCALE] | |||
dev_set = dev_set[0:DEBUG_SCALE] | |||
test_set = test_set[0:DEBUG_SCALE] | |||
# ===-----------------------------------------------------------------------=== | |||
# Build model and trainer | |||
# ===-----------------------------------------------------------------------=== | |||
# =============================== | |||
if dist.get_rank() != 0: | |||
dist.barrier() | |||
if options.word_embeddings is None: | |||
init_embedding = None | |||
else: | |||
# logger.info("Load: {}".format(options.word_embeddings)) | |||
# init_embedding = utils.embedding_load_with_cache(options.word_embeddings, options.cache_dir, uni_vocab, normalize=False) | |||
init_embedding = fastNLP.embeddings.StaticEmbedding( | |||
uni_vocab, options.word_embeddings, word_drop=0.01 | |||
) | |||
bigram_embedding = None | |||
if options.bigram_embeddings: | |||
# logger.info("Load: {}".format(options.bigram_embeddings)) | |||
# bigram_embedding = utils.embedding_load_with_cache(options.bigram_embeddings, options.cache_dir, bi_vocab, normalize=False) | |||
bigram_embedding = fastNLP.embeddings.StaticEmbedding( | |||
bi_vocab, options.bigram_embeddings | |||
) | |||
if dist.get_rank() == 0: | |||
dist.barrier() | |||
# =============================== | |||
# select subset training | |||
if options.seclude is not None: | |||
setname = "<{}>".format(options.seclude) | |||
logger.info("seclude {}".format(setname)) | |||
train_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||
test_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||
dev_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||
if options.subset is not None: | |||
setname = "<{}>".format(options.subset) | |||
logger.info("select {}".format(setname)) | |||
train_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||
test_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||
dev_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||
# build model and optimizer | |||
i2t = None | |||
if options.crf: | |||
# i2t=utils.to_id_list(tag_vocab.word2idx) | |||
i2t = {} | |||
for x, y in tag_vocab.word2idx.items(): | |||
i2t[y] = x | |||
logger.info(i2t) | |||
freeze = True if options.freeze else False | |||
model = models.make_CWS( | |||
d_model=options.d_model, | |||
N=options.N, | |||
h=options.h, | |||
d_ff=options.d_ff, | |||
dropout=options.dropout, | |||
word_embedding=init_embedding, | |||
bigram_embedding=bigram_embedding, | |||
tag_size=len(tag_vocab), | |||
task_size=len(task_vocab), | |||
crf=i2t, | |||
freeze=freeze, | |||
) | |||
device = "cpu" | |||
if torch.cuda.device_count() > 0: | |||
if options.local_rank is not None: | |||
device = "cuda:{}".format(options.local_rank) | |||
# model=nn.DataParallel(model) | |||
model = model.to(device) | |||
model = torch.nn.parallel.DistributedDataParallel( | |||
model, device_ids=[options.local_rank], output_device=options.local_rank | |||
) | |||
else: | |||
device = "cuda:0" | |||
model.to(device) | |||
if options.only_task and options.old_model is not None: | |||
logger.info("fix para except task embedding") | |||
for name, para in model.named_parameters(): | |||
if name.find("task_embed") == -1: | |||
para.requires_grad = False | |||
else: | |||
para.requires_grad = True | |||
logger.info(name) | |||
optimizer = optm.NoamOpt( | |||
options.d_model, | |||
options.factor, | |||
4000, | |||
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9), | |||
) | |||
optimizer._step = options.step | |||
best_model_file_name = "{}/model.bin".format(root_dir) | |||
if options.local_rank is None: | |||
train_sampler = BucketSampler( | |||
batch_size=options.batch_size, seq_len_field_name="seq_len" | |||
) | |||
else: | |||
train_sampler = DistributedSampler( | |||
train_set, dist.get_world_size(), dist.get_rank() | |||
) | |||
dev_sampler = SequentialSampler() | |||
i2t = utils.to_id_list(tag_vocab.word2idx) | |||
i2task = utils.to_id_list(task_vocab.word2idx) | |||
dev_set.set_input("words") | |||
test_set.set_input("words") | |||
test_batch = DataSetIter(test_set, options.batch_size, num_workers=2) | |||
word_dic = pickle.load(open(options.dataset + "/oovdict.pkl", "rb")) | |||
def batch_to_device(batch, device): | |||
for k, v in batch.items(): | |||
if torch.is_tensor(v): | |||
batch[k] = v.to(device) | |||
return batch | |||
def tester(model, test_batch, write_out=False): | |||
res = [] | |||
prf = utils.CWSEvaluator(i2t) | |||
prf_dataset = {} | |||
oov_dataset = {} | |||
logger.info("start evaluation") | |||
# import ipdb; ipdb.set_trace() | |||
with torch.no_grad(): | |||
for batch_x, batch_y in test_batch: | |||
batch_to_device(batch_x, device) | |||
# batch_to_device(batch_y, device) | |||
if bigram_embedding is not None: | |||
out = model( | |||
batch_x["task"], | |||
batch_x["uni"], | |||
batch_x["seq_len"], | |||
batch_x["bi1"], | |||
batch_x["bi2"], | |||
) | |||
else: | |||
out = model(batch_x["task"], batch_x["uni"], batch_x["seq_len"]) | |||
out = out["pred"] | |||
# print(out) | |||
num = out.size(0) | |||
out = out.detach().cpu().numpy() | |||
for i in range(num): | |||
length = int(batch_x["seq_len"][i]) | |||
out_tags = out[i, 1:length].tolist() | |||
sentence = batch_x["words"][i] | |||
gold_tags = batch_y["tags"][i][1:length].numpy().tolist() | |||
dataset_name = sentence[0] | |||
sentence = sentence[1:] | |||
# print(out_tags,gold_tags) | |||
assert utils.is_dataset_tag(dataset_name), dataset_name | |||
assert len(gold_tags) == len(out_tags) and len(gold_tags) == len( | |||
sentence | |||
) | |||
if dataset_name not in prf_dataset: | |||
prf_dataset[dataset_name] = utils.CWSEvaluator(i2t) | |||
oov_dataset[dataset_name] = utils.CWS_OOV( | |||
word_dic[dataset_name[1:-1]] | |||
) | |||
prf_dataset[dataset_name].add_instance(gold_tags, out_tags) | |||
prf.add_instance(gold_tags, out_tags) | |||
if write_out: | |||
gold_strings = utils.to_tag_strings(i2t, gold_tags) | |||
obs_strings = utils.to_tag_strings(i2t, out_tags) | |||
word_list = utils.bmes_to_words(sentence, obs_strings) | |||
oov_dataset[dataset_name].update( | |||
utils.bmes_to_words(sentence, gold_strings), word_list | |||
) | |||
raw_string = " ".join(word_list) | |||
res.append(dataset_name + " " + raw_string + " " + dataset_name) | |||
Ap = 0.0 | |||
Ar = 0.0 | |||
Af = 0.0 | |||
Aoov = 0.0 | |||
tot = 0 | |||
nw = 0.0 | |||
for dataset_name, performance in sorted(prf_dataset.items()): | |||
p = performance.result() | |||
if write_out: | |||
nw = oov_dataset[dataset_name].oov() | |||
# nw = 0 | |||
logger.info( | |||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
dataset_name, p[0], p[1], p[2], nw | |||
) | |||
) | |||
else: | |||
logger.info( | |||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
dataset_name, p[0], p[1], p[2] | |||
) | |||
) | |||
Ap += p[0] | |||
Ar += p[1] | |||
Af += p[2] | |||
Aoov += nw | |||
tot += 1 | |||
prf = prf.result() | |||
logger.info( | |||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format("TOT", prf[0], prf[1], prf[2]) | |||
) | |||
if not write_out: | |||
logger.info( | |||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
"AVG", Ap / tot, Ar / tot, Af / tot | |||
) | |||
) | |||
else: | |||
logger.info( | |||
"{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
"AVG", Ap / tot, Ar / tot, Af / tot, Aoov / tot | |||
) | |||
) | |||
return prf[-1], res | |||
# start training | |||
if not options.test: | |||
if options.old_model: | |||
# incremental training | |||
logger.info("Incremental training from old model: {}".format(options.old_model)) | |||
model.load_state_dict(torch.load(options.old_model, map_location="cuda:0")) | |||
logger.info("Number training instances: {}".format(len(train_set))) | |||
logger.info("Number dev instances: {}".format(len(dev_set))) | |||
train_batch = DataSetIter( | |||
batch_size=options.batch_size, | |||
dataset=train_set, | |||
sampler=train_sampler, | |||
num_workers=4, | |||
) | |||
dev_batch = DataSetIter( | |||
batch_size=options.batch_size, | |||
dataset=dev_set, | |||
sampler=dev_sampler, | |||
num_workers=4, | |||
) | |||
best_f1 = 0.0 | |||
for epoch in range(int(options.num_epochs)): | |||
logger.info("Epoch {} out of {}".format(epoch + 1, options.num_epochs)) | |||
train_loss = 0.0 | |||
model.train() | |||
tot = 0 | |||
t1 = time.time() | |||
for batch_x, batch_y in train_batch: | |||
model.zero_grad() | |||
if bigram_embedding is not None: | |||
out = model( | |||
batch_x["task"], | |||
batch_x["uni"], | |||
batch_x["seq_len"], | |||
batch_x["bi1"], | |||
batch_x["bi2"], | |||
batch_y["tags"], | |||
) | |||
else: | |||
out = model( | |||
batch_x["task"], batch_x["uni"], batch_x["seq_len"], batch_y["tags"] | |||
) | |||
loss = out["loss"] | |||
train_loss += loss.item() | |||
tot += 1 | |||
loss.backward() | |||
# nn.utils.clip_grad_value_(model.parameters(), 1) | |||
optimizer.step() | |||
t2 = time.time() | |||
train_loss = train_loss / tot | |||
logger.info( | |||
"time: {} loss: {} step: {}".format(t2 - t1, train_loss, optimizer._step) | |||
) | |||
# Evaluate dev data | |||
if options.skip_dev and dist.get_rank() == 0: | |||
logger.info("Saving model to {}".format(best_model_file_name)) | |||
torch.save(model.module.state_dict(), best_model_file_name) | |||
continue | |||
model.eval() | |||
if dist.get_rank() == 0: | |||
f1, _ = tester(model.module, dev_batch) | |||
if f1 > best_f1: | |||
best_f1 = f1 | |||
logger.info("- new best score!") | |||
if not options.no_model: | |||
logger.info("Saving model to {}".format(best_model_file_name)) | |||
torch.save(model.module.state_dict(), best_model_file_name) | |||
elif options.always_model: | |||
logger.info("Saving model to {}".format(best_model_file_name)) | |||
torch.save(model.module.state_dict(), best_model_file_name) | |||
dist.barrier() | |||
# Evaluate test data (once) | |||
logger.info("\nNumber test instances: {}".format(len(test_set))) | |||
if not options.skip_dev: | |||
if options.test: | |||
model.module.load_state_dict( | |||
torch.load(options.old_model, map_location="cuda:0") | |||
) | |||
else: | |||
model.module.load_state_dict( | |||
torch.load(best_model_file_name, map_location="cuda:0") | |||
) | |||
if dist.get_rank() == 0: | |||
for name, para in model.named_parameters(): | |||
if name.find("task_embed") != -1: | |||
tm = para.detach().cpu().numpy() | |||
logger.info(tm.shape) | |||
np.save("{}/task.npy".format(root_dir), tm) | |||
break | |||
_, res = tester(model.module, test_batch, True) | |||
if dist.get_rank() == 0: | |||
with open("{}/testout.txt".format(root_dir), "w", encoding="utf-8") as raw_writer: | |||
for sent in res: | |||
raw_writer.write(sent) | |||
raw_writer.write("\n") | |||
@@ -0,0 +1,14 @@ | |||
if [ -z "$DATA_DIR" ] | |||
then | |||
DATA_DIR="./data" | |||
fi | |||
mkdir -vp $DATA_DIR | |||
cmd="python -u ./data-prepare.py --sighan05 $1 --sighan08 $2 --data_path $DATA_DIR" | |||
echo $cmd | |||
eval $cmd | |||
cmd="python -u ./data-process.py --data_path $DATA_DIR" | |||
echo $cmd | |||
eval $cmd |
@@ -0,0 +1,200 @@ | |||
import fastNLP | |||
import torch | |||
import math | |||
from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
from fastNLP.modules.decoder.crf import ConditionalRandomField | |||
from fastNLP import Const | |||
import copy | |||
import numpy as np | |||
from torch.autograd import Variable | |||
import torch.autograd as autograd | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import transformer | |||
class PositionalEncoding(nn.Module): | |||
"Implement the PE function." | |||
def __init__(self, d_model, dropout, max_len=512): | |||
super(PositionalEncoding, self).__init__() | |||
self.dropout = nn.Dropout(p=dropout) | |||
# Compute the positional encodings once in log space. | |||
pe = torch.zeros(max_len, d_model).float() | |||
position = torch.arange(0, max_len).unsqueeze(1).float() | |||
div_term = torch.exp( | |||
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | |||
) | |||
pe[:, 0::2] = torch.sin(position * div_term) | |||
pe[:, 1::2] = torch.cos(position * div_term) | |||
pe = pe.unsqueeze(0) | |||
self.register_buffer("pe", pe) | |||
def forward(self, x): | |||
x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False) | |||
return self.dropout(x) | |||
class Embedding(nn.Module): | |||
def __init__( | |||
self, | |||
task_size, | |||
d_model, | |||
word_embedding=None, | |||
bi_embedding=None, | |||
word_size=None, | |||
freeze=True, | |||
): | |||
super(Embedding, self).__init__() | |||
self.task_size = task_size | |||
self.embed_dim = 0 | |||
self.task_embed = nn.Embedding(task_size, d_model) | |||
if word_embedding is not None: | |||
# self.uni_embed = nn.Embedding.from_pretrained(torch.FloatTensor(word_embedding), freeze=freeze) | |||
# self.embed_dim+=word_embedding.shape[1] | |||
self.uni_embed = word_embedding | |||
self.embed_dim += word_embedding.embedding_dim | |||
else: | |||
if bi_embedding is not None: | |||
self.embed_dim += bi_embedding.shape[1] | |||
else: | |||
self.embed_dim = d_model | |||
assert word_size is not None | |||
self.uni_embed = Embedding(word_size, self.embed_dim) | |||
if bi_embedding is not None: | |||
# self.bi_embed = nn.Embedding.from_pretrained(torch.FloatTensor(bi_embedding), freeze=freeze) | |||
# self.embed_dim += bi_embedding.shape[1]*2 | |||
self.bi_embed = bi_embedding | |||
self.embed_dim += bi_embedding.embedding_dim * 2 | |||
print("Trans Freeze", freeze, self.embed_dim) | |||
if d_model != self.embed_dim: | |||
self.F = nn.Linear(self.embed_dim, d_model) | |||
else: | |||
self.F = None | |||
self.d_model = d_model | |||
def forward(self, task, uni, bi1=None, bi2=None): | |||
y_task = self.task_embed(task[:, 0:1]) | |||
y = self.uni_embed(uni[:, 1:]) | |||
if bi1 is not None: | |||
assert self.bi_embed is not None | |||
y = torch.cat([y, self.bi_embed(bi1), self.bi_embed(bi2)], dim=-1) | |||
# y2=self.bi_embed(bi) | |||
# y=torch.cat([y,y2[:,:-1,:],y2[:,1:,:]],dim=-1) | |||
# y=torch.cat([y_task,y],dim=1) | |||
if self.F is not None: | |||
y = self.F(y) | |||
y = torch.cat([y_task, y], dim=1) | |||
return y * math.sqrt(self.d_model) | |||
def seq_len_to_mask(seq_len, max_len=None): | |||
if isinstance(seq_len, np.ndarray): | |||
assert ( | |||
len(np.shape(seq_len)) == 1 | |||
), f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
if max_len is None: | |||
max_len = int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
elif isinstance(seq_len, torch.Tensor): | |||
assert ( | |||
seq_len.dim() == 1 | |||
), f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||
batch_size = seq_len.size(0) | |||
if max_len is None: | |||
max_len = seq_len.max().long() | |||
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) | |||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||
else: | |||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | |||
return mask | |||
class CWSModel(nn.Module): | |||
def __init__(self, encoder, src_embed, position, d_model, tag_size, crf=None): | |||
super(CWSModel, self).__init__() | |||
self.encoder = encoder | |||
self.src_embed = src_embed | |||
self.pos = copy.deepcopy(position) | |||
self.proj = nn.Linear(d_model, tag_size) | |||
self.tag_size = tag_size | |||
if crf is None: | |||
self.crf = None | |||
self.loss_f = nn.CrossEntropyLoss(reduction="mean", ignore_index=-100) | |||
else: | |||
print("crf") | |||
trans = fastNLP.modules.decoder.crf.allowed_transitions( | |||
crf, encoding_type="bmes" | |||
) | |||
self.crf = ConditionalRandomField(tag_size, allowed_transitions=trans) | |||
# self.norm=nn.LayerNorm(d_model) | |||
def forward(self, task, uni, seq_len, bi1=None, bi2=None, tags=None): | |||
# mask=fastNLP.core.utils.seq_len_to_mask(seq_len,uni.size(1)) # for dev 0.5.1 | |||
mask = seq_len_to_mask(seq_len, uni.size(1)) | |||
out = self.src_embed(task, uni, bi1, bi2) | |||
out = self.pos(out) | |||
# out=self.norm(out) | |||
out = self.proj(self.encoder(out, mask.float())) | |||
if self.crf is not None: | |||
if tags is not None: | |||
out = self.crf(out, tags, mask) | |||
return {"loss": out} | |||
else: | |||
out, _ = self.crf.viterbi_decode(out, mask) | |||
return {"pred": out} | |||
else: | |||
if tags is not None: | |||
out = out.contiguous().view(-1, self.tag_size) | |||
tags = tags.data.masked_fill_(mask == 0, -100).view(-1) | |||
loss = self.loss_f(out, tags) | |||
return {"loss": loss} | |||
else: | |||
out = torch.argmax(out, dim=-1) | |||
return {"pred": out} | |||
def make_CWS( | |||
N=6, | |||
d_model=256, | |||
d_ff=1024, | |||
h=4, | |||
dropout=0.2, | |||
tag_size=4, | |||
task_size=8, | |||
bigram_embedding=None, | |||
word_embedding=None, | |||
word_size=None, | |||
crf=None, | |||
freeze=True, | |||
): | |||
c = copy.deepcopy | |||
# encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) | |||
encoder = transformer.make_encoder( | |||
N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff | |||
) | |||
position = PositionalEncoding(d_model, dropout) | |||
embed = Embedding( | |||
task_size, d_model, word_embedding, bigram_embedding, word_size, freeze | |||
) | |||
model = CWSModel(encoder, embed, position, d_model, tag_size, crf=crf) | |||
for p in model.parameters(): | |||
if p.dim() > 1 and p.requires_grad: | |||
nn.init.xavier_uniform_(p) | |||
return model |
@@ -0,0 +1,49 @@ | |||
import torch | |||
import torch.optim as optim | |||
class NoamOpt: | |||
"Optim wrapper that implements rate." | |||
def __init__(self, model_size, factor, warmup, optimizer): | |||
self.optimizer = optimizer | |||
self._step = 0 | |||
self.warmup = warmup | |||
self.factor = factor | |||
self.model_size = model_size | |||
self._rate = 0 | |||
def step(self): | |||
"Update parameters and rate" | |||
self._step += 1 | |||
rate = self.rate() | |||
for p in self.optimizer.param_groups: | |||
p["lr"] = rate | |||
self._rate = rate | |||
self.optimizer.step() | |||
def rate(self, step=None): | |||
"Implement `lrate` above" | |||
if step is None: | |||
step = self._step | |||
lr = self.factor * ( | |||
self.model_size ** (-0.5) | |||
* min(step ** (-0.5), step * self.warmup ** (-1.5)) | |||
) | |||
# if step>self.warmup: lr = max(1e-4,lr) | |||
return lr | |||
def get_std_opt(model): | |||
return NoamOpt( | |||
model.src_embed[0].d_model, | |||
2, | |||
4000, | |||
torch.optim.Adam( | |||
filter(lambda p: p.requires_grad, model.parameters()), | |||
lr=0, | |||
betas=(0.9, 0.98), | |||
eps=1e-9, | |||
), | |||
) | |||
@@ -0,0 +1,138 @@ | |||
from fastNLP import (Trainer, Tester, Callback, GradientClipCallback, LRScheduler, SpanFPreRecMetric) | |||
import torch | |||
import torch.cuda | |||
from torch.optim import Adam, SGD | |||
from argparse import ArgumentParser | |||
import logging | |||
from .utils import set_seed | |||
class LoggingCallback(Callback): | |||
def __init__(self, filepath=None): | |||
super().__init__() | |||
# create file handler and set level to debug | |||
if filepath is not None: | |||
file_handler = logging.FileHandler(filepath, "a") | |||
else: | |||
file_handler = logging.StreamHandler() | |||
file_handler.setLevel(logging.DEBUG) | |||
file_handler.setFormatter( | |||
logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||
datefmt='%m/%d/%Y %H:%M:%S')) | |||
# create logger and set level to debug | |||
logger = logging.getLogger() | |||
logger.handlers = [] | |||
logger.setLevel(logging.DEBUG) | |||
logger.propagate = False | |||
logger.addHandler(file_handler) | |||
self.log_writer = logger | |||
def on_backward_begin(self, loss): | |||
if self.step % self.trainer.print_every == 0: | |||
self.log_writer.info( | |||
'Step/Epoch {}/{}: Loss {}'.format(self.step, self.epoch, loss.item())) | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
self.log_writer.info( | |||
'Step/Epoch {}/{}: Eval result {}'.format(self.step, self.epoch, eval_result)) | |||
def on_backward_end(self): | |||
pass | |||
def main(): | |||
parser = ArgumentParser() | |||
register_args(parser) | |||
args = parser.parse_known_args()[0] | |||
set_seed(args.seed) | |||
if args.train: | |||
train(args) | |||
if args.eval: | |||
evaluate(args) | |||
def get_optim(args): | |||
name = args.optim.strip().split(' ')[0].lower() | |||
p = args.optim.strip() | |||
l = p.find('(') | |||
r = p.find(')') | |||
optim_args = eval('dict({})'.format(p[[l+1,r]])) | |||
if name == 'sgd': | |||
return SGD(**optim_args) | |||
elif name == 'adam': | |||
return Adam(**optim_args) | |||
else: | |||
raise ValueError(args.optim) | |||
def load_model_from_path(args): | |||
pass | |||
def train(args): | |||
data = get_data(args) | |||
train_data = data['train'] | |||
dev_data = data['dev'] | |||
model = get_model(args) | |||
optimizer = get_optim(args) | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
callbacks = [] | |||
trainer = Trainer( | |||
train_data=train_data, | |||
model=model, | |||
optimizer=optimizer, | |||
loss=None, | |||
batch_size=args.batch_size, | |||
n_epochs=args.epochs, | |||
num_workers=4, | |||
metrics=SpanFPreRecMetric( | |||
tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], | |||
ignore_labels=data['ignore_labels']), | |||
metric_key='f1', | |||
dev_data=dev_data, | |||
save_path=args.save_path, | |||
device=device, | |||
callbacks=callbacks, | |||
check_code_level=-1, | |||
) | |||
print(trainer.train()) | |||
def evaluate(args): | |||
data = get_data(args) | |||
test_data = data['test'] | |||
model = load_model_from_path(args) | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
tester = Tester( | |||
data=test_data, model=model, batch_size=args.batch_size, | |||
num_workers=2, device=device, | |||
metrics=SpanFPreRecMetric( | |||
tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], | |||
ignore_labels=data['ignore_labels']), | |||
) | |||
print(tester.test()) | |||
def register_args(parser): | |||
parser.add_argument('--optim', type=str, default='adam (lr=2e-3, weight_decay=0.0)') | |||
parser.add_argument('--batch_size', type=int, default=128) | |||
parser.add_argument('--epochs', type=int, default=10) | |||
parser.add_argument('--save_path', type=str, default=None) | |||
parser.add_argument('--data_path', type=str, required=True) | |||
parser.add_argument('--log_path', type=str, default=None) | |||
parser.add_argument('--model_config', type=str, required=True) | |||
parser.add_argument('--load_path', type=str, default=None) | |||
parser.add_argument('--train', action='store_true', default=False) | |||
parser.add_argument('--eval', action='store_true', default=False) | |||
parser.add_argument('--seed', type=int, default=42, help='rng seed') | |||
def get_model(args): | |||
pass | |||
def get_data(args): | |||
return torch.load(args.data_path) | |||
if __name__ == '__main__': | |||
main() |
@@ -0,0 +1,26 @@ | |||
export EXP_NAME=release04 | |||
export NGPU=2 | |||
export PORT=9988 | |||
export CUDA_DEVICE_ORDER=PCI_BUS_ID | |||
export CUDA_VISIBLE_DEVICES=$1 | |||
if [ -z "$DATA_DIR" ] | |||
then | |||
DATA_DIR="./data" | |||
fi | |||
echo $CUDA_VISIBLE_DEVICES | |||
cmd=" | |||
python -m torch.distributed.launch --nproc_per_node=$NGPU --master_port $PORT\ | |||
main.py \ | |||
--word-embeddings cn-char-fastnlp-100d \ | |||
--bigram-embeddings cn-bi-fastnlp-100d \ | |||
--num-epochs 100 \ | |||
--batch-size 256 \ | |||
--seed 1234 \ | |||
--task-name $EXP_NAME \ | |||
--dataset $DATA_DIR \ | |||
--freeze \ | |||
" | |||
echo $cmd | |||
eval $cmd |
@@ -0,0 +1,152 @@ | |||
import numpy as np | |||
import torch | |||
import torch.autograd as autograd | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
import math, copy, time | |||
from torch.autograd import Variable | |||
# import matplotlib.pyplot as plt | |||
def clones(module, N): | |||
"Produce N identical layers." | |||
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |||
def subsequent_mask(size): | |||
"Mask out subsequent positions." | |||
attn_shape = (1, size, size) | |||
subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") | |||
return torch.from_numpy(subsequent_mask) == 0 | |||
def attention(query, key, value, mask=None, dropout=None): | |||
"Compute 'Scaled Dot Product Attention'" | |||
d_k = query.size(-1) | |||
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | |||
if mask is not None: | |||
# print(scores.size(),mask.size()) # [bsz,1,1,len] | |||
scores = scores.masked_fill(mask == 0, -1e9) | |||
p_attn = F.softmax(scores, dim=-1) | |||
if dropout is not None: | |||
p_attn = dropout(p_attn) | |||
return torch.matmul(p_attn, value), p_attn | |||
class MultiHeadedAttention(nn.Module): | |||
def __init__(self, h, d_model, dropout=0.1): | |||
"Take in model size and number of heads." | |||
super(MultiHeadedAttention, self).__init__() | |||
assert d_model % h == 0 | |||
# We assume d_v always equals d_k | |||
self.d_k = d_model // h | |||
self.h = h | |||
self.linears = clones(nn.Linear(d_model, d_model), 4) | |||
self.attn = None | |||
self.dropout = nn.Dropout(p=dropout) | |||
def forward(self, query, key, value, mask=None): | |||
"Implements Figure 2" | |||
if mask is not None: | |||
# Same mask applied to all h heads. | |||
mask = mask.unsqueeze(1) | |||
nbatches = query.size(0) | |||
# 1) Do all the linear projections in batch from d_model => h x d_k | |||
query, key, value = [ | |||
l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |||
for l, x in zip(self.linears, (query, key, value)) | |||
] | |||
# 2) Apply attention on all the projected vectors in batch. | |||
x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) | |||
# 3) "Concat" using a view and apply a final linear. | |||
x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) | |||
return self.linears[-1](x) | |||
class LayerNorm(nn.Module): | |||
"Construct a layernorm module (See citation for details)." | |||
def __init__(self, features, eps=1e-6): | |||
super(LayerNorm, self).__init__() | |||
self.a_2 = nn.Parameter(torch.ones(features)) | |||
self.b_2 = nn.Parameter(torch.zeros(features)) | |||
self.eps = eps | |||
def forward(self, x): | |||
mean = x.mean(-1, keepdim=True) | |||
std = x.std(-1, keepdim=True) | |||
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 | |||
class PositionwiseFeedForward(nn.Module): | |||
"Implements FFN equation." | |||
def __init__(self, d_model, d_ff, dropout=0.1): | |||
super(PositionwiseFeedForward, self).__init__() | |||
self.w_1 = nn.Linear(d_model, d_ff) | |||
self.w_2 = nn.Linear(d_ff, d_model) | |||
self.dropout = nn.Dropout(dropout) | |||
def forward(self, x): | |||
return self.w_2(self.dropout(F.relu(self.w_1(x)))) | |||
class SublayerConnection(nn.Module): | |||
""" | |||
A residual connection followed by a layer norm. | |||
Note for code simplicity the norm is first as opposed to last. | |||
""" | |||
def __init__(self, size, dropout): | |||
super(SublayerConnection, self).__init__() | |||
self.norm = LayerNorm(size) | |||
self.dropout = nn.Dropout(dropout) | |||
def forward(self, x, sublayer): | |||
"Apply residual connection to any sublayer with the same size." | |||
return x + self.dropout(sublayer(self.norm(x))) | |||
class EncoderLayer(nn.Module): | |||
"Encoder is made up of self-attn and feed forward (defined below)" | |||
def __init__(self, size, self_attn, feed_forward, dropout): | |||
super(EncoderLayer, self).__init__() | |||
self.self_attn = self_attn | |||
self.feed_forward = feed_forward | |||
self.sublayer = clones(SublayerConnection(size, dropout), 2) | |||
self.size = size | |||
def forward(self, x, mask): | |||
"Follow Figure 1 (left) for connections." | |||
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) | |||
return self.sublayer[1](x, self.feed_forward) | |||
class Encoder(nn.Module): | |||
"Core encoder is a stack of N layers" | |||
def __init__(self, layer, N): | |||
super(Encoder, self).__init__() | |||
self.layers = clones(layer, N) | |||
self.norm = LayerNorm(layer.size) | |||
def forward(self, x, mask): | |||
# print(x.size(),mask.size()) | |||
"Pass the input (and mask) through each layer in turn." | |||
mask = mask.byte().unsqueeze(-2) | |||
for layer in self.layers: | |||
x = layer(x, mask) | |||
return self.norm(x) | |||
def make_encoder(N=6, d_model=512, d_ff=2048, h=8, dropout=0.1): | |||
c = copy.deepcopy | |||
attn = MultiHeadedAttention(h, d_model) | |||
ff = PositionwiseFeedForward(d_model, d_ff, dropout) | |||
return Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N) |
@@ -0,0 +1,308 @@ | |||
import numpy as np | |||
import torch | |||
import torch.cuda | |||
import random | |||
import os | |||
import sys | |||
import errno | |||
import time | |||
import codecs | |||
import hashlib | |||
import _pickle as pickle | |||
import warnings | |||
from fastNLP.io import EmbedLoader | |||
UNK_TAG = "<unk>" | |||
def set_seed(seed): | |||
random.seed(seed) | |||
np.random.seed(seed) | |||
torch.manual_seed(seed) | |||
torch.cuda.manual_seed_all(seed) | |||
def bmes_to_words(chars, tags): | |||
result = [] | |||
if len(chars) == 0: | |||
return result | |||
word = chars[0] | |||
for c, t in zip(chars[1:], tags[1:]): | |||
if t.upper() == "B" or t.upper() == "S": | |||
result.append(word) | |||
word = "" | |||
word += c | |||
if len(word) != 0: | |||
result.append(word) | |||
return result | |||
def bmes_to_index(tags): | |||
result = [] | |||
if len(tags) == 0: | |||
return result | |||
word = (0, 0) | |||
for i, t in enumerate(tags): | |||
if i == 0: | |||
word = (0, 0) | |||
elif t.upper() == "B" or t.upper() == "S": | |||
result.append(word) | |||
word = (i, 0) | |||
word = (word[0], word[1] + 1) | |||
if word[1] != 0: | |||
result.append(word) | |||
return result | |||
def get_bmes(sent): | |||
x = [] | |||
y = [] | |||
for word in sent: | |||
length = len(word) | |||
tag = ["m"] * length if length > 1 else ["s"] * length | |||
if length > 1: | |||
tag[0] = "b" | |||
tag[-1] = "e" | |||
x += list(word) | |||
y += tag | |||
return x, y | |||
class CWSEvaluator: | |||
def __init__(self, i2t): | |||
self.correct_preds = 0.0 | |||
self.total_preds = 0.0 | |||
self.total_correct = 0.0 | |||
self.i2t = i2t | |||
def add_instance(self, pred_tags, gold_tags): | |||
pred_tags = [self.i2t[i] for i in pred_tags] | |||
gold_tags = [self.i2t[i] for i in gold_tags] | |||
# Evaluate PRF | |||
lab_gold_chunks = set(bmes_to_index(gold_tags)) | |||
lab_pred_chunks = set(bmes_to_index(pred_tags)) | |||
self.correct_preds += len(lab_gold_chunks & lab_pred_chunks) | |||
self.total_preds += len(lab_pred_chunks) | |||
self.total_correct += len(lab_gold_chunks) | |||
def result(self, percentage=True): | |||
p = self.correct_preds / self.total_preds if self.correct_preds > 0 else 0 | |||
r = self.correct_preds / self.total_correct if self.correct_preds > 0 else 0 | |||
f1 = 2 * p * r / (p + r) if p + r > 0 else 0 | |||
if percentage: | |||
p *= 100 | |||
r *= 100 | |||
f1 *= 100 | |||
return p, r, f1 | |||
class CWS_OOV: | |||
def __init__(self, dic): | |||
self.dic = dic | |||
self.recall = 0 | |||
self.tot = 0 | |||
def update(self, gold_sent, pred_sent): | |||
i = 0 | |||
j = 0 | |||
id = 0 | |||
for w in gold_sent: | |||
if w not in self.dic: | |||
self.tot += 1 | |||
while i + len(pred_sent[id]) <= j: | |||
i += len(pred_sent[id]) | |||
id += 1 | |||
if ( | |||
i == j | |||
and len(pred_sent[id]) == len(w) | |||
and w.find(pred_sent[id]) != -1 | |||
): | |||
self.recall += 1 | |||
j += len(w) | |||
# print(gold_sent,pred_sent,self.tot) | |||
def oov(self, percentage=True): | |||
ins = 1.0 * self.recall / self.tot | |||
if percentage: | |||
ins *= 100 | |||
return ins | |||
def get_processing_word( | |||
vocab_words=None, vocab_chars=None, lowercase=False, chars=False | |||
): | |||
def f(word): | |||
# 0. get chars of words | |||
if vocab_chars is not None and chars: | |||
char_ids = [] | |||
for char in word: | |||
# ignore chars out of vocabulary | |||
if char in vocab_chars: | |||
char_ids += [vocab_chars[char]] | |||
# 1. preprocess word | |||
if lowercase: | |||
word = word.lower() | |||
if word.isdigit(): | |||
word = "0" | |||
# 2. get id of word | |||
if vocab_words is not None: | |||
if word in vocab_words: | |||
word = vocab_words[word] | |||
else: | |||
word = vocab_words[UNK_TAG] | |||
# 3. return tuple char ids, word id | |||
if vocab_chars is not None and chars: | |||
return char_ids, word | |||
else: | |||
return word | |||
return f | |||
def append_tags(src, des, name, part, encode="utf-16"): | |||
with open("{}/{}.txt".format(src, part), encoding=encode) as input, open( | |||
"{}/{}.txt".format(des, part), "a", encoding=encode | |||
) as output: | |||
for line in input: | |||
line = line.strip() | |||
if len(line) > 0: | |||
output.write("<{}> {} </{}>".format(name, line, name)) | |||
output.write("\n") | |||
def is_dataset_tag(word): | |||
return len(word) > 2 and word[0] == "<" and word[-1] == ">" | |||
def to_tag_strings(i2ts, tag_mapping, pos_separate_col=True): | |||
senlen = len(tag_mapping) | |||
key_value_strs = [] | |||
for j in range(senlen): | |||
val = i2ts[tag_mapping[j]] | |||
pos_str = val | |||
key_value_strs.append(pos_str) | |||
return key_value_strs | |||
def to_id_list(w2i): | |||
i2w = [None] * len(w2i) | |||
for w, i in w2i.items(): | |||
i2w[i] = w | |||
return i2w | |||
def make_sure_path_exists(path): | |||
try: | |||
os.makedirs(path) | |||
except OSError as exception: | |||
if exception.errno != errno.EEXIST: | |||
raise | |||
def md5_for_file(fn): | |||
md5 = hashlib.md5() | |||
with open(fn, "rb") as f: | |||
for chunk in iter(lambda: f.read(128 * md5.block_size), b""): | |||
md5.update(chunk) | |||
return md5.hexdigest() | |||
def embedding_match_vocab( | |||
vocab, | |||
emb, | |||
ori_vocab, | |||
dtype=np.float32, | |||
padding="<pad>", | |||
unknown="<unk>", | |||
normalize=True, | |||
error="ignore", | |||
init_method=None, | |||
): | |||
dim = emb.shape[-1] | |||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||
if init_method: | |||
matrix = init_method(matrix) | |||
for word, idx in ori_vocab.word2idx.items(): | |||
try: | |||
if word == padding and vocab.padding is not None: | |||
word = vocab.padding | |||
elif word == unknown and vocab.unknown is not None: | |||
word = vocab.unknown | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
matrix[index] = emb[idx] | |||
hit_flags[index] = True | |||
except Exception as e: | |||
if error == "ignore": | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
total_hits = np.sum(hit_flags) | |||
print( | |||
"Found {} out of {} words in the pre-training embedding.".format( | |||
total_hits, len(vocab) | |||
) | |||
) | |||
if init_method is None: | |||
found_vectors = matrix[hit_flags] | |||
if len(found_vectors) != 0: | |||
mean = np.mean(found_vectors, axis=0, keepdims=True) | |||
std = np.std(found_vectors, axis=0, keepdims=True) | |||
unfound_vec_num = len(vocab) - total_hits | |||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean | |||
matrix[hit_flags == False] = r_vecs | |||
if normalize: | |||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
return matrix | |||
def embedding_load_with_cache(emb_file, cache_dir, vocab, **kwargs): | |||
def match_cache(file, cache_dir): | |||
md5 = md5_for_file(file) | |||
cache_files = os.listdir(cache_dir) | |||
for fn in cache_files: | |||
if md5 in fn.split("-")[-1]: | |||
return os.path.join(cache_dir, fn), True | |||
return ( | |||
"{}-{}.pkl".format(os.path.join(cache_dir, os.path.basename(file)), md5), | |||
False, | |||
) | |||
def get_cache(file): | |||
if not os.path.exists(file): | |||
return None | |||
with open(file, "rb") as f: | |||
emb = pickle.load(f) | |||
return emb | |||
os.makedirs(cache_dir, exist_ok=True) | |||
cache_fn, match = match_cache(emb_file, cache_dir) | |||
if not match: | |||
print("cache missed, re-generating cache at {}".format(cache_fn)) | |||
emb, ori_vocab = EmbedLoader.load_without_vocab( | |||
emb_file, padding=None, unknown=None, normalize=False | |||
) | |||
with open(cache_fn, "wb") as f: | |||
pickle.dump((emb, ori_vocab), f) | |||
else: | |||
print("cache matched at {}".format(cache_fn)) | |||
# use cache | |||
print("loading embeddings ...") | |||
emb = get_cache(cache_fn) | |||
assert emb is not None | |||
return embedding_match_vocab(vocab, emb[0], emb[1], **kwargs) |