From 0431a958b0e2a0f456af905f96a643fc6058bc2a Mon Sep 17 00:00:00 2001 From: benbijituo Date: Thu, 26 Sep 2019 14:31:37 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86CNXNLI=E7=9A=84?= =?UTF-8?q?=5Fload()=EF=BC=8C=E5=8F=AF=E4=BB=A5=E5=A4=84=E7=90=86=E7=89=B9?= =?UTF-8?q?=E6=AE=8A=E7=9A=84instance=E6=A0=BC=E5=BC=8F=E5=A6=82=E4=B8=8B?= =?UTF-8?q?=EF=BC=9A=20=E2=80=9CXXX\t"XXX\tXXX?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/loader/matching.py | 12 +++------ fastNLP/io/pipe/matching.py | 38 +++++++++++++++++++++++++++ test/data_for_tests/io/XNLI/train.txt | 1 + 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 0969eeef..d3da9bdc 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -340,7 +340,7 @@ class QuoraLoader(Loader): class CNXNLILoader(Loader): """ 别名: - 数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理 + 数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize 原始数据为: train中的数据包括premise,hypo和label三个field dev和test中的数据为csv或json格式,包括十多个field,这里只取与以上三个field中的数据 @@ -358,8 +358,6 @@ class CNXNLILoader(Loader): super(CNXNLILoader, self).__init__() def _load(self, path: str = None): - #csv_loader = CSVLoader(sep='\t') - #ds_all = csv_loader._load(path) ds_all = DataSet() with open(path, 'r', encoding='utf-8') as f: head_name_list = f.readline().strip().split('\t') @@ -386,17 +384,15 @@ class CNXNLILoader(Loader): return ds_zh def _load_train(self, path: str = None): - #csv_loader = CSVLoader(sep='\t') - #ds = csv_loader._load(path) ds = DataSet() with open(path, 'r', encoding='utf-8') as f: next(f) for line in f: raw_instance = line.strip().split('\t') - premise = raw_instance[0] - hypo = raw_instance[1] - label = raw_instance[-1] + premise = "".join(raw_instance[0].split())# 把已经分好词的premise和hypo强制还原为character segmentation + hypo = "".join(raw_instance[1].split()) + label = "".join(raw_instance[-1].split()) if premise: ds.append(Instance(premise=premise, hypo=hypo, label=label)) diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index dac21dca..dbe69525 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -466,6 +466,7 @@ class LCQMCBertPipe(MatchingBertPipe): data_bundle = LCQMCLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) + data_bundle = TruncateBertPipe(task='cn').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) return data_bundle @@ -475,6 +476,7 @@ class BQCorpusBertPipe(MatchingBertPipe): data_bundle = BQCorpusLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) + data_bundle = TruncateBertPipe(task='cn').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) return data_bundle @@ -485,5 +487,41 @@ class CNXNLIBertPipe(MatchingBertPipe): data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) + data_bundle = TruncateBertPipe(task='cn').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) return data_bundle + + +class TruncateBertPipe(Pipe): + def __init__(self, task='cn'): + super().__init__() + self.task = task + + def _truncate(self, sentence_index:list, sep_index_vocab): + # 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index + sep_index_words = sentence_index.index(sep_index_vocab) + words_before_sep = sentence_index[:sep_index_words] + words_after_sep = sentence_index[sep_index_words:] # 注意此部分包括了[SEP] + if self.task == 'cn': + # 中文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过250 + words_before_sep = words_before_sep[:250] + words_after_sep = words_after_sep[:250] + elif self.task == 'en': + # 英文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过215 + words_before_sep = words_before_sep[:215] + words_after_sep = words_after_sep[:215] + else: + raise RuntimeError("Only support 'cn' or 'en' task.") + + return words_before_sep + words_after_sep + + def process(self, data_bundle: DataBundle) -> DataBundle: + for name in data_bundle.datasets.keys(): + dataset = data_bundle.get_dataset(name) + sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') + dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') + + # truncate之后需要更新seq_len + dataset.add_seq_len(field_name='words') + return data_bundle + diff --git a/test/data_for_tests/io/XNLI/train.txt b/test/data_for_tests/io/XNLI/train.txt index 45d1ce9e..8a2fd3a3 100644 --- a/test/data_for_tests/io/XNLI/train.txt +++ b/test/data_for_tests/io/XNLI/train.txt @@ -6,3 +6,4 @@ premise hypo label 一 段 时间 来 看 , 这 一 运动 似乎 要 取得 成功 , 但 政治 事件 , 加 上 帕内尔 在 一个 令 人 愤慨 的 离婚案 中 被 称为 共同 答辩人 , 导致 许多 人 撤回 他们 的 支持 . 帕内尔 在 一个 令 人 愤慨 的 离婚 问题 上 的 法律 问题 使 这 场 运动 受到 了 影响 . entailment 看 在 这里 , 他 说 我们 不 希望 任何 律师 混在 这 一 点 . 他 说 看看 那 张 纸 neutral Soderstrom 在 创伤 中心 进行 了 多次 筛选 测试 . 测试 必须 在 创伤 中心 进行 比较 , 否则 就 会 无效 . neutral +嗯 , 这 是 一 种 明显 的 我 的 意思 是 , 他们 甚至 把 它 带 到 现在 呢 , 他们 在 电视 上 做 广告 , 你 知道 如果 你 知道 , 如果 你 知道 这样 做 , 或者 如果 你 需要 这 个呃 , 我们 会 告 你 和 你 你 不用 给 我们 钱 , 但 他们 不 告诉 你 的 是 如果 他们 赢 了 你 给 他们 至少 三分之一 他们 赢 的 东西 , 所以 我 不 知道 它 是呃 , 它 得到 了 现在 做 更 多 的 生意 , 而 不 是呃 实际上 是 在 处理 犯罪 而 不 是 与 呃嗯 他们 的 律师 只 是 为了 钱 , 我 相信 , 我 知道 我 同意 你 , 我 认为 你 是 真实 的 你. 非常 正确 的 是 , 我 认为 他们 应该 有 同等 数量 的 你 知道 也许 他们 可以 有 几 个 , 但 我 认为 大多数 他们 应该 不 是 律师 在 事实 , 这 是 方式 他们 已经 进入 政治 , 这 是 因为 在 法律 上 , 你 知道 的 循环 和 一切 , 但 我 不 知道 我们 是 在 马里兰州 和呃 , 我们 有 同样 的 东西 人满为患 , 和呃 他们 让 他们 出来 我 的 意思 是 只 是 普通 的 监狱 判决 的 事情 , 他们 让. 他们 是 因为 他们 没有 任何 地方 可以 留住 他们 所以 你 可以 知道呃 , 除非 是 一个 重大 的 罪行 , 但呃 , 即使 是 小小的 东西 , 我 的 意思 是 那些 在 美国 失去 的 人 是 受害者 和 谁 可能 是 抢劫 或 毒品 , 或者 其他 什么 , 他们 是 谁 要 支付 , 他们 是 一个 会 受苦 , 另 一个 你 知道 的 人 , 如果 他们 被 逮捕 , 如果 他们 逮捕 他们嗯 , 然后 呢 , 你 知道 的 时间 法律 接管 了 一 半 时间 呃 他们 要么 让 他们 走 , 或者 他们 下 了 一个 句子 , 因为 他们 有 一个 律师 , 你 知道 的 感觉 他们 是 不 是 所有 在 一起 当 他们 做到 了 .它 我 不 知道 我们 怎么 到 这 一 点 , 虽然 . neutral From dd94641a271f3e14fb14bc2d5fb2271fa4056565 Mon Sep 17 00:00:00 2001 From: yhcc Date: Thu, 26 Sep 2019 20:45:35 +0800 Subject: [PATCH 2/6] Update README.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 更新readme错误换行 --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 916f2c1a..f9fdc768 100644 --- a/README.md +++ b/README.md @@ -137,6 +137,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: 致谢 感谢 [深脑云](http://www.dbcloud.ai/) 提供的模型与数据存储、下载服务。 + From da2416a1b89450e00841b0e931aaadfcd98a6eff Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Fri, 27 Sep 2019 03:08:29 +0800 Subject: [PATCH 3/6] fix test bugs in: 1. use prettytable to print instance; 2. CNXNLI loader and pipe. --- test/core/test_dataset.py | 2 +- test/io/loader/test_matching_loader.py | 2 +- test/io/pipe/test_matching.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 9820eff6..852041f8 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -230,7 +230,7 @@ class TestDataSetIter(unittest.TestCase): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) for iter in ds: self.assertEqual(iter.__repr__(), """+--------------+--------+ -| x | y | +| x | y | +--------------+--------+ | [1, 2, 3, 4] | [5, 6] | +--------------+--------+""") diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py index abe21aa9..70367f6d 100644 --- a/test/io/loader/test_matching_loader.py +++ b/test/io/loader/test_matching_loader.py @@ -31,7 +31,7 @@ class TestMatchingLoad(unittest.TestCase): 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True), 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False), - 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 7, 6), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False), 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False), } for k, v in data_set_dict.items(): diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index 52d372d5..7e68863d 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -38,7 +38,7 @@ class TestRunMatchingPipe(unittest.TestCase): 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), - 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 7, 6), (37, 3), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (37, 3), False), 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), } for k, v in data_set_dict.items(): From 995f4ef18d8dd29bb06507d0f3b092eb669a6055 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Fri, 27 Sep 2019 03:32:03 +0800 Subject: [PATCH 4/6] add distilbert from pytorch-transformers package --- fastNLP/io/file_utils.py | 2 + fastNLP/modules/encoder/bert.py | 124 +++++++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 11 deletions(-) diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index a4abb575..9e7ac6f6 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -37,6 +37,8 @@ PRETRAINED_BERT_MODEL_DIR = { 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', + 'en-distilbert-base-uncased': 'distilbert-base-uncased.zip', + 'multi-base-cased': 'bert-base-multilingual-cased.zip', 'multi-base-uncased': 'bert-base-multilingual-uncased.zip', diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py index 16b456fb..821b9c5c 100644 --- a/fastNLP/modules/encoder/bert.py +++ b/fastNLP/modules/encoder/bert.py @@ -16,6 +16,7 @@ import unicodedata import torch from torch import nn +import numpy as np from ..utils import _get_file_name_base_on_postfix from ...io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR @@ -24,6 +25,24 @@ from ...core import logger CONFIG_FILE = 'bert_config.json' VOCAB_NAME = 'vocab.txt' +BERT_KEY_RENAME_MAP_1 = { + 'gamma': 'weight', + 'beta': 'bias', + 'distilbert.embeddings': 'bert.embeddings', + 'distilbert.transformer': 'bert.encoder', +} + +BERT_KEY_RENAME_MAP_2 = { + 'q_lin': 'self.query', + 'k_lin': 'self.key', + 'v_lin': 'self.value', + 'out_lin': 'output.dense', + 'sa_layer_norm': 'attention.output.LayerNorm', + 'ffn.lin1': 'intermediate.dense', + 'ffn.lin2': 'output.dense', + 'output_layer_norm': 'output.LayerNorm', +} + class BertConfig(object): """Configuration class to store the configuration of a `BertModel`. @@ -162,6 +181,55 @@ class BertLayerNorm(nn.Module): return self.weight * x + self.bias +class DistilBertEmbeddings(nn.Module): + def __init__(self, config): + super(DistilBertEmbeddings, self).__init__() + + def create_sinusoidal_embeddings(n_pos, dim, out): + position_enc = np.array([ + [pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] + for pos in range(n_pos) + ]) + out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + out.requires_grad = False + + self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + if config.sinusoidal_pos_embds: + create_sinusoidal_embeddings(n_pos=config.max_position_embeddings, + dim=config.hidden_size, + out=self.position_embeddings.weight) + + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=1e-12) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, input_ids, token_type_ids): + """ + Parameters + ---------- + input_ids: torch.tensor(bs, max_seq_length) + The token ids to embed. + token_type_ids: no used. + Outputs + ------- + embeddings: torch.tensor(bs, max_seq_length, dim) + The embedded tokens (plus position embeddings, no token_type embeddings) + """ + seq_length = input_ids.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length) + + word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim) + position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim) + + embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim) + embeddings = self.LayerNorm(embeddings) # (bs, max_seq_length, dim) + embeddings = self.dropout(embeddings) # (bs, max_seq_length, dim) + return embeddings + + class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings. """ @@ -383,9 +451,22 @@ class BertModel(nn.Module): super(BertModel, self).__init__() self.config = config self.hidden_size = self.config.hidden_size - self.embeddings = BertEmbeddings(config) + self.model_type = 'bert' + if hasattr(config, 'sinusoidal_pos_embds'): + self.model_type = 'distilbert' + elif 'model_type' in kwargs: + self.model_type = kwargs['model_type'].lower() + + if self.model_type == 'distilbert': + self.embeddings = DistilBertEmbeddings(config) + else: + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) - self.pooler = BertPooler(config) + if self.model_type != 'distilbert': + self.pooler = BertPooler(config) + else: + logger.info('DistilBert has NOT pooler, will use hidden states of [CLS] token as pooled output.') self.apply(self.init_bert_weights) def init_bert_weights(self, module): @@ -427,7 +508,10 @@ class BertModel(nn.Module): extended_attention_mask, output_all_encoded_layers=output_all_encoded_layers) sequence_output = encoded_layers[-1] - pooled_output = self.pooler(sequence_output) + if self.model_type != 'distilbert': + pooled_output = self.pooler(sequence_output) + else: + pooled_output = sequence_output[:, 0] if not output_all_encoded_layers: encoded_layers = encoded_layers[-1] return encoded_layers, pooled_output @@ -445,9 +529,7 @@ class BertModel(nn.Module): # Load config config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json') config = BertConfig.from_json_file(config_file) - # logger.info("Model config {}".format(config)) - # Instantiate model. - model = cls(config, *inputs, **kwargs) + if state_dict is None: weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin') state_dict = torch.load(weights_path, map_location='cpu') @@ -455,20 +537,40 @@ class BertModel(nn.Module): logger.error(f'Cannot load parameters through `state_dict` variable.') raise RuntimeError(f'Cannot load parameters through `state_dict` variable.') + model_type = 'BERT' + old_keys = [] + new_keys = [] + for key in state_dict.keys(): + new_key = None + for key_name in BERT_KEY_RENAME_MAP_1: + if key_name in key: + new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_1[key_name]) + if 'distilbert' in key: + model_type = 'DistilBert' + break + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + state_dict[new_key] = state_dict.pop(old_key) + old_keys = [] new_keys = [] for key in state_dict.keys(): new_key = None - if 'gamma' in key: - new_key = key.replace('gamma', 'weight') - if 'beta' in key: - new_key = key.replace('beta', 'bias') + for key_name in BERT_KEY_RENAME_MAP_2: + if key_name in key: + new_key = key.replace(key_name, BERT_KEY_RENAME_MAP_2[key_name]) + break if new_key: old_keys.append(key) new_keys.append(new_key) for old_key, new_key in zip(old_keys, new_keys): state_dict[new_key] = state_dict.pop(old_key) + # Instantiate model. + model = cls(config, model_type=model_type, *inputs, **kwargs) + missing_keys = [] unexpected_keys = [] error_msgs = [] @@ -494,7 +596,7 @@ class BertModel(nn.Module): logger.warning("Weights from pretrained model not used in {}: {}".format( model.__class__.__name__, unexpected_keys)) - logger.info(f"Load pre-trained BERT parameters from file {weights_path}.") + logger.info(f"Load pre-trained {model_type} parameters from file {weights_path}.") return model From 209a8fedbac770bcacd9141fa575e3fc95515979 Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Fri, 27 Sep 2019 04:16:25 +0800 Subject: [PATCH 5/6] Update test_matching.py --- test/io/pipe/test_matching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py index 7e68863d..bfd65db2 100644 --- a/test/io/pipe/test_matching.py +++ b/test/io/pipe/test_matching.py @@ -38,7 +38,7 @@ class TestRunMatchingPipe(unittest.TestCase): 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True), 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True), 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False), - 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (37, 3), False), + 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False), 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False), } for k, v in data_set_dict.items(): From a98fed61e546bb6f27fe12b5011a31ef82a46e75 Mon Sep 17 00:00:00 2001 From: yunfan Date: Fri, 27 Sep 2019 07:17:44 +0000 Subject: [PATCH 6/6] [bugfix] package now ignores reproduction/ & test/ --- setup.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 7d1d8034..72f92d16 100644 --- a/setup.py +++ b/setup.py @@ -11,9 +11,12 @@ with open('LICENSE', encoding='utf-8') as f: with open('requirements.txt', encoding='utf-8') as f: reqs = f.read() +pkgs = [p for p in find_packages() if p.startswith('fastNLP')] +print(pkgs) + setup( name='FastNLP', - version='0.4.9', + version='0.4.10', url='https://github.com/fastnlp/fastNLP', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', long_description=readme, @@ -21,6 +24,6 @@ setup( license='Apache License', author='FudanNLP', python_requires='>=3.6', - packages=find_packages(), + packages=pkgs, install_requires=reqs.strip().split('\n'), )