| @@ -0,0 +1,61 @@ | |||
| # Multi-Criteria-CWS | |||
| An implementation of [Multi-Criteria Chinese Word Segmentation with Transformer](http://arxiv.org/abs/1906.12035) with fastNLP. | |||
| ## Dataset | |||
| ### Overview | |||
| We use the same datasets listed in paper. | |||
| - sighan2005 | |||
| - pku | |||
| - msr | |||
| - as | |||
| - cityu | |||
| - sighan2008 | |||
| - ctb | |||
| - ckip | |||
| - cityu (combined with data in sighan2005) | |||
| - ncc | |||
| - sxu | |||
| ### Preprocess | |||
| First, download OpenCC to convert between Traditional Chinese and Simplified Chinese. | |||
| ``` shell | |||
| pip install opencc-python-reimplemented | |||
| ``` | |||
| Then, set a path to save processed data, and run the shell script to process the data. | |||
| ```shell | |||
| export DATA_DIR=path/to/processed-data | |||
| bash make_data.sh path/to/sighan2005 path/to/sighan2008 | |||
| ``` | |||
| It would take a few minutes to finish the process. | |||
| ## Model | |||
| We use transformer to build the model, as described in paper. | |||
| ## Train | |||
| Finally, to train the model, run the shell script. | |||
| The `train.sh` takes one argument, the GPU-IDs to use, for example: | |||
| ``` shell | |||
| bash train.sh 0,1 | |||
| ``` | |||
| This command use GPUs with ID 0 and 1. | |||
| Note: Please refer to the paper for details of hyper-parameters. And modify the settings in `train.sh` to match your experiment environment. | |||
| Type | |||
| ``` shell | |||
| python main.py --help | |||
| ``` | |||
| to learn all arguments to be specified in training. | |||
| ## Performance | |||
| Results on the test sets of eight CWS datasets with multi-criteria learning. | |||
| | Dataset | MSRA | AS | PKU | CTB | CKIP | CITYU | NCC | SXU | Avg. | | |||
| | -------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | | |||
| | Original paper | 98.05 | 96.44 | 96.41 | 96.99 | 96.51 | 96.91 | 96.04 | 97.61 | 96.87 | | |||
| | Ours | 96.92 | 95.71 | 95.65 | 95.96 | 96.00 | 96.09 | 94.61 | 96.64 | 95.95 | | |||
| @@ -0,0 +1,262 @@ | |||
| import os | |||
| import re | |||
| import argparse | |||
| from opencc import OpenCC | |||
| cc = OpenCC("t2s") | |||
| from utils import make_sure_path_exists, append_tags | |||
| sighan05_root = "" | |||
| sighan08_root = "" | |||
| data_path = "" | |||
| E_pun = u",.!?[]()<>\"\"''," | |||
| C_pun = u",。!?【】()《》“”‘’、" | |||
| Table = {ord(f): ord(t) for f, t in zip(C_pun, E_pun)} | |||
| Table[12288] = 32 # 全半角空格 | |||
| def C_trans_to_E(string): | |||
| return string.translate(Table) | |||
| def normalize(ustring): | |||
| """全角转半角""" | |||
| rstring = "" | |||
| for uchar in ustring: | |||
| inside_code = ord(uchar) | |||
| if inside_code == 12288: # 全角空格直接转换 | |||
| inside_code = 32 | |||
| elif 65281 <= inside_code <= 65374: # 全角字符(除空格)根据关系转化 | |||
| inside_code -= 65248 | |||
| rstring += chr(inside_code) | |||
| return rstring | |||
| def preprocess(text): | |||
| rNUM = u"(-|\+)?\d+((\.|·)\d+)?%?" | |||
| rENG = u"[A-Za-z_]+.*" | |||
| sent = normalize(C_trans_to_E(text.strip())).split() | |||
| new_sent = [] | |||
| for word in sent: | |||
| word = re.sub(u"\s+", "", word, flags=re.U) | |||
| word = re.sub(rNUM, u"0", word, flags=re.U) | |||
| word = re.sub(rENG, u"X", word) | |||
| new_sent.append(word) | |||
| return new_sent | |||
| def to_sentence_list(text, split_long_sentence=False): | |||
| text = preprocess(text) | |||
| delimiter = set() | |||
| delimiter.update("。!?:;…、,(),;!?、,\"'") | |||
| delimiter.add("……") | |||
| sent_list = [] | |||
| sent = [] | |||
| sent_len = 0 | |||
| for word in text: | |||
| sent.append(word) | |||
| sent_len += len(word) | |||
| if word in delimiter or (split_long_sentence and sent_len >= 50): | |||
| sent_list.append(sent) | |||
| sent = [] | |||
| sent_len = 0 | |||
| if len(sent) > 0: | |||
| sent_list.append(sent) | |||
| return sent_list | |||
| def is_traditional(dataset): | |||
| return dataset in ["as", "cityu", "ckip"] | |||
| def convert_file( | |||
| src, des, need_cc=False, split_long_sentence=False, encode="utf-8-sig" | |||
| ): | |||
| with open(src, encoding=encode) as src, open(des, "w", encoding="utf-8") as des: | |||
| for line in src: | |||
| for sent in to_sentence_list(line, split_long_sentence): | |||
| line = " ".join(sent) + "\n" | |||
| if need_cc: | |||
| line = cc.convert(line) | |||
| des.write(line) | |||
| # if len(''.join(sent)) > 200: | |||
| # print(' '.join(sent)) | |||
| def split_train_dev(dataset): | |||
| root = data_path + "/" + dataset + "/raw/" | |||
| with open(root + "train-all.txt", encoding="UTF-8") as src, open( | |||
| root + "train.txt", "w", encoding="UTF-8" | |||
| ) as train, open(root + "dev.txt", "w", encoding="UTF-8") as dev: | |||
| lines = src.readlines() | |||
| idx = int(len(lines) * 0.9) | |||
| for line in lines[:idx]: | |||
| train.write(line) | |||
| for line in lines[idx:]: | |||
| dev.write(line) | |||
| def combine_files(one, two, out): | |||
| if os.path.exists(out): | |||
| os.remove(out) | |||
| with open(one, encoding="utf-8") as one, open(two, encoding="utf-8") as two, open( | |||
| out, "a", encoding="utf-8" | |||
| ) as out: | |||
| for line in one: | |||
| out.write(line) | |||
| for line in two: | |||
| out.write(line) | |||
| def bmes_tag(input_file, output_file): | |||
| with open(input_file, encoding="utf-8") as input_data, open( | |||
| output_file, "w", encoding="utf-8" | |||
| ) as output_data: | |||
| for line in input_data: | |||
| word_list = line.strip().split() | |||
| for word in word_list: | |||
| if len(word) == 1 or ( | |||
| len(word) > 2 and word[0] == "<" and word[-1] == ">" | |||
| ): | |||
| output_data.write(word + "\tS\n") | |||
| else: | |||
| output_data.write(word[0] + "\tB\n") | |||
| for w in word[1 : len(word) - 1]: | |||
| output_data.write(w + "\tM\n") | |||
| output_data.write(word[len(word) - 1] + "\tE\n") | |||
| output_data.write("\n") | |||
| def make_bmes(dataset="pku"): | |||
| path = data_path + "/" + dataset + "/" | |||
| make_sure_path_exists(path + "bmes") | |||
| bmes_tag(path + "raw/train.txt", path + "bmes/train.txt") | |||
| bmes_tag(path + "raw/train-all.txt", path + "bmes/train-all.txt") | |||
| bmes_tag(path + "raw/dev.txt", path + "bmes/dev.txt") | |||
| bmes_tag(path + "raw/test.txt", path + "bmes/test.txt") | |||
| def convert_sighan2005_dataset(dataset): | |||
| global sighan05_root | |||
| root = os.path.join(data_path, dataset) | |||
| make_sure_path_exists(root) | |||
| make_sure_path_exists(root + "/raw") | |||
| file_path = "{}/{}_training.utf8".format(sighan05_root, dataset) | |||
| convert_file( | |||
| file_path, "{}/raw/train-all.txt".format(root), is_traditional(dataset), True | |||
| ) | |||
| if dataset == "as": | |||
| file_path = "{}/{}_testing_gold.utf8".format(sighan05_root, dataset) | |||
| else: | |||
| file_path = "{}/{}_test_gold.utf8".format(sighan05_root, dataset) | |||
| convert_file( | |||
| file_path, "{}/raw/test.txt".format(root), is_traditional(dataset), False | |||
| ) | |||
| split_train_dev(dataset) | |||
| def convert_sighan2008_dataset(dataset, utf=16): | |||
| global sighan08_root | |||
| root = os.path.join(data_path, dataset) | |||
| make_sure_path_exists(root) | |||
| make_sure_path_exists(root + "/raw") | |||
| convert_file( | |||
| "{}/{}_train_utf{}.seg".format(sighan08_root, dataset, utf), | |||
| "{}/raw/train-all.txt".format(root), | |||
| is_traditional(dataset), | |||
| True, | |||
| "utf-{}".format(utf), | |||
| ) | |||
| convert_file( | |||
| "{}/{}_seg_truth&resource/{}_truth_utf{}.seg".format( | |||
| sighan08_root, dataset, dataset, utf | |||
| ), | |||
| "{}/raw/test.txt".format(root), | |||
| is_traditional(dataset), | |||
| False, | |||
| "utf-{}".format(utf), | |||
| ) | |||
| split_train_dev(dataset) | |||
| def extract_conll(src, out): | |||
| words = [] | |||
| with open(src, encoding="utf-8") as src, open(out, "w", encoding="utf-8") as out: | |||
| for line in src: | |||
| line = line.strip() | |||
| if len(line) == 0: | |||
| out.write(" ".join(words) + "\n") | |||
| words = [] | |||
| continue | |||
| cells = line.split() | |||
| words.append(cells[1]) | |||
| def make_joint_corpus(datasets, joint): | |||
| parts = ["dev", "test", "train", "train-all"] | |||
| for part in parts: | |||
| old_file = "{}/{}/raw/{}.txt".format(data_path, joint, part) | |||
| if os.path.exists(old_file): | |||
| os.remove(old_file) | |||
| elif not os.path.exists(os.path.dirname(old_file)): | |||
| os.makedirs(os.path.dirname(old_file)) | |||
| for name in datasets: | |||
| append_tags( | |||
| os.path.join(data_path, name, "raw"), | |||
| os.path.dirname(old_file), | |||
| name, | |||
| part, | |||
| encode="utf-8", | |||
| ) | |||
| def convert_all_sighan2005(datasets): | |||
| for dataset in datasets: | |||
| print(("Converting sighan bakeoff 2005 corpus: {}".format(dataset))) | |||
| convert_sighan2005_dataset(dataset) | |||
| make_bmes(dataset) | |||
| def convert_all_sighan2008(datasets): | |||
| for dataset in datasets: | |||
| print(("Converting sighan bakeoff 2008 corpus: {}".format(dataset))) | |||
| convert_sighan2008_dataset(dataset, 16) | |||
| make_bmes(dataset) | |||
| if __name__ == "__main__": | |||
| parser = argparse.ArgumentParser() | |||
| # fmt: off | |||
| parser.add_argument("--sighan05", required=True, type=str, help="path to sighan2005 dataset") | |||
| parser.add_argument("--sighan08", required=True, type=str, help="path to sighan2008 dataset") | |||
| parser.add_argument("--data_path", required=True, type=str, help="path to save dataset") | |||
| # fmt: on | |||
| args, _ = parser.parse_known_args() | |||
| sighan05_root = args.sighan05 | |||
| sighan08_root = args.sighan08 | |||
| data_path = args.data_path | |||
| print("Converting sighan2005 Simplified Chinese corpus") | |||
| datasets = "pku", "msr", "as", "cityu" | |||
| convert_all_sighan2005(datasets) | |||
| print("Combining sighan2005 corpus to one joint Simplified Chinese corpus") | |||
| datasets = "pku", "msr", "as", "cityu" | |||
| make_joint_corpus(datasets, "joint-sighan2005") | |||
| make_bmes("joint-sighan2005") | |||
| # For researchers who have access to sighan2008 corpus, use official corpora please. | |||
| print("Converting sighan2008 Simplified Chinese corpus") | |||
| datasets = "ctb", "ckip", "cityu", "ncc", "sxu" | |||
| convert_all_sighan2008(datasets) | |||
| print("Combining those 8 sighan corpora to one joint corpus") | |||
| datasets = "pku", "msr", "as", "ctb", "ckip", "cityu", "ncc", "sxu" | |||
| make_joint_corpus(datasets, "joint-sighan2008") | |||
| make_bmes("joint-sighan2008") | |||
| @@ -0,0 +1,166 @@ | |||
| import os | |||
| import sys | |||
| import codecs | |||
| import argparse | |||
| from _pickle import load, dump | |||
| import collections | |||
| from utils import get_processing_word, is_dataset_tag, make_sure_path_exists, get_bmes | |||
| from fastNLP import Instance, DataSet, Vocabulary, Const | |||
| max_len = 0 | |||
| def expand(x): | |||
| sent = ["<sos>"] + x[1:] + ["<eos>"] | |||
| return [x + y for x, y in zip(sent[:-1], sent[1:])] | |||
| def read_file(filename, processing_word=get_processing_word(lowercase=False)): | |||
| dataset = DataSet() | |||
| niter = 0 | |||
| with codecs.open(filename, "r", "utf-8-sig") as f: | |||
| words, tags = [], [] | |||
| for line in f: | |||
| line = line.strip() | |||
| if len(line) == 0 or line.startswith("-DOCSTART-"): | |||
| if len(words) != 0: | |||
| assert len(words) > 2 | |||
| if niter == 1: | |||
| print(words, tags) | |||
| niter += 1 | |||
| dataset.append(Instance(ori_words=words[:-1], ori_tags=tags[:-1])) | |||
| words, tags = [], [] | |||
| else: | |||
| word, tag = line.split() | |||
| word = processing_word(word) | |||
| words.append(word) | |||
| tags.append(tag.lower()) | |||
| dataset.apply_field(lambda x: [x[0]], field_name="ori_words", new_field_name="task") | |||
| dataset.apply_field( | |||
| lambda x: len(x), field_name="ori_tags", new_field_name="seq_len" | |||
| ) | |||
| dataset.apply_field( | |||
| lambda x: expand(x), field_name="ori_words", new_field_name="bi1" | |||
| ) | |||
| return dataset | |||
| def main(): | |||
| parser = argparse.ArgumentParser() | |||
| # fmt: off | |||
| parser.add_argument("--data_path", required=True, type=str, help="all of datasets pkl paths") | |||
| # fmt: on | |||
| options, _ = parser.parse_known_args() | |||
| train_set, test_set = DataSet(), DataSet() | |||
| input_dir = os.path.join(options.data_path, "joint-sighan2008/bmes") | |||
| options.output = os.path.join(options.data_path, "total_dataset.pkl") | |||
| print(input_dir, options.output) | |||
| for fn in os.listdir(input_dir): | |||
| if fn not in ["test.txt", "train-all.txt"]: | |||
| continue | |||
| print(fn) | |||
| abs_fn = os.path.join(input_dir, fn) | |||
| ds = read_file(abs_fn) | |||
| if "test.txt" == fn: | |||
| test_set = ds | |||
| else: | |||
| train_set = ds | |||
| print( | |||
| "num samples of total train, test: {}, {}".format(len(train_set), len(test_set)) | |||
| ) | |||
| uni_vocab = Vocabulary(min_freq=None).from_dataset( | |||
| train_set, test_set, field_name="ori_words" | |||
| ) | |||
| # bi_vocab = Vocabulary(min_freq=3, max_size=50000).from_dataset(train_set,test_set, field_name="bi1") | |||
| bi_vocab = Vocabulary(min_freq=3, max_size=None).from_dataset( | |||
| train_set, field_name="bi1", no_create_entry_dataset=[test_set] | |||
| ) | |||
| tag_vocab = Vocabulary(min_freq=None, padding="s", unknown=None).from_dataset( | |||
| train_set, field_name="ori_tags" | |||
| ) | |||
| task_vocab = Vocabulary(min_freq=None, padding=None, unknown=None).from_dataset( | |||
| train_set, field_name="task" | |||
| ) | |||
| def to_index(dataset): | |||
| uni_vocab.index_dataset(dataset, field_name="ori_words", new_field_name="uni") | |||
| tag_vocab.index_dataset(dataset, field_name="ori_tags", new_field_name="tags") | |||
| task_vocab.index_dataset(dataset, field_name="task", new_field_name="task") | |||
| dataset.apply_field(lambda x: x[1:], field_name="bi1", new_field_name="bi2") | |||
| dataset.apply_field(lambda x: x[:-1], field_name="bi1", new_field_name="bi1") | |||
| bi_vocab.index_dataset(dataset, field_name="bi1", new_field_name="bi1") | |||
| bi_vocab.index_dataset(dataset, field_name="bi2", new_field_name="bi2") | |||
| dataset.set_input("task", "uni", "bi1", "bi2", "seq_len") | |||
| dataset.set_target("tags") | |||
| return dataset | |||
| train_set = to_index(train_set) | |||
| test_set = to_index(test_set) | |||
| output = {} | |||
| output["train_set"] = train_set | |||
| output["test_set"] = test_set | |||
| output["uni_vocab"] = uni_vocab | |||
| output["bi_vocab"] = bi_vocab | |||
| output["tag_vocab"] = tag_vocab | |||
| output["task_vocab"] = task_vocab | |||
| print(tag_vocab.word2idx) | |||
| print(task_vocab.word2idx) | |||
| make_sure_path_exists(os.path.dirname(options.output)) | |||
| print("Saving dataset to {}".format(os.path.abspath(options.output))) | |||
| with open(options.output, "wb") as outfile: | |||
| dump(output, outfile) | |||
| print(len(task_vocab), len(tag_vocab), len(uni_vocab), len(bi_vocab)) | |||
| dic = {} | |||
| tokens = {} | |||
| def process(words): | |||
| name = words[0][1:-1] | |||
| if name not in dic: | |||
| dic[name] = set() | |||
| tokens[name] = 0 | |||
| tokens[name] += len(words[1:]) | |||
| dic[name].update(words[1:]) | |||
| train_set.apply_field(process, "ori_words", None) | |||
| for name in dic.keys(): | |||
| print(name, len(dic[name]), tokens[name]) | |||
| with open(os.path.join(os.path.dirname(options.output), "oovdict.pkl"), "wb") as f: | |||
| dump(dic, f) | |||
| def get_max_len(ds): | |||
| global max_len | |||
| max_len = 0 | |||
| def find_max_len(words): | |||
| global max_len | |||
| if max_len < len(words): | |||
| max_len = len(words) | |||
| ds.apply_field(find_max_len, "ori_words", None) | |||
| return max_len | |||
| print( | |||
| "train max len: {}, test max len: {}".format( | |||
| get_max_len(train_set), get_max_len(test_set) | |||
| ) | |||
| ) | |||
| if __name__ == "__main__": | |||
| main() | |||
| @@ -0,0 +1,506 @@ | |||
| import _pickle as pickle | |||
| import argparse | |||
| import collections | |||
| import logging | |||
| import math | |||
| import os | |||
| import pickle | |||
| import random | |||
| import sys | |||
| import time | |||
| from sys import maxsize | |||
| import fastNLP | |||
| import fastNLP.embeddings | |||
| import numpy as np | |||
| import torch | |||
| import torch.distributed as dist | |||
| import torch.nn as nn | |||
| from fastNLP import BucketSampler, DataSetIter, SequentialSampler, logger | |||
| from torch.nn.parallel import DistributedDataParallel | |||
| from torch.utils.data.distributed import DistributedSampler | |||
| import models | |||
| import optm | |||
| import utils | |||
| NONE_TAG = "<NONE>" | |||
| START_TAG = "<sos>" | |||
| END_TAG = "<eos>" | |||
| DEFAULT_WORD_EMBEDDING_SIZE = 100 | |||
| DEBUG_SCALE = 200 | |||
| # ===-----------------------------------------------------------------------=== | |||
| # Argument parsing | |||
| # ===-----------------------------------------------------------------------=== | |||
| # fmt: off | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument("--dataset", required=True, dest="dataset", help="processed data dir") | |||
| parser.add_argument("--word-embeddings", dest="word_embeddings", help="File from which to read in pretrained embeds") | |||
| parser.add_argument("--bigram-embeddings", dest="bigram_embeddings", help="File from which to read in pretrained embeds") | |||
| parser.add_argument("--crf", dest="crf", action="store_true", help="crf") | |||
| # parser.add_argument("--devi", default="0", dest="devi", help="gpu") | |||
| parser.add_argument("--step", default=0, dest="step", type=int,help="step") | |||
| parser.add_argument("--num-epochs", default=100, dest="num_epochs", type=int, | |||
| help="Number of full passes through training set") | |||
| parser.add_argument("--batch-size", default=128, dest="batch_size", type=int, | |||
| help="Minibatch size of training set") | |||
| parser.add_argument("--d_model", default=256, dest="d_model", type=int, help="d_model") | |||
| parser.add_argument("--d_ff", default=1024, dest="d_ff", type=int, help="d_ff") | |||
| parser.add_argument("--N", default=6, dest="N", type=int, help="N") | |||
| parser.add_argument("--h", default=4, dest="h", type=int, help="h") | |||
| parser.add_argument("--factor", default=2, dest="factor", type=float, help="Initial learning rate") | |||
| parser.add_argument("--dropout", default=0.2, dest="dropout", type=float, | |||
| help="Amount of dropout(not keep rate, but drop rate) to apply to embeddings part of graph") | |||
| parser.add_argument("--log-dir", default="result", dest="log_dir", | |||
| help="Directory where to write logs / serialized models") | |||
| parser.add_argument("--task-name", default=time.strftime("%Y-%m-%d-%H-%M-%S"), dest="task_name", | |||
| help="Name for this task, use a comprehensive one") | |||
| parser.add_argument("--no-model", dest="no_model", action="store_true", help="Don't serialize model") | |||
| parser.add_argument("--always-model", dest="always_model", action="store_true", | |||
| help="Always serialize model after every epoch") | |||
| parser.add_argument("--old-model", dest="old_model", help="Path to old model for incremental training") | |||
| parser.add_argument("--skip-dev", dest="skip_dev", action="store_true", help="Skip dev set, would save some time") | |||
| parser.add_argument("--freeze", dest="freeze", action="store_true", help="freeze pretrained embedding") | |||
| parser.add_argument("--only-task", dest="only_task", action="store_true", help="only train task embedding") | |||
| parser.add_argument("--subset", dest="subset", help="Only train and test on a subset of the whole dataset") | |||
| parser.add_argument("--seclude", dest="seclude", help="train and test except a subset") | |||
| parser.add_argument("--instances", default=None, dest="instances", type=int,help="num of instances of subset") | |||
| parser.add_argument("--seed", dest="python_seed", type=int, default=random.randrange(maxsize), | |||
| help="Random seed of Python and NumPy") | |||
| parser.add_argument("--debug", dest="debug", default=False, action="store_true", help="Debug mode") | |||
| parser.add_argument("--test", dest="test", action="store_true", help="Test mode") | |||
| parser.add_argument('--local_rank', type=int, default=None) | |||
| parser.add_argument('--init_method', type=str, default='env://') | |||
| # fmt: on | |||
| options, _ = parser.parse_known_args() | |||
| print("unknown args", _) | |||
| task_name = options.task_name | |||
| root_dir = "{}/{}".format(options.log_dir, task_name) | |||
| utils.make_sure_path_exists(root_dir) | |||
| if options.local_rank is not None: | |||
| torch.cuda.set_device(options.local_rank) | |||
| dist.init_process_group("nccl", init_method=options.init_method) | |||
| def init_logger(): | |||
| if not os.path.exists(root_dir): | |||
| os.mkdir(root_dir) | |||
| log_formatter = logging.Formatter("%(asctime)s - %(message)s") | |||
| logger = logging.getLogger() | |||
| file_handler = logging.FileHandler("{0}/info.log".format(root_dir), mode="w") | |||
| file_handler.setFormatter(log_formatter) | |||
| logger.addHandler(file_handler) | |||
| console_handler = logging.StreamHandler() | |||
| console_handler.setFormatter(log_formatter) | |||
| logger.addHandler(console_handler) | |||
| if options.local_rank is None or options.local_rank == 0: | |||
| logger.setLevel(logging.INFO) | |||
| else: | |||
| logger.setLevel(logging.WARNING) | |||
| return logger | |||
| # ===-----------------------------------------------------------------------=== | |||
| # Set up logging | |||
| # ===-----------------------------------------------------------------------=== | |||
| # logger = init_logger() | |||
| logger.add_file("{}/info.log".format(root_dir), "INFO") | |||
| logger.setLevel(logging.INFO if dist.get_rank() == 0 else logging.WARNING) | |||
| # ===-----------------------------------------------------------------------=== | |||
| # Log some stuff about this run | |||
| # ===-----------------------------------------------------------------------=== | |||
| logger.info(" ".join(sys.argv)) | |||
| logger.info("") | |||
| logger.info(options) | |||
| if options.debug: | |||
| logger.info("DEBUG MODE") | |||
| options.num_epochs = 2 | |||
| options.batch_size = 20 | |||
| random.seed(options.python_seed) | |||
| np.random.seed(options.python_seed % (2 ** 32 - 1)) | |||
| torch.cuda.manual_seed_all(options.python_seed) | |||
| logger.info("Python random seed: {}".format(options.python_seed)) | |||
| # ===-----------------------------------------------------------------------=== | |||
| # Read in dataset | |||
| # ===-----------------------------------------------------------------------=== | |||
| dataset = pickle.load(open(options.dataset + "/total_dataset.pkl", "rb")) | |||
| train_set = dataset["train_set"] | |||
| test_set = dataset["test_set"] | |||
| uni_vocab = dataset["uni_vocab"] | |||
| bi_vocab = dataset["bi_vocab"] | |||
| task_vocab = dataset["task_vocab"] | |||
| tag_vocab = dataset["tag_vocab"] | |||
| for v in (bi_vocab, uni_vocab, tag_vocab, task_vocab): | |||
| if hasattr(v, "_word2idx"): | |||
| v.word2idx = v._word2idx | |||
| for ds in (train_set, test_set): | |||
| ds.rename_field("ori_words", "words") | |||
| logger.info("{} {}".format(bi_vocab.to_word(0), tag_vocab.word2idx)) | |||
| logger.info(task_vocab.word2idx) | |||
| if options.skip_dev: | |||
| dev_set = test_set | |||
| else: | |||
| train_set, dev_set = train_set.split(0.1) | |||
| logger.info("{} {} {}".format(len(train_set), len(dev_set), len(test_set))) | |||
| if options.debug: | |||
| train_set = train_set[0:DEBUG_SCALE] | |||
| dev_set = dev_set[0:DEBUG_SCALE] | |||
| test_set = test_set[0:DEBUG_SCALE] | |||
| # ===-----------------------------------------------------------------------=== | |||
| # Build model and trainer | |||
| # ===-----------------------------------------------------------------------=== | |||
| # =============================== | |||
| if dist.get_rank() != 0: | |||
| dist.barrier() | |||
| if options.word_embeddings is None: | |||
| init_embedding = None | |||
| else: | |||
| # logger.info("Load: {}".format(options.word_embeddings)) | |||
| # init_embedding = utils.embedding_load_with_cache(options.word_embeddings, options.cache_dir, uni_vocab, normalize=False) | |||
| init_embedding = fastNLP.embeddings.StaticEmbedding( | |||
| uni_vocab, options.word_embeddings, word_drop=0.01 | |||
| ) | |||
| bigram_embedding = None | |||
| if options.bigram_embeddings: | |||
| # logger.info("Load: {}".format(options.bigram_embeddings)) | |||
| # bigram_embedding = utils.embedding_load_with_cache(options.bigram_embeddings, options.cache_dir, bi_vocab, normalize=False) | |||
| bigram_embedding = fastNLP.embeddings.StaticEmbedding( | |||
| bi_vocab, options.bigram_embeddings | |||
| ) | |||
| if dist.get_rank() == 0: | |||
| dist.barrier() | |||
| # =============================== | |||
| # select subset training | |||
| if options.seclude is not None: | |||
| setname = "<{}>".format(options.seclude) | |||
| logger.info("seclude {}".format(setname)) | |||
| train_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||
| test_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||
| dev_set.drop(lambda x: x["words"][0] == setname, inplace=True) | |||
| if options.subset is not None: | |||
| setname = "<{}>".format(options.subset) | |||
| logger.info("select {}".format(setname)) | |||
| train_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||
| test_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||
| dev_set.drop(lambda x: x["words"][0] != setname, inplace=True) | |||
| # build model and optimizer | |||
| i2t = None | |||
| if options.crf: | |||
| # i2t=utils.to_id_list(tag_vocab.word2idx) | |||
| i2t = {} | |||
| for x, y in tag_vocab.word2idx.items(): | |||
| i2t[y] = x | |||
| logger.info(i2t) | |||
| freeze = True if options.freeze else False | |||
| model = models.make_CWS( | |||
| d_model=options.d_model, | |||
| N=options.N, | |||
| h=options.h, | |||
| d_ff=options.d_ff, | |||
| dropout=options.dropout, | |||
| word_embedding=init_embedding, | |||
| bigram_embedding=bigram_embedding, | |||
| tag_size=len(tag_vocab), | |||
| task_size=len(task_vocab), | |||
| crf=i2t, | |||
| freeze=freeze, | |||
| ) | |||
| device = "cpu" | |||
| if torch.cuda.device_count() > 0: | |||
| if options.local_rank is not None: | |||
| device = "cuda:{}".format(options.local_rank) | |||
| # model=nn.DataParallel(model) | |||
| model = model.to(device) | |||
| model = torch.nn.parallel.DistributedDataParallel( | |||
| model, device_ids=[options.local_rank], output_device=options.local_rank | |||
| ) | |||
| else: | |||
| device = "cuda:0" | |||
| model.to(device) | |||
| if options.only_task and options.old_model is not None: | |||
| logger.info("fix para except task embedding") | |||
| for name, para in model.named_parameters(): | |||
| if name.find("task_embed") == -1: | |||
| para.requires_grad = False | |||
| else: | |||
| para.requires_grad = True | |||
| logger.info(name) | |||
| optimizer = optm.NoamOpt( | |||
| options.d_model, | |||
| options.factor, | |||
| 4000, | |||
| torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9), | |||
| ) | |||
| optimizer._step = options.step | |||
| best_model_file_name = "{}/model.bin".format(root_dir) | |||
| if options.local_rank is None: | |||
| train_sampler = BucketSampler( | |||
| batch_size=options.batch_size, seq_len_field_name="seq_len" | |||
| ) | |||
| else: | |||
| train_sampler = DistributedSampler( | |||
| train_set, dist.get_world_size(), dist.get_rank() | |||
| ) | |||
| dev_sampler = SequentialSampler() | |||
| i2t = utils.to_id_list(tag_vocab.word2idx) | |||
| i2task = utils.to_id_list(task_vocab.word2idx) | |||
| dev_set.set_input("words") | |||
| test_set.set_input("words") | |||
| test_batch = DataSetIter(test_set, options.batch_size, num_workers=2) | |||
| word_dic = pickle.load(open(options.dataset + "/oovdict.pkl", "rb")) | |||
| def batch_to_device(batch, device): | |||
| for k, v in batch.items(): | |||
| if torch.is_tensor(v): | |||
| batch[k] = v.to(device) | |||
| return batch | |||
| def tester(model, test_batch, write_out=False): | |||
| res = [] | |||
| prf = utils.CWSEvaluator(i2t) | |||
| prf_dataset = {} | |||
| oov_dataset = {} | |||
| logger.info("start evaluation") | |||
| # import ipdb; ipdb.set_trace() | |||
| with torch.no_grad(): | |||
| for batch_x, batch_y in test_batch: | |||
| batch_to_device(batch_x, device) | |||
| # batch_to_device(batch_y, device) | |||
| if bigram_embedding is not None: | |||
| out = model( | |||
| batch_x["task"], | |||
| batch_x["uni"], | |||
| batch_x["seq_len"], | |||
| batch_x["bi1"], | |||
| batch_x["bi2"], | |||
| ) | |||
| else: | |||
| out = model(batch_x["task"], batch_x["uni"], batch_x["seq_len"]) | |||
| out = out["pred"] | |||
| # print(out) | |||
| num = out.size(0) | |||
| out = out.detach().cpu().numpy() | |||
| for i in range(num): | |||
| length = int(batch_x["seq_len"][i]) | |||
| out_tags = out[i, 1:length].tolist() | |||
| sentence = batch_x["words"][i] | |||
| gold_tags = batch_y["tags"][i][1:length].numpy().tolist() | |||
| dataset_name = sentence[0] | |||
| sentence = sentence[1:] | |||
| # print(out_tags,gold_tags) | |||
| assert utils.is_dataset_tag(dataset_name), dataset_name | |||
| assert len(gold_tags) == len(out_tags) and len(gold_tags) == len( | |||
| sentence | |||
| ) | |||
| if dataset_name not in prf_dataset: | |||
| prf_dataset[dataset_name] = utils.CWSEvaluator(i2t) | |||
| oov_dataset[dataset_name] = utils.CWS_OOV( | |||
| word_dic[dataset_name[1:-1]] | |||
| ) | |||
| prf_dataset[dataset_name].add_instance(gold_tags, out_tags) | |||
| prf.add_instance(gold_tags, out_tags) | |||
| if write_out: | |||
| gold_strings = utils.to_tag_strings(i2t, gold_tags) | |||
| obs_strings = utils.to_tag_strings(i2t, out_tags) | |||
| word_list = utils.bmes_to_words(sentence, obs_strings) | |||
| oov_dataset[dataset_name].update( | |||
| utils.bmes_to_words(sentence, gold_strings), word_list | |||
| ) | |||
| raw_string = " ".join(word_list) | |||
| res.append(dataset_name + " " + raw_string + " " + dataset_name) | |||
| Ap = 0.0 | |||
| Ar = 0.0 | |||
| Af = 0.0 | |||
| Aoov = 0.0 | |||
| tot = 0 | |||
| nw = 0.0 | |||
| for dataset_name, performance in sorted(prf_dataset.items()): | |||
| p = performance.result() | |||
| if write_out: | |||
| nw = oov_dataset[dataset_name].oov() | |||
| # nw = 0 | |||
| logger.info( | |||
| "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
| dataset_name, p[0], p[1], p[2], nw | |||
| ) | |||
| ) | |||
| else: | |||
| logger.info( | |||
| "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
| dataset_name, p[0], p[1], p[2] | |||
| ) | |||
| ) | |||
| Ap += p[0] | |||
| Ar += p[1] | |||
| Af += p[2] | |||
| Aoov += nw | |||
| tot += 1 | |||
| prf = prf.result() | |||
| logger.info( | |||
| "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format("TOT", prf[0], prf[1], prf[2]) | |||
| ) | |||
| if not write_out: | |||
| logger.info( | |||
| "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
| "AVG", Ap / tot, Ar / tot, Af / tot | |||
| ) | |||
| ) | |||
| else: | |||
| logger.info( | |||
| "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( | |||
| "AVG", Ap / tot, Ar / tot, Af / tot, Aoov / tot | |||
| ) | |||
| ) | |||
| return prf[-1], res | |||
| # start training | |||
| if not options.test: | |||
| if options.old_model: | |||
| # incremental training | |||
| logger.info("Incremental training from old model: {}".format(options.old_model)) | |||
| model.load_state_dict(torch.load(options.old_model, map_location="cuda:0")) | |||
| logger.info("Number training instances: {}".format(len(train_set))) | |||
| logger.info("Number dev instances: {}".format(len(dev_set))) | |||
| train_batch = DataSetIter( | |||
| batch_size=options.batch_size, | |||
| dataset=train_set, | |||
| sampler=train_sampler, | |||
| num_workers=4, | |||
| ) | |||
| dev_batch = DataSetIter( | |||
| batch_size=options.batch_size, | |||
| dataset=dev_set, | |||
| sampler=dev_sampler, | |||
| num_workers=4, | |||
| ) | |||
| best_f1 = 0.0 | |||
| for epoch in range(int(options.num_epochs)): | |||
| logger.info("Epoch {} out of {}".format(epoch + 1, options.num_epochs)) | |||
| train_loss = 0.0 | |||
| model.train() | |||
| tot = 0 | |||
| t1 = time.time() | |||
| for batch_x, batch_y in train_batch: | |||
| model.zero_grad() | |||
| if bigram_embedding is not None: | |||
| out = model( | |||
| batch_x["task"], | |||
| batch_x["uni"], | |||
| batch_x["seq_len"], | |||
| batch_x["bi1"], | |||
| batch_x["bi2"], | |||
| batch_y["tags"], | |||
| ) | |||
| else: | |||
| out = model( | |||
| batch_x["task"], batch_x["uni"], batch_x["seq_len"], batch_y["tags"] | |||
| ) | |||
| loss = out["loss"] | |||
| train_loss += loss.item() | |||
| tot += 1 | |||
| loss.backward() | |||
| # nn.utils.clip_grad_value_(model.parameters(), 1) | |||
| optimizer.step() | |||
| t2 = time.time() | |||
| train_loss = train_loss / tot | |||
| logger.info( | |||
| "time: {} loss: {} step: {}".format(t2 - t1, train_loss, optimizer._step) | |||
| ) | |||
| # Evaluate dev data | |||
| if options.skip_dev and dist.get_rank() == 0: | |||
| logger.info("Saving model to {}".format(best_model_file_name)) | |||
| torch.save(model.module.state_dict(), best_model_file_name) | |||
| continue | |||
| model.eval() | |||
| if dist.get_rank() == 0: | |||
| f1, _ = tester(model.module, dev_batch) | |||
| if f1 > best_f1: | |||
| best_f1 = f1 | |||
| logger.info("- new best score!") | |||
| if not options.no_model: | |||
| logger.info("Saving model to {}".format(best_model_file_name)) | |||
| torch.save(model.module.state_dict(), best_model_file_name) | |||
| elif options.always_model: | |||
| logger.info("Saving model to {}".format(best_model_file_name)) | |||
| torch.save(model.module.state_dict(), best_model_file_name) | |||
| dist.barrier() | |||
| # Evaluate test data (once) | |||
| logger.info("\nNumber test instances: {}".format(len(test_set))) | |||
| if not options.skip_dev: | |||
| if options.test: | |||
| model.module.load_state_dict( | |||
| torch.load(options.old_model, map_location="cuda:0") | |||
| ) | |||
| else: | |||
| model.module.load_state_dict( | |||
| torch.load(best_model_file_name, map_location="cuda:0") | |||
| ) | |||
| if dist.get_rank() == 0: | |||
| for name, para in model.named_parameters(): | |||
| if name.find("task_embed") != -1: | |||
| tm = para.detach().cpu().numpy() | |||
| logger.info(tm.shape) | |||
| np.save("{}/task.npy".format(root_dir), tm) | |||
| break | |||
| _, res = tester(model.module, test_batch, True) | |||
| if dist.get_rank() == 0: | |||
| with open("{}/testout.txt".format(root_dir), "w", encoding="utf-8") as raw_writer: | |||
| for sent in res: | |||
| raw_writer.write(sent) | |||
| raw_writer.write("\n") | |||
| @@ -0,0 +1,14 @@ | |||
| if [ -z "$DATA_DIR" ] | |||
| then | |||
| DATA_DIR="./data" | |||
| fi | |||
| mkdir -vp $DATA_DIR | |||
| cmd="python -u ./data-prepare.py --sighan05 $1 --sighan08 $2 --data_path $DATA_DIR" | |||
| echo $cmd | |||
| eval $cmd | |||
| cmd="python -u ./data-process.py --data_path $DATA_DIR" | |||
| echo $cmd | |||
| eval $cmd | |||
| @@ -0,0 +1,200 @@ | |||
| import fastNLP | |||
| import torch | |||
| import math | |||
| from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
| from fastNLP.modules.decoder.crf import ConditionalRandomField | |||
| from fastNLP import Const | |||
| import copy | |||
| import numpy as np | |||
| from torch.autograd import Variable | |||
| import torch.autograd as autograd | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import transformer | |||
| class PositionalEncoding(nn.Module): | |||
| "Implement the PE function." | |||
| def __init__(self, d_model, dropout, max_len=512): | |||
| super(PositionalEncoding, self).__init__() | |||
| self.dropout = nn.Dropout(p=dropout) | |||
| # Compute the positional encodings once in log space. | |||
| pe = torch.zeros(max_len, d_model).float() | |||
| position = torch.arange(0, max_len).unsqueeze(1).float() | |||
| div_term = torch.exp( | |||
| torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) | |||
| ) | |||
| pe[:, 0::2] = torch.sin(position * div_term) | |||
| pe[:, 1::2] = torch.cos(position * div_term) | |||
| pe = pe.unsqueeze(0) | |||
| self.register_buffer("pe", pe) | |||
| def forward(self, x): | |||
| x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False) | |||
| return self.dropout(x) | |||
| class Embedding(nn.Module): | |||
| def __init__( | |||
| self, | |||
| task_size, | |||
| d_model, | |||
| word_embedding=None, | |||
| bi_embedding=None, | |||
| word_size=None, | |||
| freeze=True, | |||
| ): | |||
| super(Embedding, self).__init__() | |||
| self.task_size = task_size | |||
| self.embed_dim = 0 | |||
| self.task_embed = nn.Embedding(task_size, d_model) | |||
| if word_embedding is not None: | |||
| # self.uni_embed = nn.Embedding.from_pretrained(torch.FloatTensor(word_embedding), freeze=freeze) | |||
| # self.embed_dim+=word_embedding.shape[1] | |||
| self.uni_embed = word_embedding | |||
| self.embed_dim += word_embedding.embedding_dim | |||
| else: | |||
| if bi_embedding is not None: | |||
| self.embed_dim += bi_embedding.shape[1] | |||
| else: | |||
| self.embed_dim = d_model | |||
| assert word_size is not None | |||
| self.uni_embed = Embedding(word_size, self.embed_dim) | |||
| if bi_embedding is not None: | |||
| # self.bi_embed = nn.Embedding.from_pretrained(torch.FloatTensor(bi_embedding), freeze=freeze) | |||
| # self.embed_dim += bi_embedding.shape[1]*2 | |||
| self.bi_embed = bi_embedding | |||
| self.embed_dim += bi_embedding.embedding_dim * 2 | |||
| print("Trans Freeze", freeze, self.embed_dim) | |||
| if d_model != self.embed_dim: | |||
| self.F = nn.Linear(self.embed_dim, d_model) | |||
| else: | |||
| self.F = None | |||
| self.d_model = d_model | |||
| def forward(self, task, uni, bi1=None, bi2=None): | |||
| y_task = self.task_embed(task[:, 0:1]) | |||
| y = self.uni_embed(uni[:, 1:]) | |||
| if bi1 is not None: | |||
| assert self.bi_embed is not None | |||
| y = torch.cat([y, self.bi_embed(bi1), self.bi_embed(bi2)], dim=-1) | |||
| # y2=self.bi_embed(bi) | |||
| # y=torch.cat([y,y2[:,:-1,:],y2[:,1:,:]],dim=-1) | |||
| # y=torch.cat([y_task,y],dim=1) | |||
| if self.F is not None: | |||
| y = self.F(y) | |||
| y = torch.cat([y_task, y], dim=1) | |||
| return y * math.sqrt(self.d_model) | |||
| def seq_len_to_mask(seq_len, max_len=None): | |||
| if isinstance(seq_len, np.ndarray): | |||
| assert ( | |||
| len(np.shape(seq_len)) == 1 | |||
| ), f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
| if max_len is None: | |||
| max_len = int(seq_len.max()) | |||
| broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
| mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
| elif isinstance(seq_len, torch.Tensor): | |||
| assert ( | |||
| seq_len.dim() == 1 | |||
| ), f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||
| batch_size = seq_len.size(0) | |||
| if max_len is None: | |||
| max_len = seq_len.max().long() | |||
| broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) | |||
| mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||
| else: | |||
| raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | |||
| return mask | |||
| class CWSModel(nn.Module): | |||
| def __init__(self, encoder, src_embed, position, d_model, tag_size, crf=None): | |||
| super(CWSModel, self).__init__() | |||
| self.encoder = encoder | |||
| self.src_embed = src_embed | |||
| self.pos = copy.deepcopy(position) | |||
| self.proj = nn.Linear(d_model, tag_size) | |||
| self.tag_size = tag_size | |||
| if crf is None: | |||
| self.crf = None | |||
| self.loss_f = nn.CrossEntropyLoss(reduction="mean", ignore_index=-100) | |||
| else: | |||
| print("crf") | |||
| trans = fastNLP.modules.decoder.crf.allowed_transitions( | |||
| crf, encoding_type="bmes" | |||
| ) | |||
| self.crf = ConditionalRandomField(tag_size, allowed_transitions=trans) | |||
| # self.norm=nn.LayerNorm(d_model) | |||
| def forward(self, task, uni, seq_len, bi1=None, bi2=None, tags=None): | |||
| # mask=fastNLP.core.utils.seq_len_to_mask(seq_len,uni.size(1)) # for dev 0.5.1 | |||
| mask = seq_len_to_mask(seq_len, uni.size(1)) | |||
| out = self.src_embed(task, uni, bi1, bi2) | |||
| out = self.pos(out) | |||
| # out=self.norm(out) | |||
| out = self.proj(self.encoder(out, mask.float())) | |||
| if self.crf is not None: | |||
| if tags is not None: | |||
| out = self.crf(out, tags, mask) | |||
| return {"loss": out} | |||
| else: | |||
| out, _ = self.crf.viterbi_decode(out, mask) | |||
| return {"pred": out} | |||
| else: | |||
| if tags is not None: | |||
| out = out.contiguous().view(-1, self.tag_size) | |||
| tags = tags.data.masked_fill_(mask == 0, -100).view(-1) | |||
| loss = self.loss_f(out, tags) | |||
| return {"loss": loss} | |||
| else: | |||
| out = torch.argmax(out, dim=-1) | |||
| return {"pred": out} | |||
| def make_CWS( | |||
| N=6, | |||
| d_model=256, | |||
| d_ff=1024, | |||
| h=4, | |||
| dropout=0.2, | |||
| tag_size=4, | |||
| task_size=8, | |||
| bigram_embedding=None, | |||
| word_embedding=None, | |||
| word_size=None, | |||
| crf=None, | |||
| freeze=True, | |||
| ): | |||
| c = copy.deepcopy | |||
| # encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) | |||
| encoder = transformer.make_encoder( | |||
| N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff | |||
| ) | |||
| position = PositionalEncoding(d_model, dropout) | |||
| embed = Embedding( | |||
| task_size, d_model, word_embedding, bigram_embedding, word_size, freeze | |||
| ) | |||
| model = CWSModel(encoder, embed, position, d_model, tag_size, crf=crf) | |||
| for p in model.parameters(): | |||
| if p.dim() > 1 and p.requires_grad: | |||
| nn.init.xavier_uniform_(p) | |||
| return model | |||
| @@ -0,0 +1,49 @@ | |||
| import torch | |||
| import torch.optim as optim | |||
| class NoamOpt: | |||
| "Optim wrapper that implements rate." | |||
| def __init__(self, model_size, factor, warmup, optimizer): | |||
| self.optimizer = optimizer | |||
| self._step = 0 | |||
| self.warmup = warmup | |||
| self.factor = factor | |||
| self.model_size = model_size | |||
| self._rate = 0 | |||
| def step(self): | |||
| "Update parameters and rate" | |||
| self._step += 1 | |||
| rate = self.rate() | |||
| for p in self.optimizer.param_groups: | |||
| p["lr"] = rate | |||
| self._rate = rate | |||
| self.optimizer.step() | |||
| def rate(self, step=None): | |||
| "Implement `lrate` above" | |||
| if step is None: | |||
| step = self._step | |||
| lr = self.factor * ( | |||
| self.model_size ** (-0.5) | |||
| * min(step ** (-0.5), step * self.warmup ** (-1.5)) | |||
| ) | |||
| # if step>self.warmup: lr = max(1e-4,lr) | |||
| return lr | |||
| def get_std_opt(model): | |||
| return NoamOpt( | |||
| model.src_embed[0].d_model, | |||
| 2, | |||
| 4000, | |||
| torch.optim.Adam( | |||
| filter(lambda p: p.requires_grad, model.parameters()), | |||
| lr=0, | |||
| betas=(0.9, 0.98), | |||
| eps=1e-9, | |||
| ), | |||
| ) | |||
| @@ -0,0 +1,138 @@ | |||
| from fastNLP import (Trainer, Tester, Callback, GradientClipCallback, LRScheduler, SpanFPreRecMetric) | |||
| import torch | |||
| import torch.cuda | |||
| from torch.optim import Adam, SGD | |||
| from argparse import ArgumentParser | |||
| import logging | |||
| from .utils import set_seed | |||
| class LoggingCallback(Callback): | |||
| def __init__(self, filepath=None): | |||
| super().__init__() | |||
| # create file handler and set level to debug | |||
| if filepath is not None: | |||
| file_handler = logging.FileHandler(filepath, "a") | |||
| else: | |||
| file_handler = logging.StreamHandler() | |||
| file_handler.setLevel(logging.DEBUG) | |||
| file_handler.setFormatter( | |||
| logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s', | |||
| datefmt='%m/%d/%Y %H:%M:%S')) | |||
| # create logger and set level to debug | |||
| logger = logging.getLogger() | |||
| logger.handlers = [] | |||
| logger.setLevel(logging.DEBUG) | |||
| logger.propagate = False | |||
| logger.addHandler(file_handler) | |||
| self.log_writer = logger | |||
| def on_backward_begin(self, loss): | |||
| if self.step % self.trainer.print_every == 0: | |||
| self.log_writer.info( | |||
| 'Step/Epoch {}/{}: Loss {}'.format(self.step, self.epoch, loss.item())) | |||
| def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
| self.log_writer.info( | |||
| 'Step/Epoch {}/{}: Eval result {}'.format(self.step, self.epoch, eval_result)) | |||
| def on_backward_end(self): | |||
| pass | |||
| def main(): | |||
| parser = ArgumentParser() | |||
| register_args(parser) | |||
| args = parser.parse_known_args()[0] | |||
| set_seed(args.seed) | |||
| if args.train: | |||
| train(args) | |||
| if args.eval: | |||
| evaluate(args) | |||
| def get_optim(args): | |||
| name = args.optim.strip().split(' ')[0].lower() | |||
| p = args.optim.strip() | |||
| l = p.find('(') | |||
| r = p.find(')') | |||
| optim_args = eval('dict({})'.format(p[[l+1,r]])) | |||
| if name == 'sgd': | |||
| return SGD(**optim_args) | |||
| elif name == 'adam': | |||
| return Adam(**optim_args) | |||
| else: | |||
| raise ValueError(args.optim) | |||
| def load_model_from_path(args): | |||
| pass | |||
| def train(args): | |||
| data = get_data(args) | |||
| train_data = data['train'] | |||
| dev_data = data['dev'] | |||
| model = get_model(args) | |||
| optimizer = get_optim(args) | |||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
| callbacks = [] | |||
| trainer = Trainer( | |||
| train_data=train_data, | |||
| model=model, | |||
| optimizer=optimizer, | |||
| loss=None, | |||
| batch_size=args.batch_size, | |||
| n_epochs=args.epochs, | |||
| num_workers=4, | |||
| metrics=SpanFPreRecMetric( | |||
| tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], | |||
| ignore_labels=data['ignore_labels']), | |||
| metric_key='f1', | |||
| dev_data=dev_data, | |||
| save_path=args.save_path, | |||
| device=device, | |||
| callbacks=callbacks, | |||
| check_code_level=-1, | |||
| ) | |||
| print(trainer.train()) | |||
| def evaluate(args): | |||
| data = get_data(args) | |||
| test_data = data['test'] | |||
| model = load_model_from_path(args) | |||
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
| tester = Tester( | |||
| data=test_data, model=model, batch_size=args.batch_size, | |||
| num_workers=2, device=device, | |||
| metrics=SpanFPreRecMetric( | |||
| tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], | |||
| ignore_labels=data['ignore_labels']), | |||
| ) | |||
| print(tester.test()) | |||
| def register_args(parser): | |||
| parser.add_argument('--optim', type=str, default='adam (lr=2e-3, weight_decay=0.0)') | |||
| parser.add_argument('--batch_size', type=int, default=128) | |||
| parser.add_argument('--epochs', type=int, default=10) | |||
| parser.add_argument('--save_path', type=str, default=None) | |||
| parser.add_argument('--data_path', type=str, required=True) | |||
| parser.add_argument('--log_path', type=str, default=None) | |||
| parser.add_argument('--model_config', type=str, required=True) | |||
| parser.add_argument('--load_path', type=str, default=None) | |||
| parser.add_argument('--train', action='store_true', default=False) | |||
| parser.add_argument('--eval', action='store_true', default=False) | |||
| parser.add_argument('--seed', type=int, default=42, help='rng seed') | |||
| def get_model(args): | |||
| pass | |||
| def get_data(args): | |||
| return torch.load(args.data_path) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -0,0 +1,26 @@ | |||
| export EXP_NAME=release04 | |||
| export NGPU=2 | |||
| export PORT=9988 | |||
| export CUDA_DEVICE_ORDER=PCI_BUS_ID | |||
| export CUDA_VISIBLE_DEVICES=$1 | |||
| if [ -z "$DATA_DIR" ] | |||
| then | |||
| DATA_DIR="./data" | |||
| fi | |||
| echo $CUDA_VISIBLE_DEVICES | |||
| cmd=" | |||
| python -m torch.distributed.launch --nproc_per_node=$NGPU --master_port $PORT\ | |||
| main.py \ | |||
| --word-embeddings cn-char-fastnlp-100d \ | |||
| --bigram-embeddings cn-bi-fastnlp-100d \ | |||
| --num-epochs 100 \ | |||
| --batch-size 256 \ | |||
| --seed 1234 \ | |||
| --task-name $EXP_NAME \ | |||
| --dataset $DATA_DIR \ | |||
| --freeze \ | |||
| " | |||
| echo $cmd | |||
| eval $cmd | |||
| @@ -0,0 +1,152 @@ | |||
| import numpy as np | |||
| import torch | |||
| import torch.autograd as autograd | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| import math, copy, time | |||
| from torch.autograd import Variable | |||
| # import matplotlib.pyplot as plt | |||
| def clones(module, N): | |||
| "Produce N identical layers." | |||
| return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) | |||
| def subsequent_mask(size): | |||
| "Mask out subsequent positions." | |||
| attn_shape = (1, size, size) | |||
| subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") | |||
| return torch.from_numpy(subsequent_mask) == 0 | |||
| def attention(query, key, value, mask=None, dropout=None): | |||
| "Compute 'Scaled Dot Product Attention'" | |||
| d_k = query.size(-1) | |||
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) | |||
| if mask is not None: | |||
| # print(scores.size(),mask.size()) # [bsz,1,1,len] | |||
| scores = scores.masked_fill(mask == 0, -1e9) | |||
| p_attn = F.softmax(scores, dim=-1) | |||
| if dropout is not None: | |||
| p_attn = dropout(p_attn) | |||
| return torch.matmul(p_attn, value), p_attn | |||
| class MultiHeadedAttention(nn.Module): | |||
| def __init__(self, h, d_model, dropout=0.1): | |||
| "Take in model size and number of heads." | |||
| super(MultiHeadedAttention, self).__init__() | |||
| assert d_model % h == 0 | |||
| # We assume d_v always equals d_k | |||
| self.d_k = d_model // h | |||
| self.h = h | |||
| self.linears = clones(nn.Linear(d_model, d_model), 4) | |||
| self.attn = None | |||
| self.dropout = nn.Dropout(p=dropout) | |||
| def forward(self, query, key, value, mask=None): | |||
| "Implements Figure 2" | |||
| if mask is not None: | |||
| # Same mask applied to all h heads. | |||
| mask = mask.unsqueeze(1) | |||
| nbatches = query.size(0) | |||
| # 1) Do all the linear projections in batch from d_model => h x d_k | |||
| query, key, value = [ | |||
| l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) | |||
| for l, x in zip(self.linears, (query, key, value)) | |||
| ] | |||
| # 2) Apply attention on all the projected vectors in batch. | |||
| x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) | |||
| # 3) "Concat" using a view and apply a final linear. | |||
| x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) | |||
| return self.linears[-1](x) | |||
| class LayerNorm(nn.Module): | |||
| "Construct a layernorm module (See citation for details)." | |||
| def __init__(self, features, eps=1e-6): | |||
| super(LayerNorm, self).__init__() | |||
| self.a_2 = nn.Parameter(torch.ones(features)) | |||
| self.b_2 = nn.Parameter(torch.zeros(features)) | |||
| self.eps = eps | |||
| def forward(self, x): | |||
| mean = x.mean(-1, keepdim=True) | |||
| std = x.std(-1, keepdim=True) | |||
| return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 | |||
| class PositionwiseFeedForward(nn.Module): | |||
| "Implements FFN equation." | |||
| def __init__(self, d_model, d_ff, dropout=0.1): | |||
| super(PositionwiseFeedForward, self).__init__() | |||
| self.w_1 = nn.Linear(d_model, d_ff) | |||
| self.w_2 = nn.Linear(d_ff, d_model) | |||
| self.dropout = nn.Dropout(dropout) | |||
| def forward(self, x): | |||
| return self.w_2(self.dropout(F.relu(self.w_1(x)))) | |||
| class SublayerConnection(nn.Module): | |||
| """ | |||
| A residual connection followed by a layer norm. | |||
| Note for code simplicity the norm is first as opposed to last. | |||
| """ | |||
| def __init__(self, size, dropout): | |||
| super(SublayerConnection, self).__init__() | |||
| self.norm = LayerNorm(size) | |||
| self.dropout = nn.Dropout(dropout) | |||
| def forward(self, x, sublayer): | |||
| "Apply residual connection to any sublayer with the same size." | |||
| return x + self.dropout(sublayer(self.norm(x))) | |||
| class EncoderLayer(nn.Module): | |||
| "Encoder is made up of self-attn and feed forward (defined below)" | |||
| def __init__(self, size, self_attn, feed_forward, dropout): | |||
| super(EncoderLayer, self).__init__() | |||
| self.self_attn = self_attn | |||
| self.feed_forward = feed_forward | |||
| self.sublayer = clones(SublayerConnection(size, dropout), 2) | |||
| self.size = size | |||
| def forward(self, x, mask): | |||
| "Follow Figure 1 (left) for connections." | |||
| x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) | |||
| return self.sublayer[1](x, self.feed_forward) | |||
| class Encoder(nn.Module): | |||
| "Core encoder is a stack of N layers" | |||
| def __init__(self, layer, N): | |||
| super(Encoder, self).__init__() | |||
| self.layers = clones(layer, N) | |||
| self.norm = LayerNorm(layer.size) | |||
| def forward(self, x, mask): | |||
| # print(x.size(),mask.size()) | |||
| "Pass the input (and mask) through each layer in turn." | |||
| mask = mask.byte().unsqueeze(-2) | |||
| for layer in self.layers: | |||
| x = layer(x, mask) | |||
| return self.norm(x) | |||
| def make_encoder(N=6, d_model=512, d_ff=2048, h=8, dropout=0.1): | |||
| c = copy.deepcopy | |||
| attn = MultiHeadedAttention(h, d_model) | |||
| ff = PositionwiseFeedForward(d_model, d_ff, dropout) | |||
| return Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N) | |||
| @@ -0,0 +1,308 @@ | |||
| import numpy as np | |||
| import torch | |||
| import torch.cuda | |||
| import random | |||
| import os | |||
| import sys | |||
| import errno | |||
| import time | |||
| import codecs | |||
| import hashlib | |||
| import _pickle as pickle | |||
| import warnings | |||
| from fastNLP.io import EmbedLoader | |||
| UNK_TAG = "<unk>" | |||
| def set_seed(seed): | |||
| random.seed(seed) | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed_all(seed) | |||
| def bmes_to_words(chars, tags): | |||
| result = [] | |||
| if len(chars) == 0: | |||
| return result | |||
| word = chars[0] | |||
| for c, t in zip(chars[1:], tags[1:]): | |||
| if t.upper() == "B" or t.upper() == "S": | |||
| result.append(word) | |||
| word = "" | |||
| word += c | |||
| if len(word) != 0: | |||
| result.append(word) | |||
| return result | |||
| def bmes_to_index(tags): | |||
| result = [] | |||
| if len(tags) == 0: | |||
| return result | |||
| word = (0, 0) | |||
| for i, t in enumerate(tags): | |||
| if i == 0: | |||
| word = (0, 0) | |||
| elif t.upper() == "B" or t.upper() == "S": | |||
| result.append(word) | |||
| word = (i, 0) | |||
| word = (word[0], word[1] + 1) | |||
| if word[1] != 0: | |||
| result.append(word) | |||
| return result | |||
| def get_bmes(sent): | |||
| x = [] | |||
| y = [] | |||
| for word in sent: | |||
| length = len(word) | |||
| tag = ["m"] * length if length > 1 else ["s"] * length | |||
| if length > 1: | |||
| tag[0] = "b" | |||
| tag[-1] = "e" | |||
| x += list(word) | |||
| y += tag | |||
| return x, y | |||
| class CWSEvaluator: | |||
| def __init__(self, i2t): | |||
| self.correct_preds = 0.0 | |||
| self.total_preds = 0.0 | |||
| self.total_correct = 0.0 | |||
| self.i2t = i2t | |||
| def add_instance(self, pred_tags, gold_tags): | |||
| pred_tags = [self.i2t[i] for i in pred_tags] | |||
| gold_tags = [self.i2t[i] for i in gold_tags] | |||
| # Evaluate PRF | |||
| lab_gold_chunks = set(bmes_to_index(gold_tags)) | |||
| lab_pred_chunks = set(bmes_to_index(pred_tags)) | |||
| self.correct_preds += len(lab_gold_chunks & lab_pred_chunks) | |||
| self.total_preds += len(lab_pred_chunks) | |||
| self.total_correct += len(lab_gold_chunks) | |||
| def result(self, percentage=True): | |||
| p = self.correct_preds / self.total_preds if self.correct_preds > 0 else 0 | |||
| r = self.correct_preds / self.total_correct if self.correct_preds > 0 else 0 | |||
| f1 = 2 * p * r / (p + r) if p + r > 0 else 0 | |||
| if percentage: | |||
| p *= 100 | |||
| r *= 100 | |||
| f1 *= 100 | |||
| return p, r, f1 | |||
| class CWS_OOV: | |||
| def __init__(self, dic): | |||
| self.dic = dic | |||
| self.recall = 0 | |||
| self.tot = 0 | |||
| def update(self, gold_sent, pred_sent): | |||
| i = 0 | |||
| j = 0 | |||
| id = 0 | |||
| for w in gold_sent: | |||
| if w not in self.dic: | |||
| self.tot += 1 | |||
| while i + len(pred_sent[id]) <= j: | |||
| i += len(pred_sent[id]) | |||
| id += 1 | |||
| if ( | |||
| i == j | |||
| and len(pred_sent[id]) == len(w) | |||
| and w.find(pred_sent[id]) != -1 | |||
| ): | |||
| self.recall += 1 | |||
| j += len(w) | |||
| # print(gold_sent,pred_sent,self.tot) | |||
| def oov(self, percentage=True): | |||
| ins = 1.0 * self.recall / self.tot | |||
| if percentage: | |||
| ins *= 100 | |||
| return ins | |||
| def get_processing_word( | |||
| vocab_words=None, vocab_chars=None, lowercase=False, chars=False | |||
| ): | |||
| def f(word): | |||
| # 0. get chars of words | |||
| if vocab_chars is not None and chars: | |||
| char_ids = [] | |||
| for char in word: | |||
| # ignore chars out of vocabulary | |||
| if char in vocab_chars: | |||
| char_ids += [vocab_chars[char]] | |||
| # 1. preprocess word | |||
| if lowercase: | |||
| word = word.lower() | |||
| if word.isdigit(): | |||
| word = "0" | |||
| # 2. get id of word | |||
| if vocab_words is not None: | |||
| if word in vocab_words: | |||
| word = vocab_words[word] | |||
| else: | |||
| word = vocab_words[UNK_TAG] | |||
| # 3. return tuple char ids, word id | |||
| if vocab_chars is not None and chars: | |||
| return char_ids, word | |||
| else: | |||
| return word | |||
| return f | |||
| def append_tags(src, des, name, part, encode="utf-16"): | |||
| with open("{}/{}.txt".format(src, part), encoding=encode) as input, open( | |||
| "{}/{}.txt".format(des, part), "a", encoding=encode | |||
| ) as output: | |||
| for line in input: | |||
| line = line.strip() | |||
| if len(line) > 0: | |||
| output.write("<{}> {} </{}>".format(name, line, name)) | |||
| output.write("\n") | |||
| def is_dataset_tag(word): | |||
| return len(word) > 2 and word[0] == "<" and word[-1] == ">" | |||
| def to_tag_strings(i2ts, tag_mapping, pos_separate_col=True): | |||
| senlen = len(tag_mapping) | |||
| key_value_strs = [] | |||
| for j in range(senlen): | |||
| val = i2ts[tag_mapping[j]] | |||
| pos_str = val | |||
| key_value_strs.append(pos_str) | |||
| return key_value_strs | |||
| def to_id_list(w2i): | |||
| i2w = [None] * len(w2i) | |||
| for w, i in w2i.items(): | |||
| i2w[i] = w | |||
| return i2w | |||
| def make_sure_path_exists(path): | |||
| try: | |||
| os.makedirs(path) | |||
| except OSError as exception: | |||
| if exception.errno != errno.EEXIST: | |||
| raise | |||
| def md5_for_file(fn): | |||
| md5 = hashlib.md5() | |||
| with open(fn, "rb") as f: | |||
| for chunk in iter(lambda: f.read(128 * md5.block_size), b""): | |||
| md5.update(chunk) | |||
| return md5.hexdigest() | |||
| def embedding_match_vocab( | |||
| vocab, | |||
| emb, | |||
| ori_vocab, | |||
| dtype=np.float32, | |||
| padding="<pad>", | |||
| unknown="<unk>", | |||
| normalize=True, | |||
| error="ignore", | |||
| init_method=None, | |||
| ): | |||
| dim = emb.shape[-1] | |||
| matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
| hit_flags = np.zeros(len(vocab), dtype=bool) | |||
| if init_method: | |||
| matrix = init_method(matrix) | |||
| for word, idx in ori_vocab.word2idx.items(): | |||
| try: | |||
| if word == padding and vocab.padding is not None: | |||
| word = vocab.padding | |||
| elif word == unknown and vocab.unknown is not None: | |||
| word = vocab.unknown | |||
| if word in vocab: | |||
| index = vocab.to_index(word) | |||
| matrix[index] = emb[idx] | |||
| hit_flags[index] = True | |||
| except Exception as e: | |||
| if error == "ignore": | |||
| warnings.warn("Error occurred at the {} line.".format(idx)) | |||
| else: | |||
| print("Error occurred at the {} line.".format(idx)) | |||
| raise e | |||
| total_hits = np.sum(hit_flags) | |||
| print( | |||
| "Found {} out of {} words in the pre-training embedding.".format( | |||
| total_hits, len(vocab) | |||
| ) | |||
| ) | |||
| if init_method is None: | |||
| found_vectors = matrix[hit_flags] | |||
| if len(found_vectors) != 0: | |||
| mean = np.mean(found_vectors, axis=0, keepdims=True) | |||
| std = np.std(found_vectors, axis=0, keepdims=True) | |||
| unfound_vec_num = len(vocab) - total_hits | |||
| r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean | |||
| matrix[hit_flags == False] = r_vecs | |||
| if normalize: | |||
| matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
| return matrix | |||
| def embedding_load_with_cache(emb_file, cache_dir, vocab, **kwargs): | |||
| def match_cache(file, cache_dir): | |||
| md5 = md5_for_file(file) | |||
| cache_files = os.listdir(cache_dir) | |||
| for fn in cache_files: | |||
| if md5 in fn.split("-")[-1]: | |||
| return os.path.join(cache_dir, fn), True | |||
| return ( | |||
| "{}-{}.pkl".format(os.path.join(cache_dir, os.path.basename(file)), md5), | |||
| False, | |||
| ) | |||
| def get_cache(file): | |||
| if not os.path.exists(file): | |||
| return None | |||
| with open(file, "rb") as f: | |||
| emb = pickle.load(f) | |||
| return emb | |||
| os.makedirs(cache_dir, exist_ok=True) | |||
| cache_fn, match = match_cache(emb_file, cache_dir) | |||
| if not match: | |||
| print("cache missed, re-generating cache at {}".format(cache_fn)) | |||
| emb, ori_vocab = EmbedLoader.load_without_vocab( | |||
| emb_file, padding=None, unknown=None, normalize=False | |||
| ) | |||
| with open(cache_fn, "wb") as f: | |||
| pickle.dump((emb, ori_vocab), f) | |||
| else: | |||
| print("cache matched at {}".format(cache_fn)) | |||
| # use cache | |||
| print("loading embeddings ...") | |||
| emb = get_cache(cache_fn) | |||
| assert emb is not None | |||
| return embedding_match_vocab(vocab, emb[0], emb[1], **kwargs) | |||