diff --git a/reproduction/multi-criteria-cws/README.md b/reproduction/multi-criteria-cws/README.md new file mode 100644 index 00000000..0f4ab8d8 --- /dev/null +++ b/reproduction/multi-criteria-cws/README.md @@ -0,0 +1,61 @@ + + +# Multi-Criteria-CWS + +An implementation of [Multi-Criteria Chinese Word Segmentation with Transformer](http://arxiv.org/abs/1906.12035) with fastNLP. + +## Dataset +### Overview +We use the same datasets listed in paper. +- sighan2005 + - pku + - msr + - as + - cityu +- sighan2008 + - ctb + - ckip + - cityu (combined with data in sighan2005) + - ncc + - sxu + +### Preprocess +First, download OpenCC to convert between Traditional Chinese and Simplified Chinese. +``` shell +pip install opencc-python-reimplemented +``` +Then, set a path to save processed data, and run the shell script to process the data. +```shell +export DATA_DIR=path/to/processed-data +bash make_data.sh path/to/sighan2005 path/to/sighan2008 +``` +It would take a few minutes to finish the process. + +## Model +We use transformer to build the model, as described in paper. + +## Train +Finally, to train the model, run the shell script. +The `train.sh` takes one argument, the GPU-IDs to use, for example: +``` shell +bash train.sh 0,1 +``` +This command use GPUs with ID 0 and 1. + +Note: Please refer to the paper for details of hyper-parameters. And modify the settings in `train.sh` to match your experiment environment. + +Type +``` shell +python main.py --help +``` +to learn all arguments to be specified in training. + +## Performance + +Results on the test sets of eight CWS datasets with multi-criteria learning. + +| Dataset | MSRA | AS | PKU | CTB | CKIP | CITYU | NCC | SXU | Avg. | +| -------------- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | ----- | +| Original paper | 98.05 | 96.44 | 96.41 | 96.99 | 96.51 | 96.91 | 96.04 | 97.61 | 96.87 | +| Ours | 96.92 | 95.71 | 95.65 | 95.96 | 96.00 | 96.09 | 94.61 | 96.64 | 95.95 | + diff --git a/reproduction/multi-criteria-cws/data-prepare.py b/reproduction/multi-criteria-cws/data-prepare.py new file mode 100644 index 00000000..1d6e89b5 --- /dev/null +++ b/reproduction/multi-criteria-cws/data-prepare.py @@ -0,0 +1,262 @@ +import os +import re +import argparse +from opencc import OpenCC + +cc = OpenCC("t2s") + +from utils import make_sure_path_exists, append_tags + +sighan05_root = "" +sighan08_root = "" +data_path = "" + +E_pun = u",.!?[]()<>\"\"''," +C_pun = u",。!?【】()《》“”‘’、" +Table = {ord(f): ord(t) for f, t in zip(C_pun, E_pun)} +Table[12288] = 32 # 全半角空格 + + +def C_trans_to_E(string): + return string.translate(Table) + + +def normalize(ustring): + """全角转半角""" + rstring = "" + for uchar in ustring: + inside_code = ord(uchar) + if inside_code == 12288: # 全角空格直接转换 + inside_code = 32 + elif 65281 <= inside_code <= 65374: # 全角字符(除空格)根据关系转化 + inside_code -= 65248 + + rstring += chr(inside_code) + return rstring + + +def preprocess(text): + rNUM = u"(-|\+)?\d+((\.|·)\d+)?%?" + rENG = u"[A-Za-z_]+.*" + sent = normalize(C_trans_to_E(text.strip())).split() + new_sent = [] + for word in sent: + word = re.sub(u"\s+", "", word, flags=re.U) + word = re.sub(rNUM, u"0", word, flags=re.U) + word = re.sub(rENG, u"X", word) + new_sent.append(word) + return new_sent + + +def to_sentence_list(text, split_long_sentence=False): + text = preprocess(text) + delimiter = set() + delimiter.update("。!?:;…、,(),;!?、,\"'") + delimiter.add("……") + sent_list = [] + sent = [] + sent_len = 0 + for word in text: + sent.append(word) + sent_len += len(word) + if word in delimiter or (split_long_sentence and sent_len >= 50): + sent_list.append(sent) + sent = [] + sent_len = 0 + + if len(sent) > 0: + sent_list.append(sent) + + return sent_list + + +def is_traditional(dataset): + return dataset in ["as", "cityu", "ckip"] + + +def convert_file( + src, des, need_cc=False, split_long_sentence=False, encode="utf-8-sig" +): + with open(src, encoding=encode) as src, open(des, "w", encoding="utf-8") as des: + for line in src: + for sent in to_sentence_list(line, split_long_sentence): + line = " ".join(sent) + "\n" + if need_cc: + line = cc.convert(line) + des.write(line) + # if len(''.join(sent)) > 200: + # print(' '.join(sent)) + + +def split_train_dev(dataset): + root = data_path + "/" + dataset + "/raw/" + with open(root + "train-all.txt", encoding="UTF-8") as src, open( + root + "train.txt", "w", encoding="UTF-8" + ) as train, open(root + "dev.txt", "w", encoding="UTF-8") as dev: + lines = src.readlines() + idx = int(len(lines) * 0.9) + for line in lines[:idx]: + train.write(line) + for line in lines[idx:]: + dev.write(line) + + +def combine_files(one, two, out): + if os.path.exists(out): + os.remove(out) + with open(one, encoding="utf-8") as one, open(two, encoding="utf-8") as two, open( + out, "a", encoding="utf-8" + ) as out: + for line in one: + out.write(line) + for line in two: + out.write(line) + + +def bmes_tag(input_file, output_file): + with open(input_file, encoding="utf-8") as input_data, open( + output_file, "w", encoding="utf-8" + ) as output_data: + for line in input_data: + word_list = line.strip().split() + for word in word_list: + if len(word) == 1 or ( + len(word) > 2 and word[0] == "<" and word[-1] == ">" + ): + output_data.write(word + "\tS\n") + else: + output_data.write(word[0] + "\tB\n") + for w in word[1 : len(word) - 1]: + output_data.write(w + "\tM\n") + output_data.write(word[len(word) - 1] + "\tE\n") + output_data.write("\n") + + +def make_bmes(dataset="pku"): + path = data_path + "/" + dataset + "/" + make_sure_path_exists(path + "bmes") + bmes_tag(path + "raw/train.txt", path + "bmes/train.txt") + bmes_tag(path + "raw/train-all.txt", path + "bmes/train-all.txt") + bmes_tag(path + "raw/dev.txt", path + "bmes/dev.txt") + bmes_tag(path + "raw/test.txt", path + "bmes/test.txt") + + +def convert_sighan2005_dataset(dataset): + global sighan05_root + root = os.path.join(data_path, dataset) + make_sure_path_exists(root) + make_sure_path_exists(root + "/raw") + file_path = "{}/{}_training.utf8".format(sighan05_root, dataset) + convert_file( + file_path, "{}/raw/train-all.txt".format(root), is_traditional(dataset), True + ) + if dataset == "as": + file_path = "{}/{}_testing_gold.utf8".format(sighan05_root, dataset) + else: + file_path = "{}/{}_test_gold.utf8".format(sighan05_root, dataset) + convert_file( + file_path, "{}/raw/test.txt".format(root), is_traditional(dataset), False + ) + split_train_dev(dataset) + + +def convert_sighan2008_dataset(dataset, utf=16): + global sighan08_root + root = os.path.join(data_path, dataset) + make_sure_path_exists(root) + make_sure_path_exists(root + "/raw") + convert_file( + "{}/{}_train_utf{}.seg".format(sighan08_root, dataset, utf), + "{}/raw/train-all.txt".format(root), + is_traditional(dataset), + True, + "utf-{}".format(utf), + ) + convert_file( + "{}/{}_seg_truth&resource/{}_truth_utf{}.seg".format( + sighan08_root, dataset, dataset, utf + ), + "{}/raw/test.txt".format(root), + is_traditional(dataset), + False, + "utf-{}".format(utf), + ) + split_train_dev(dataset) + + +def extract_conll(src, out): + words = [] + with open(src, encoding="utf-8") as src, open(out, "w", encoding="utf-8") as out: + for line in src: + line = line.strip() + if len(line) == 0: + out.write(" ".join(words) + "\n") + words = [] + continue + cells = line.split() + words.append(cells[1]) + + +def make_joint_corpus(datasets, joint): + parts = ["dev", "test", "train", "train-all"] + for part in parts: + old_file = "{}/{}/raw/{}.txt".format(data_path, joint, part) + if os.path.exists(old_file): + os.remove(old_file) + elif not os.path.exists(os.path.dirname(old_file)): + os.makedirs(os.path.dirname(old_file)) + for name in datasets: + append_tags( + os.path.join(data_path, name, "raw"), + os.path.dirname(old_file), + name, + part, + encode="utf-8", + ) + + +def convert_all_sighan2005(datasets): + for dataset in datasets: + print(("Converting sighan bakeoff 2005 corpus: {}".format(dataset))) + convert_sighan2005_dataset(dataset) + make_bmes(dataset) + + +def convert_all_sighan2008(datasets): + for dataset in datasets: + print(("Converting sighan bakeoff 2008 corpus: {}".format(dataset))) + convert_sighan2008_dataset(dataset, 16) + make_bmes(dataset) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + # fmt: off + parser.add_argument("--sighan05", required=True, type=str, help="path to sighan2005 dataset") + parser.add_argument("--sighan08", required=True, type=str, help="path to sighan2008 dataset") + parser.add_argument("--data_path", required=True, type=str, help="path to save dataset") + # fmt: on + + args, _ = parser.parse_known_args() + sighan05_root = args.sighan05 + sighan08_root = args.sighan08 + data_path = args.data_path + + print("Converting sighan2005 Simplified Chinese corpus") + datasets = "pku", "msr", "as", "cityu" + convert_all_sighan2005(datasets) + + print("Combining sighan2005 corpus to one joint Simplified Chinese corpus") + datasets = "pku", "msr", "as", "cityu" + make_joint_corpus(datasets, "joint-sighan2005") + make_bmes("joint-sighan2005") + + # For researchers who have access to sighan2008 corpus, use official corpora please. + print("Converting sighan2008 Simplified Chinese corpus") + datasets = "ctb", "ckip", "cityu", "ncc", "sxu" + convert_all_sighan2008(datasets) + print("Combining those 8 sighan corpora to one joint corpus") + datasets = "pku", "msr", "as", "ctb", "ckip", "cityu", "ncc", "sxu" + make_joint_corpus(datasets, "joint-sighan2008") + make_bmes("joint-sighan2008") + diff --git a/reproduction/multi-criteria-cws/data-process.py b/reproduction/multi-criteria-cws/data-process.py new file mode 100644 index 00000000..829580ef --- /dev/null +++ b/reproduction/multi-criteria-cws/data-process.py @@ -0,0 +1,166 @@ +import os +import sys + +import codecs +import argparse +from _pickle import load, dump +import collections +from utils import get_processing_word, is_dataset_tag, make_sure_path_exists, get_bmes +from fastNLP import Instance, DataSet, Vocabulary, Const + +max_len = 0 + + +def expand(x): + sent = [""] + x[1:] + [""] + return [x + y for x, y in zip(sent[:-1], sent[1:])] + + +def read_file(filename, processing_word=get_processing_word(lowercase=False)): + dataset = DataSet() + niter = 0 + with codecs.open(filename, "r", "utf-8-sig") as f: + words, tags = [], [] + for line in f: + line = line.strip() + if len(line) == 0 or line.startswith("-DOCSTART-"): + if len(words) != 0: + assert len(words) > 2 + if niter == 1: + print(words, tags) + niter += 1 + dataset.append(Instance(ori_words=words[:-1], ori_tags=tags[:-1])) + words, tags = [], [] + else: + word, tag = line.split() + word = processing_word(word) + words.append(word) + tags.append(tag.lower()) + + dataset.apply_field(lambda x: [x[0]], field_name="ori_words", new_field_name="task") + dataset.apply_field( + lambda x: len(x), field_name="ori_tags", new_field_name="seq_len" + ) + dataset.apply_field( + lambda x: expand(x), field_name="ori_words", new_field_name="bi1" + ) + return dataset + + +def main(): + parser = argparse.ArgumentParser() + # fmt: off + parser.add_argument("--data_path", required=True, type=str, help="all of datasets pkl paths") + # fmt: on + + options, _ = parser.parse_known_args() + + train_set, test_set = DataSet(), DataSet() + + input_dir = os.path.join(options.data_path, "joint-sighan2008/bmes") + options.output = os.path.join(options.data_path, "total_dataset.pkl") + print(input_dir, options.output) + + for fn in os.listdir(input_dir): + if fn not in ["test.txt", "train-all.txt"]: + continue + print(fn) + abs_fn = os.path.join(input_dir, fn) + ds = read_file(abs_fn) + if "test.txt" == fn: + test_set = ds + else: + train_set = ds + + print( + "num samples of total train, test: {}, {}".format(len(train_set), len(test_set)) + ) + + uni_vocab = Vocabulary(min_freq=None).from_dataset( + train_set, test_set, field_name="ori_words" + ) + # bi_vocab = Vocabulary(min_freq=3, max_size=50000).from_dataset(train_set,test_set, field_name="bi1") + bi_vocab = Vocabulary(min_freq=3, max_size=None).from_dataset( + train_set, field_name="bi1", no_create_entry_dataset=[test_set] + ) + tag_vocab = Vocabulary(min_freq=None, padding="s", unknown=None).from_dataset( + train_set, field_name="ori_tags" + ) + task_vocab = Vocabulary(min_freq=None, padding=None, unknown=None).from_dataset( + train_set, field_name="task" + ) + + def to_index(dataset): + uni_vocab.index_dataset(dataset, field_name="ori_words", new_field_name="uni") + tag_vocab.index_dataset(dataset, field_name="ori_tags", new_field_name="tags") + task_vocab.index_dataset(dataset, field_name="task", new_field_name="task") + + dataset.apply_field(lambda x: x[1:], field_name="bi1", new_field_name="bi2") + dataset.apply_field(lambda x: x[:-1], field_name="bi1", new_field_name="bi1") + bi_vocab.index_dataset(dataset, field_name="bi1", new_field_name="bi1") + bi_vocab.index_dataset(dataset, field_name="bi2", new_field_name="bi2") + + dataset.set_input("task", "uni", "bi1", "bi2", "seq_len") + dataset.set_target("tags") + return dataset + + train_set = to_index(train_set) + test_set = to_index(test_set) + + output = {} + output["train_set"] = train_set + output["test_set"] = test_set + output["uni_vocab"] = uni_vocab + output["bi_vocab"] = bi_vocab + output["tag_vocab"] = tag_vocab + output["task_vocab"] = task_vocab + + print(tag_vocab.word2idx) + print(task_vocab.word2idx) + + make_sure_path_exists(os.path.dirname(options.output)) + + print("Saving dataset to {}".format(os.path.abspath(options.output))) + with open(options.output, "wb") as outfile: + dump(output, outfile) + + print(len(task_vocab), len(tag_vocab), len(uni_vocab), len(bi_vocab)) + dic = {} + tokens = {} + + def process(words): + name = words[0][1:-1] + if name not in dic: + dic[name] = set() + tokens[name] = 0 + tokens[name] += len(words[1:]) + dic[name].update(words[1:]) + + train_set.apply_field(process, "ori_words", None) + for name in dic.keys(): + print(name, len(dic[name]), tokens[name]) + + with open(os.path.join(os.path.dirname(options.output), "oovdict.pkl"), "wb") as f: + dump(dic, f) + + def get_max_len(ds): + global max_len + max_len = 0 + + def find_max_len(words): + global max_len + if max_len < len(words): + max_len = len(words) + + ds.apply_field(find_max_len, "ori_words", None) + return max_len + + print( + "train max len: {}, test max len: {}".format( + get_max_len(train_set), get_max_len(test_set) + ) + ) + + +if __name__ == "__main__": + main() diff --git a/reproduction/multi-criteria-cws/main.py b/reproduction/multi-criteria-cws/main.py new file mode 100644 index 00000000..049a1974 --- /dev/null +++ b/reproduction/multi-criteria-cws/main.py @@ -0,0 +1,506 @@ +import _pickle as pickle +import argparse +import collections +import logging +import math +import os +import pickle +import random +import sys +import time +from sys import maxsize + +import fastNLP +import fastNLP.embeddings +import numpy as np +import torch +import torch.distributed as dist +import torch.nn as nn +from fastNLP import BucketSampler, DataSetIter, SequentialSampler, logger +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data.distributed import DistributedSampler + +import models +import optm +import utils + +NONE_TAG = "" +START_TAG = "" +END_TAG = "" + +DEFAULT_WORD_EMBEDDING_SIZE = 100 +DEBUG_SCALE = 200 + +# ===-----------------------------------------------------------------------=== +# Argument parsing +# ===-----------------------------------------------------------------------=== +# fmt: off +parser = argparse.ArgumentParser() +parser.add_argument("--dataset", required=True, dest="dataset", help="processed data dir") +parser.add_argument("--word-embeddings", dest="word_embeddings", help="File from which to read in pretrained embeds") +parser.add_argument("--bigram-embeddings", dest="bigram_embeddings", help="File from which to read in pretrained embeds") +parser.add_argument("--crf", dest="crf", action="store_true", help="crf") +# parser.add_argument("--devi", default="0", dest="devi", help="gpu") +parser.add_argument("--step", default=0, dest="step", type=int,help="step") +parser.add_argument("--num-epochs", default=100, dest="num_epochs", type=int, + help="Number of full passes through training set") +parser.add_argument("--batch-size", default=128, dest="batch_size", type=int, + help="Minibatch size of training set") +parser.add_argument("--d_model", default=256, dest="d_model", type=int, help="d_model") +parser.add_argument("--d_ff", default=1024, dest="d_ff", type=int, help="d_ff") +parser.add_argument("--N", default=6, dest="N", type=int, help="N") +parser.add_argument("--h", default=4, dest="h", type=int, help="h") +parser.add_argument("--factor", default=2, dest="factor", type=float, help="Initial learning rate") +parser.add_argument("--dropout", default=0.2, dest="dropout", type=float, + help="Amount of dropout(not keep rate, but drop rate) to apply to embeddings part of graph") +parser.add_argument("--log-dir", default="result", dest="log_dir", + help="Directory where to write logs / serialized models") +parser.add_argument("--task-name", default=time.strftime("%Y-%m-%d-%H-%M-%S"), dest="task_name", + help="Name for this task, use a comprehensive one") +parser.add_argument("--no-model", dest="no_model", action="store_true", help="Don't serialize model") +parser.add_argument("--always-model", dest="always_model", action="store_true", + help="Always serialize model after every epoch") +parser.add_argument("--old-model", dest="old_model", help="Path to old model for incremental training") +parser.add_argument("--skip-dev", dest="skip_dev", action="store_true", help="Skip dev set, would save some time") +parser.add_argument("--freeze", dest="freeze", action="store_true", help="freeze pretrained embedding") +parser.add_argument("--only-task", dest="only_task", action="store_true", help="only train task embedding") +parser.add_argument("--subset", dest="subset", help="Only train and test on a subset of the whole dataset") +parser.add_argument("--seclude", dest="seclude", help="train and test except a subset") +parser.add_argument("--instances", default=None, dest="instances", type=int,help="num of instances of subset") + +parser.add_argument("--seed", dest="python_seed", type=int, default=random.randrange(maxsize), + help="Random seed of Python and NumPy") +parser.add_argument("--debug", dest="debug", default=False, action="store_true", help="Debug mode") +parser.add_argument("--test", dest="test", action="store_true", help="Test mode") +parser.add_argument('--local_rank', type=int, default=None) +parser.add_argument('--init_method', type=str, default='env://') +# fmt: on + +options, _ = parser.parse_known_args() +print("unknown args", _) +task_name = options.task_name +root_dir = "{}/{}".format(options.log_dir, task_name) +utils.make_sure_path_exists(root_dir) + +if options.local_rank is not None: + torch.cuda.set_device(options.local_rank) + dist.init_process_group("nccl", init_method=options.init_method) + + +def init_logger(): + if not os.path.exists(root_dir): + os.mkdir(root_dir) + log_formatter = logging.Formatter("%(asctime)s - %(message)s") + logger = logging.getLogger() + file_handler = logging.FileHandler("{0}/info.log".format(root_dir), mode="w") + file_handler.setFormatter(log_formatter) + logger.addHandler(file_handler) + console_handler = logging.StreamHandler() + console_handler.setFormatter(log_formatter) + logger.addHandler(console_handler) + if options.local_rank is None or options.local_rank == 0: + logger.setLevel(logging.INFO) + else: + logger.setLevel(logging.WARNING) + return logger + + +# ===-----------------------------------------------------------------------=== +# Set up logging +# ===-----------------------------------------------------------------------=== +# logger = init_logger() +logger.add_file("{}/info.log".format(root_dir), "INFO") +logger.setLevel(logging.INFO if dist.get_rank() == 0 else logging.WARNING) + +# ===-----------------------------------------------------------------------=== +# Log some stuff about this run +# ===-----------------------------------------------------------------------=== +logger.info(" ".join(sys.argv)) +logger.info("") +logger.info(options) + +if options.debug: + logger.info("DEBUG MODE") + options.num_epochs = 2 + options.batch_size = 20 + +random.seed(options.python_seed) +np.random.seed(options.python_seed % (2 ** 32 - 1)) +torch.cuda.manual_seed_all(options.python_seed) +logger.info("Python random seed: {}".format(options.python_seed)) + +# ===-----------------------------------------------------------------------=== +# Read in dataset +# ===-----------------------------------------------------------------------=== +dataset = pickle.load(open(options.dataset + "/total_dataset.pkl", "rb")) +train_set = dataset["train_set"] +test_set = dataset["test_set"] +uni_vocab = dataset["uni_vocab"] +bi_vocab = dataset["bi_vocab"] +task_vocab = dataset["task_vocab"] +tag_vocab = dataset["tag_vocab"] +for v in (bi_vocab, uni_vocab, tag_vocab, task_vocab): + if hasattr(v, "_word2idx"): + v.word2idx = v._word2idx +for ds in (train_set, test_set): + ds.rename_field("ori_words", "words") + +logger.info("{} {}".format(bi_vocab.to_word(0), tag_vocab.word2idx)) +logger.info(task_vocab.word2idx) +if options.skip_dev: + dev_set = test_set +else: + train_set, dev_set = train_set.split(0.1) + +logger.info("{} {} {}".format(len(train_set), len(dev_set), len(test_set))) + +if options.debug: + train_set = train_set[0:DEBUG_SCALE] + dev_set = dev_set[0:DEBUG_SCALE] + test_set = test_set[0:DEBUG_SCALE] + +# ===-----------------------------------------------------------------------=== +# Build model and trainer +# ===-----------------------------------------------------------------------=== + +# =============================== +if dist.get_rank() != 0: + dist.barrier() + +if options.word_embeddings is None: + init_embedding = None +else: + # logger.info("Load: {}".format(options.word_embeddings)) + # init_embedding = utils.embedding_load_with_cache(options.word_embeddings, options.cache_dir, uni_vocab, normalize=False) + init_embedding = fastNLP.embeddings.StaticEmbedding( + uni_vocab, options.word_embeddings, word_drop=0.01 + ) + +bigram_embedding = None +if options.bigram_embeddings: + # logger.info("Load: {}".format(options.bigram_embeddings)) + # bigram_embedding = utils.embedding_load_with_cache(options.bigram_embeddings, options.cache_dir, bi_vocab, normalize=False) + bigram_embedding = fastNLP.embeddings.StaticEmbedding( + bi_vocab, options.bigram_embeddings + ) + +if dist.get_rank() == 0: + dist.barrier() +# =============================== + +# select subset training +if options.seclude is not None: + setname = "<{}>".format(options.seclude) + logger.info("seclude {}".format(setname)) + train_set.drop(lambda x: x["words"][0] == setname, inplace=True) + test_set.drop(lambda x: x["words"][0] == setname, inplace=True) + dev_set.drop(lambda x: x["words"][0] == setname, inplace=True) + +if options.subset is not None: + setname = "<{}>".format(options.subset) + logger.info("select {}".format(setname)) + train_set.drop(lambda x: x["words"][0] != setname, inplace=True) + test_set.drop(lambda x: x["words"][0] != setname, inplace=True) + dev_set.drop(lambda x: x["words"][0] != setname, inplace=True) + +# build model and optimizer +i2t = None +if options.crf: + # i2t=utils.to_id_list(tag_vocab.word2idx) + i2t = {} + for x, y in tag_vocab.word2idx.items(): + i2t[y] = x + logger.info(i2t) + +freeze = True if options.freeze else False +model = models.make_CWS( + d_model=options.d_model, + N=options.N, + h=options.h, + d_ff=options.d_ff, + dropout=options.dropout, + word_embedding=init_embedding, + bigram_embedding=bigram_embedding, + tag_size=len(tag_vocab), + task_size=len(task_vocab), + crf=i2t, + freeze=freeze, +) + +device = "cpu" + +if torch.cuda.device_count() > 0: + if options.local_rank is not None: + device = "cuda:{}".format(options.local_rank) + # model=nn.DataParallel(model) + model = model.to(device) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[options.local_rank], output_device=options.local_rank + ) + else: + device = "cuda:0" + model.to(device) + + +if options.only_task and options.old_model is not None: + logger.info("fix para except task embedding") + for name, para in model.named_parameters(): + if name.find("task_embed") == -1: + para.requires_grad = False + else: + para.requires_grad = True + logger.info(name) + +optimizer = optm.NoamOpt( + options.d_model, + options.factor, + 4000, + torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9), +) + +optimizer._step = options.step + +best_model_file_name = "{}/model.bin".format(root_dir) + +if options.local_rank is None: + train_sampler = BucketSampler( + batch_size=options.batch_size, seq_len_field_name="seq_len" + ) +else: + train_sampler = DistributedSampler( + train_set, dist.get_world_size(), dist.get_rank() + ) +dev_sampler = SequentialSampler() + +i2t = utils.to_id_list(tag_vocab.word2idx) +i2task = utils.to_id_list(task_vocab.word2idx) +dev_set.set_input("words") +test_set.set_input("words") +test_batch = DataSetIter(test_set, options.batch_size, num_workers=2) + +word_dic = pickle.load(open(options.dataset + "/oovdict.pkl", "rb")) + + +def batch_to_device(batch, device): + for k, v in batch.items(): + if torch.is_tensor(v): + batch[k] = v.to(device) + return batch + + +def tester(model, test_batch, write_out=False): + res = [] + prf = utils.CWSEvaluator(i2t) + prf_dataset = {} + oov_dataset = {} + + logger.info("start evaluation") + # import ipdb; ipdb.set_trace() + with torch.no_grad(): + for batch_x, batch_y in test_batch: + batch_to_device(batch_x, device) + # batch_to_device(batch_y, device) + if bigram_embedding is not None: + out = model( + batch_x["task"], + batch_x["uni"], + batch_x["seq_len"], + batch_x["bi1"], + batch_x["bi2"], + ) + else: + out = model(batch_x["task"], batch_x["uni"], batch_x["seq_len"]) + out = out["pred"] + # print(out) + num = out.size(0) + out = out.detach().cpu().numpy() + for i in range(num): + length = int(batch_x["seq_len"][i]) + + out_tags = out[i, 1:length].tolist() + sentence = batch_x["words"][i] + gold_tags = batch_y["tags"][i][1:length].numpy().tolist() + dataset_name = sentence[0] + sentence = sentence[1:] + # print(out_tags,gold_tags) + assert utils.is_dataset_tag(dataset_name), dataset_name + assert len(gold_tags) == len(out_tags) and len(gold_tags) == len( + sentence + ) + + if dataset_name not in prf_dataset: + prf_dataset[dataset_name] = utils.CWSEvaluator(i2t) + oov_dataset[dataset_name] = utils.CWS_OOV( + word_dic[dataset_name[1:-1]] + ) + + prf_dataset[dataset_name].add_instance(gold_tags, out_tags) + prf.add_instance(gold_tags, out_tags) + + if write_out: + gold_strings = utils.to_tag_strings(i2t, gold_tags) + obs_strings = utils.to_tag_strings(i2t, out_tags) + + word_list = utils.bmes_to_words(sentence, obs_strings) + oov_dataset[dataset_name].update( + utils.bmes_to_words(sentence, gold_strings), word_list + ) + + raw_string = " ".join(word_list) + res.append(dataset_name + " " + raw_string + " " + dataset_name) + + Ap = 0.0 + Ar = 0.0 + Af = 0.0 + Aoov = 0.0 + tot = 0 + nw = 0.0 + for dataset_name, performance in sorted(prf_dataset.items()): + p = performance.result() + if write_out: + nw = oov_dataset[dataset_name].oov() + # nw = 0 + logger.info( + "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( + dataset_name, p[0], p[1], p[2], nw + ) + ) + else: + logger.info( + "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( + dataset_name, p[0], p[1], p[2] + ) + ) + Ap += p[0] + Ar += p[1] + Af += p[2] + Aoov += nw + tot += 1 + + prf = prf.result() + logger.info( + "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format("TOT", prf[0], prf[1], prf[2]) + ) + if not write_out: + logger.info( + "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( + "AVG", Ap / tot, Ar / tot, Af / tot + ) + ) + else: + logger.info( + "{}\t{:04.2f}\t{:04.2f}\t{:04.2f}\t{:04.2f}".format( + "AVG", Ap / tot, Ar / tot, Af / tot, Aoov / tot + ) + ) + return prf[-1], res + + +# start training +if not options.test: + if options.old_model: + # incremental training + logger.info("Incremental training from old model: {}".format(options.old_model)) + model.load_state_dict(torch.load(options.old_model, map_location="cuda:0")) + + logger.info("Number training instances: {}".format(len(train_set))) + logger.info("Number dev instances: {}".format(len(dev_set))) + + train_batch = DataSetIter( + batch_size=options.batch_size, + dataset=train_set, + sampler=train_sampler, + num_workers=4, + ) + dev_batch = DataSetIter( + batch_size=options.batch_size, + dataset=dev_set, + sampler=dev_sampler, + num_workers=4, + ) + + best_f1 = 0.0 + for epoch in range(int(options.num_epochs)): + logger.info("Epoch {} out of {}".format(epoch + 1, options.num_epochs)) + train_loss = 0.0 + model.train() + tot = 0 + t1 = time.time() + for batch_x, batch_y in train_batch: + model.zero_grad() + if bigram_embedding is not None: + out = model( + batch_x["task"], + batch_x["uni"], + batch_x["seq_len"], + batch_x["bi1"], + batch_x["bi2"], + batch_y["tags"], + ) + else: + out = model( + batch_x["task"], batch_x["uni"], batch_x["seq_len"], batch_y["tags"] + ) + loss = out["loss"] + train_loss += loss.item() + tot += 1 + loss.backward() + # nn.utils.clip_grad_value_(model.parameters(), 1) + optimizer.step() + + t2 = time.time() + train_loss = train_loss / tot + logger.info( + "time: {} loss: {} step: {}".format(t2 - t1, train_loss, optimizer._step) + ) + # Evaluate dev data + if options.skip_dev and dist.get_rank() == 0: + logger.info("Saving model to {}".format(best_model_file_name)) + torch.save(model.module.state_dict(), best_model_file_name) + continue + + model.eval() + if dist.get_rank() == 0: + f1, _ = tester(model.module, dev_batch) + if f1 > best_f1: + best_f1 = f1 + logger.info("- new best score!") + if not options.no_model: + logger.info("Saving model to {}".format(best_model_file_name)) + torch.save(model.module.state_dict(), best_model_file_name) + + elif options.always_model: + logger.info("Saving model to {}".format(best_model_file_name)) + torch.save(model.module.state_dict(), best_model_file_name) + dist.barrier() + +# Evaluate test data (once) +logger.info("\nNumber test instances: {}".format(len(test_set))) + + +if not options.skip_dev: + if options.test: + model.module.load_state_dict( + torch.load(options.old_model, map_location="cuda:0") + ) + else: + model.module.load_state_dict( + torch.load(best_model_file_name, map_location="cuda:0") + ) + +if dist.get_rank() == 0: + for name, para in model.named_parameters(): + if name.find("task_embed") != -1: + tm = para.detach().cpu().numpy() + logger.info(tm.shape) + np.save("{}/task.npy".format(root_dir), tm) + break + +_, res = tester(model.module, test_batch, True) + +if dist.get_rank() == 0: + with open("{}/testout.txt".format(root_dir), "w", encoding="utf-8") as raw_writer: + for sent in res: + raw_writer.write(sent) + raw_writer.write("\n") + diff --git a/reproduction/multi-criteria-cws/make_data.sh b/reproduction/multi-criteria-cws/make_data.sh new file mode 100644 index 00000000..9c2b09d8 --- /dev/null +++ b/reproduction/multi-criteria-cws/make_data.sh @@ -0,0 +1,14 @@ +if [ -z "$DATA_DIR" ] +then + DATA_DIR="./data" +fi + +mkdir -vp $DATA_DIR + +cmd="python -u ./data-prepare.py --sighan05 $1 --sighan08 $2 --data_path $DATA_DIR" +echo $cmd +eval $cmd + +cmd="python -u ./data-process.py --data_path $DATA_DIR" +echo $cmd +eval $cmd diff --git a/reproduction/multi-criteria-cws/model.py b/reproduction/multi-criteria-cws/model.py new file mode 100644 index 00000000..e69de29b diff --git a/reproduction/multi-criteria-cws/models.py b/reproduction/multi-criteria-cws/models.py new file mode 100644 index 00000000..965da651 --- /dev/null +++ b/reproduction/multi-criteria-cws/models.py @@ -0,0 +1,200 @@ +import fastNLP +import torch +import math +from fastNLP.modules.encoder.transformer import TransformerEncoder +from fastNLP.modules.decoder.crf import ConditionalRandomField +from fastNLP import Const +import copy +import numpy as np +from torch.autograd import Variable +import torch.autograd as autograd +import torch.nn as nn +import torch.nn.functional as F +import transformer + + +class PositionalEncoding(nn.Module): + "Implement the PE function." + + def __init__(self, d_model, dropout, max_len=512): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False) + return self.dropout(x) + + +class Embedding(nn.Module): + def __init__( + self, + task_size, + d_model, + word_embedding=None, + bi_embedding=None, + word_size=None, + freeze=True, + ): + super(Embedding, self).__init__() + self.task_size = task_size + self.embed_dim = 0 + + self.task_embed = nn.Embedding(task_size, d_model) + if word_embedding is not None: + # self.uni_embed = nn.Embedding.from_pretrained(torch.FloatTensor(word_embedding), freeze=freeze) + # self.embed_dim+=word_embedding.shape[1] + self.uni_embed = word_embedding + self.embed_dim += word_embedding.embedding_dim + else: + if bi_embedding is not None: + self.embed_dim += bi_embedding.shape[1] + else: + self.embed_dim = d_model + assert word_size is not None + self.uni_embed = Embedding(word_size, self.embed_dim) + + if bi_embedding is not None: + # self.bi_embed = nn.Embedding.from_pretrained(torch.FloatTensor(bi_embedding), freeze=freeze) + # self.embed_dim += bi_embedding.shape[1]*2 + self.bi_embed = bi_embedding + self.embed_dim += bi_embedding.embedding_dim * 2 + + print("Trans Freeze", freeze, self.embed_dim) + + if d_model != self.embed_dim: + self.F = nn.Linear(self.embed_dim, d_model) + else: + self.F = None + + self.d_model = d_model + + def forward(self, task, uni, bi1=None, bi2=None): + y_task = self.task_embed(task[:, 0:1]) + y = self.uni_embed(uni[:, 1:]) + if bi1 is not None: + assert self.bi_embed is not None + + y = torch.cat([y, self.bi_embed(bi1), self.bi_embed(bi2)], dim=-1) + # y2=self.bi_embed(bi) + # y=torch.cat([y,y2[:,:-1,:],y2[:,1:,:]],dim=-1) + + # y=torch.cat([y_task,y],dim=1) + if self.F is not None: + y = self.F(y) + y = torch.cat([y_task, y], dim=1) + return y * math.sqrt(self.d_model) + + +def seq_len_to_mask(seq_len, max_len=None): + if isinstance(seq_len, np.ndarray): + assert ( + len(np.shape(seq_len)) == 1 + ), f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." + if max_len is None: + max_len = int(seq_len.max()) + broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) + mask = broad_cast_seq_len < seq_len.reshape(-1, 1) + + elif isinstance(seq_len, torch.Tensor): + assert ( + seq_len.dim() == 1 + ), f"seq_len can only have one dimension, got {seq_len.dim() == 1}." + batch_size = seq_len.size(0) + if max_len is None: + max_len = seq_len.max().long() + broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) + mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) + else: + raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") + + return mask + + +class CWSModel(nn.Module): + def __init__(self, encoder, src_embed, position, d_model, tag_size, crf=None): + super(CWSModel, self).__init__() + self.encoder = encoder + self.src_embed = src_embed + self.pos = copy.deepcopy(position) + self.proj = nn.Linear(d_model, tag_size) + self.tag_size = tag_size + if crf is None: + self.crf = None + self.loss_f = nn.CrossEntropyLoss(reduction="mean", ignore_index=-100) + else: + print("crf") + trans = fastNLP.modules.decoder.crf.allowed_transitions( + crf, encoding_type="bmes" + ) + self.crf = ConditionalRandomField(tag_size, allowed_transitions=trans) + # self.norm=nn.LayerNorm(d_model) + + def forward(self, task, uni, seq_len, bi1=None, bi2=None, tags=None): + # mask=fastNLP.core.utils.seq_len_to_mask(seq_len,uni.size(1)) # for dev 0.5.1 + mask = seq_len_to_mask(seq_len, uni.size(1)) + out = self.src_embed(task, uni, bi1, bi2) + out = self.pos(out) + # out=self.norm(out) + out = self.proj(self.encoder(out, mask.float())) + + if self.crf is not None: + if tags is not None: + out = self.crf(out, tags, mask) + return {"loss": out} + else: + out, _ = self.crf.viterbi_decode(out, mask) + return {"pred": out} + else: + if tags is not None: + out = out.contiguous().view(-1, self.tag_size) + tags = tags.data.masked_fill_(mask == 0, -100).view(-1) + loss = self.loss_f(out, tags) + return {"loss": loss} + else: + out = torch.argmax(out, dim=-1) + return {"pred": out} + + +def make_CWS( + N=6, + d_model=256, + d_ff=1024, + h=4, + dropout=0.2, + tag_size=4, + task_size=8, + bigram_embedding=None, + word_embedding=None, + word_size=None, + crf=None, + freeze=True, +): + c = copy.deepcopy + # encoder=TransformerEncoder(num_layers=N,model_size=d_model,inner_size=d_ff,key_size=d_model//h,value_size=d_model//h,num_head=h,dropout=dropout) + encoder = transformer.make_encoder( + N=N, d_model=d_model, h=h, dropout=dropout, d_ff=d_ff + ) + + position = PositionalEncoding(d_model, dropout) + + embed = Embedding( + task_size, d_model, word_embedding, bigram_embedding, word_size, freeze + ) + model = CWSModel(encoder, embed, position, d_model, tag_size, crf=crf) + + for p in model.parameters(): + if p.dim() > 1 and p.requires_grad: + nn.init.xavier_uniform_(p) + + return model diff --git a/reproduction/multi-criteria-cws/optm.py b/reproduction/multi-criteria-cws/optm.py new file mode 100644 index 00000000..a2b68de5 --- /dev/null +++ b/reproduction/multi-criteria-cws/optm.py @@ -0,0 +1,49 @@ +import torch +import torch.optim as optim + + +class NoamOpt: + "Optim wrapper that implements rate." + + def __init__(self, model_size, factor, warmup, optimizer): + self.optimizer = optimizer + self._step = 0 + self.warmup = warmup + self.factor = factor + self.model_size = model_size + self._rate = 0 + + def step(self): + "Update parameters and rate" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + "Implement `lrate` above" + if step is None: + step = self._step + lr = self.factor * ( + self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + # if step>self.warmup: lr = max(1e-4,lr) + return lr + + +def get_std_opt(model): + return NoamOpt( + model.src_embed[0].d_model, + 2, + 4000, + torch.optim.Adam( + filter(lambda p: p.requires_grad, model.parameters()), + lr=0, + betas=(0.9, 0.98), + eps=1e-9, + ), + ) + diff --git a/reproduction/multi-criteria-cws/train.py b/reproduction/multi-criteria-cws/train.py new file mode 100644 index 00000000..fce914a1 --- /dev/null +++ b/reproduction/multi-criteria-cws/train.py @@ -0,0 +1,138 @@ +from fastNLP import (Trainer, Tester, Callback, GradientClipCallback, LRScheduler, SpanFPreRecMetric) +import torch +import torch.cuda +from torch.optim import Adam, SGD +from argparse import ArgumentParser +import logging +from .utils import set_seed + + +class LoggingCallback(Callback): + def __init__(self, filepath=None): + super().__init__() + # create file handler and set level to debug + if filepath is not None: + file_handler = logging.FileHandler(filepath, "a") + else: + file_handler = logging.StreamHandler() + + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter( + logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt='%m/%d/%Y %H:%M:%S')) + + # create logger and set level to debug + logger = logging.getLogger() + logger.handlers = [] + logger.setLevel(logging.DEBUG) + logger.propagate = False + logger.addHandler(file_handler) + self.log_writer = logger + + def on_backward_begin(self, loss): + if self.step % self.trainer.print_every == 0: + self.log_writer.info( + 'Step/Epoch {}/{}: Loss {}'.format(self.step, self.epoch, loss.item())) + + def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): + self.log_writer.info( + 'Step/Epoch {}/{}: Eval result {}'.format(self.step, self.epoch, eval_result)) + + def on_backward_end(self): + pass + + +def main(): + parser = ArgumentParser() + register_args(parser) + args = parser.parse_known_args()[0] + + set_seed(args.seed) + if args.train: + train(args) + if args.eval: + evaluate(args) + +def get_optim(args): + name = args.optim.strip().split(' ')[0].lower() + p = args.optim.strip() + l = p.find('(') + r = p.find(')') + optim_args = eval('dict({})'.format(p[[l+1,r]])) + if name == 'sgd': + return SGD(**optim_args) + elif name == 'adam': + return Adam(**optim_args) + else: + raise ValueError(args.optim) + +def load_model_from_path(args): + pass + +def train(args): + data = get_data(args) + train_data = data['train'] + dev_data = data['dev'] + model = get_model(args) + optimizer = get_optim(args) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + callbacks = [] + trainer = Trainer( + train_data=train_data, + model=model, + optimizer=optimizer, + loss=None, + batch_size=args.batch_size, + n_epochs=args.epochs, + num_workers=4, + metrics=SpanFPreRecMetric( + tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], + ignore_labels=data['ignore_labels']), + metric_key='f1', + dev_data=dev_data, + save_path=args.save_path, + device=device, + callbacks=callbacks, + check_code_level=-1, + ) + + print(trainer.train()) + + + +def evaluate(args): + data = get_data(args) + test_data = data['test'] + model = load_model_from_path(args) + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + tester = Tester( + data=test_data, model=model, batch_size=args.batch_size, + num_workers=2, device=device, + metrics=SpanFPreRecMetric( + tag_vocab=data['tag_vocab'], encoding_type=data['encoding_type'], + ignore_labels=data['ignore_labels']), + ) + print(tester.test()) + +def register_args(parser): + parser.add_argument('--optim', type=str, default='adam (lr=2e-3, weight_decay=0.0)') + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--save_path', type=str, default=None) + parser.add_argument('--data_path', type=str, required=True) + parser.add_argument('--log_path', type=str, default=None) + parser.add_argument('--model_config', type=str, required=True) + parser.add_argument('--load_path', type=str, default=None) + parser.add_argument('--train', action='store_true', default=False) + parser.add_argument('--eval', action='store_true', default=False) + parser.add_argument('--seed', type=int, default=42, help='rng seed') + +def get_model(args): + pass + +def get_data(args): + return torch.load(args.data_path) + +if __name__ == '__main__': + main() diff --git a/reproduction/multi-criteria-cws/train.sh b/reproduction/multi-criteria-cws/train.sh new file mode 100644 index 00000000..aa47b8af --- /dev/null +++ b/reproduction/multi-criteria-cws/train.sh @@ -0,0 +1,26 @@ +export EXP_NAME=release04 +export NGPU=2 +export PORT=9988 +export CUDA_DEVICE_ORDER=PCI_BUS_ID +export CUDA_VISIBLE_DEVICES=$1 + +if [ -z "$DATA_DIR" ] +then + DATA_DIR="./data" +fi + +echo $CUDA_VISIBLE_DEVICES +cmd=" +python -m torch.distributed.launch --nproc_per_node=$NGPU --master_port $PORT\ + main.py \ + --word-embeddings cn-char-fastnlp-100d \ + --bigram-embeddings cn-bi-fastnlp-100d \ + --num-epochs 100 \ + --batch-size 256 \ + --seed 1234 \ + --task-name $EXP_NAME \ + --dataset $DATA_DIR \ + --freeze \ +" +echo $cmd +eval $cmd diff --git a/reproduction/multi-criteria-cws/transformer.py b/reproduction/multi-criteria-cws/transformer.py new file mode 100644 index 00000000..fc352e44 --- /dev/null +++ b/reproduction/multi-criteria-cws/transformer.py @@ -0,0 +1,152 @@ +import numpy as np +import torch +import torch.autograd as autograd +import torch.nn as nn +import torch.nn.functional as F +import math, copy, time +from torch.autograd import Variable + +# import matplotlib.pyplot as plt + + +def clones(module, N): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype("uint8") + return torch.from_numpy(subsequent_mask) == 0 + + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) + if mask is not None: + # print(scores.size(),mask.size()) # [bsz,1,1,len] + scores = scores.masked_fill(mask == 0, -1e9) + p_attn = F.softmax(scores, dim=-1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + + nbatches = query.size(0) + + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = [ + l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value)) + ] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout) + + # 3) "Concat" using a view and apply a final linear. + x = x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) + return self.linears[-1](x) + + +class LayerNorm(nn.Module): + "Construct a layernorm module (See citation for details)." + + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward (defined below)" + + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "Follow Figure 1 (left) for connections." + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + + +class Encoder(nn.Module): + "Core encoder is a stack of N layers" + + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask): + # print(x.size(),mask.size()) + "Pass the input (and mask) through each layer in turn." + mask = mask.byte().unsqueeze(-2) + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + + +def make_encoder(N=6, d_model=512, d_ff=2048, h=8, dropout=0.1): + c = copy.deepcopy + attn = MultiHeadedAttention(h, d_model) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + return Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N) diff --git a/reproduction/multi-criteria-cws/utils.py b/reproduction/multi-criteria-cws/utils.py new file mode 100644 index 00000000..aeb7e43c --- /dev/null +++ b/reproduction/multi-criteria-cws/utils.py @@ -0,0 +1,308 @@ +import numpy as np +import torch +import torch.cuda +import random +import os +import sys +import errno +import time +import codecs +import hashlib +import _pickle as pickle +import warnings +from fastNLP.io import EmbedLoader + +UNK_TAG = "" + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def bmes_to_words(chars, tags): + result = [] + if len(chars) == 0: + return result + word = chars[0] + + for c, t in zip(chars[1:], tags[1:]): + if t.upper() == "B" or t.upper() == "S": + result.append(word) + word = "" + word += c + if len(word) != 0: + result.append(word) + + return result + + +def bmes_to_index(tags): + result = [] + if len(tags) == 0: + return result + word = (0, 0) + + for i, t in enumerate(tags): + if i == 0: + word = (0, 0) + elif t.upper() == "B" or t.upper() == "S": + result.append(word) + word = (i, 0) + word = (word[0], word[1] + 1) + if word[1] != 0: + result.append(word) + return result + + +def get_bmes(sent): + x = [] + y = [] + for word in sent: + length = len(word) + tag = ["m"] * length if length > 1 else ["s"] * length + if length > 1: + tag[0] = "b" + tag[-1] = "e" + x += list(word) + y += tag + return x, y + + +class CWSEvaluator: + def __init__(self, i2t): + self.correct_preds = 0.0 + self.total_preds = 0.0 + self.total_correct = 0.0 + self.i2t = i2t + + def add_instance(self, pred_tags, gold_tags): + pred_tags = [self.i2t[i] for i in pred_tags] + gold_tags = [self.i2t[i] for i in gold_tags] + # Evaluate PRF + lab_gold_chunks = set(bmes_to_index(gold_tags)) + lab_pred_chunks = set(bmes_to_index(pred_tags)) + self.correct_preds += len(lab_gold_chunks & lab_pred_chunks) + self.total_preds += len(lab_pred_chunks) + self.total_correct += len(lab_gold_chunks) + + def result(self, percentage=True): + p = self.correct_preds / self.total_preds if self.correct_preds > 0 else 0 + r = self.correct_preds / self.total_correct if self.correct_preds > 0 else 0 + f1 = 2 * p * r / (p + r) if p + r > 0 else 0 + if percentage: + p *= 100 + r *= 100 + f1 *= 100 + return p, r, f1 + + +class CWS_OOV: + def __init__(self, dic): + self.dic = dic + self.recall = 0 + self.tot = 0 + + def update(self, gold_sent, pred_sent): + i = 0 + j = 0 + id = 0 + for w in gold_sent: + if w not in self.dic: + self.tot += 1 + while i + len(pred_sent[id]) <= j: + i += len(pred_sent[id]) + id += 1 + if ( + i == j + and len(pred_sent[id]) == len(w) + and w.find(pred_sent[id]) != -1 + ): + self.recall += 1 + j += len(w) + # print(gold_sent,pred_sent,self.tot) + + def oov(self, percentage=True): + ins = 1.0 * self.recall / self.tot + if percentage: + ins *= 100 + return ins + + +def get_processing_word( + vocab_words=None, vocab_chars=None, lowercase=False, chars=False +): + def f(word): + # 0. get chars of words + if vocab_chars is not None and chars: + char_ids = [] + for char in word: + # ignore chars out of vocabulary + if char in vocab_chars: + char_ids += [vocab_chars[char]] + + # 1. preprocess word + if lowercase: + word = word.lower() + if word.isdigit(): + word = "0" + + # 2. get id of word + if vocab_words is not None: + if word in vocab_words: + word = vocab_words[word] + else: + word = vocab_words[UNK_TAG] + + # 3. return tuple char ids, word id + if vocab_chars is not None and chars: + return char_ids, word + else: + return word + + return f + + +def append_tags(src, des, name, part, encode="utf-16"): + with open("{}/{}.txt".format(src, part), encoding=encode) as input, open( + "{}/{}.txt".format(des, part), "a", encoding=encode + ) as output: + for line in input: + line = line.strip() + if len(line) > 0: + output.write("<{}> {} ".format(name, line, name)) + output.write("\n") + + +def is_dataset_tag(word): + return len(word) > 2 and word[0] == "<" and word[-1] == ">" + + +def to_tag_strings(i2ts, tag_mapping, pos_separate_col=True): + senlen = len(tag_mapping) + key_value_strs = [] + + for j in range(senlen): + val = i2ts[tag_mapping[j]] + pos_str = val + key_value_strs.append(pos_str) + return key_value_strs + + +def to_id_list(w2i): + i2w = [None] * len(w2i) + for w, i in w2i.items(): + i2w[i] = w + return i2w + + +def make_sure_path_exists(path): + try: + os.makedirs(path) + except OSError as exception: + if exception.errno != errno.EEXIST: + raise + + +def md5_for_file(fn): + md5 = hashlib.md5() + with open(fn, "rb") as f: + for chunk in iter(lambda: f.read(128 * md5.block_size), b""): + md5.update(chunk) + return md5.hexdigest() + + +def embedding_match_vocab( + vocab, + emb, + ori_vocab, + dtype=np.float32, + padding="", + unknown="", + normalize=True, + error="ignore", + init_method=None, +): + dim = emb.shape[-1] + matrix = np.random.randn(len(vocab), dim).astype(dtype) + hit_flags = np.zeros(len(vocab), dtype=bool) + + if init_method: + matrix = init_method(matrix) + for word, idx in ori_vocab.word2idx.items(): + try: + if word == padding and vocab.padding is not None: + word = vocab.padding + elif word == unknown and vocab.unknown is not None: + word = vocab.unknown + if word in vocab: + index = vocab.to_index(word) + matrix[index] = emb[idx] + hit_flags[index] = True + except Exception as e: + if error == "ignore": + warnings.warn("Error occurred at the {} line.".format(idx)) + else: + print("Error occurred at the {} line.".format(idx)) + raise e + + total_hits = np.sum(hit_flags) + print( + "Found {} out of {} words in the pre-training embedding.".format( + total_hits, len(vocab) + ) + ) + if init_method is None: + found_vectors = matrix[hit_flags] + if len(found_vectors) != 0: + mean = np.mean(found_vectors, axis=0, keepdims=True) + std = np.std(found_vectors, axis=0, keepdims=True) + unfound_vec_num = len(vocab) - total_hits + r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean + matrix[hit_flags == False] = r_vecs + + if normalize: + matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) + + return matrix + + +def embedding_load_with_cache(emb_file, cache_dir, vocab, **kwargs): + def match_cache(file, cache_dir): + md5 = md5_for_file(file) + cache_files = os.listdir(cache_dir) + for fn in cache_files: + if md5 in fn.split("-")[-1]: + return os.path.join(cache_dir, fn), True + return ( + "{}-{}.pkl".format(os.path.join(cache_dir, os.path.basename(file)), md5), + False, + ) + + def get_cache(file): + if not os.path.exists(file): + return None + with open(file, "rb") as f: + emb = pickle.load(f) + return emb + + os.makedirs(cache_dir, exist_ok=True) + cache_fn, match = match_cache(emb_file, cache_dir) + if not match: + print("cache missed, re-generating cache at {}".format(cache_fn)) + emb, ori_vocab = EmbedLoader.load_without_vocab( + emb_file, padding=None, unknown=None, normalize=False + ) + with open(cache_fn, "wb") as f: + pickle.dump((emb, ori_vocab), f) + + else: + print("cache matched at {}".format(cache_fn)) + + # use cache + print("loading embeddings ...") + emb = get_cache(cache_fn) + assert emb is not None + return embedding_match_vocab(vocab, emb[0], emb[1], **kwargs)