@@ -5,10 +5,8 @@ from ..base_loader import DataInfo, DataSetLoader | |||||
from ...core.vocabulary import VocabularyOption, Vocabulary | from ...core.vocabulary import VocabularyOption, Vocabulary | ||||
from ...core.dataset import DataSet | from ...core.dataset import DataSet | ||||
from ...core.instance import Instance | from ...core.instance import Instance | ||||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||||
from ..utils import check_dataloader_paths, get_tokenizer | |||||
spacy.prefer_gpu() | |||||
sptk = spacy.load('en') | |||||
class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | ||||
@@ -37,6 +35,7 @@ class SSTLoader(DataSetLoader): | |||||
tag_v['0'] = tag_v['1'] | tag_v['0'] = tag_v['1'] | ||||
tag_v['4'] = tag_v['3'] | tag_v['4'] = tag_v['3'] | ||||
self.tag_v = tag_v | self.tag_v = tag_v | ||||
self.tokenizer = get_tokenizer() | |||||
def _load(self, path): | def _load(self, path): | ||||
""" | """ | ||||
@@ -55,29 +54,37 @@ class SSTLoader(DataSetLoader): | |||||
ds.append(Instance(words=words, target=tag)) | ds.append(Instance(words=words, target=tag)) | ||||
return ds | return ds | ||||
@staticmethod | |||||
def _get_one(data, subtree): | |||||
def _get_one(self, data, subtree): | |||||
tree = Tree.fromstring(data) | tree = Tree.fromstring(data) | ||||
if subtree: | if subtree: | ||||
return [([x.text for x in sptk.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] | |||||
return [([x.text for x in sptk.tokenizer(' '.join(tree.leaves()))], tree.label())] | |||||
return [([x.text for x in self.tokenizer(' '.join(t.leaves()))], t.label()) for t in tree.subtrees() ] | |||||
return [([x.text for x in self.tokenizer(' '.join(tree.leaves()))], tree.label())] | |||||
def process(self, | def process(self, | ||||
paths, | |||||
train_ds: Iterable[str] = None, | |||||
paths, train_subtree=True, | |||||
src_vocab_op: VocabularyOption = None, | src_vocab_op: VocabularyOption = None, | ||||
tgt_vocab_op: VocabularyOption = None, | |||||
src_embed_op: EmbeddingOption = None): | |||||
tgt_vocab_op: VocabularyOption = None,): | |||||
paths = check_dataloader_paths(paths) | |||||
input_name, target_name = 'words', 'target' | input_name, target_name = 'words', 'target' | ||||
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) | ||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | tgt_vocab = Vocabulary(unknown=None, padding=None) \ | ||||
if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) | ||||
info = DataInfo(datasets=self.load(paths)) | |||||
_train_ds = [info.datasets[name] | |||||
for name in train_ds] if train_ds else info.datasets.values() | |||||
src_vocab.from_dataset(*_train_ds, field_name=input_name) | |||||
tgt_vocab.from_dataset(*_train_ds, field_name=target_name) | |||||
info = DataInfo() | |||||
origin_subtree = self.subtree | |||||
self.subtree = train_subtree | |||||
info.datasets['train'] = self._load(paths['train']) | |||||
self.subtree = origin_subtree | |||||
for n, p in paths.items(): | |||||
if n != 'train': | |||||
info.datasets[n] = self._load(p) | |||||
src_vocab.from_dataset( | |||||
info.datasets['train'], | |||||
field_name=input_name, | |||||
no_create_entry_dataset=[ds for n, ds in info.datasets.items() if n != 'train']) | |||||
tgt_vocab.from_dataset(info.datasets['train'], field_name=target_name) | |||||
src_vocab.index_dataset( | src_vocab.index_dataset( | ||||
*info.datasets.values(), | *info.datasets.values(), | ||||
field_name=input_name, new_field_name=input_name) | field_name=input_name, new_field_name=input_name) | ||||
@@ -89,10 +96,5 @@ class SSTLoader(DataSetLoader): | |||||
target_name: tgt_vocab | target_name: tgt_vocab | ||||
} | } | ||||
if src_embed_op is not None: | |||||
src_embed_op.vocab = src_vocab | |||||
init_emb = EmbedLoader.load_with_vocab(**src_embed_op) | |||||
info.embeddings[input_name] = init_emb | |||||
return info | return info | ||||
@@ -0,0 +1,69 @@ | |||||
import os | |||||
from typing import Union, Dict | |||||
def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
""" | |||||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果 | |||||
{ | |||||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||||
'test': 'xxx' # 可能有,也可能没有 | |||||
... | |||||
} | |||||
如果paths为不合法的,将直接进行raise相应的错误 | |||||
:param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名 | |||||
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||||
:return: | |||||
""" | |||||
if isinstance(paths, str): | |||||
if os.path.isfile(paths): | |||||
return {'train': paths} | |||||
elif os.path.isdir(paths): | |||||
filenames = os.listdir(paths) | |||||
files = {} | |||||
for filename in filenames: | |||||
path_pair = None | |||||
if 'train' in filename: | |||||
path_pair = ('train', filename) | |||||
if 'dev' in filename: | |||||
if path_pair: | |||||
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('dev', filename) | |||||
if 'test' in filename: | |||||
if path_pair: | |||||
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('test', filename) | |||||
if path_pair: | |||||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||||
return files | |||||
else: | |||||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||||
elif isinstance(paths, dict): | |||||
if paths: | |||||
if 'train' not in paths: | |||||
raise KeyError("You have to include `train` in your dict.") | |||||
for key, value in paths.items(): | |||||
if isinstance(key, str) and isinstance(value, str): | |||||
if not os.path.isfile(value): | |||||
raise TypeError(f"{value} is not a valid file.") | |||||
else: | |||||
raise TypeError("All keys and values in paths should be str.") | |||||
return paths | |||||
else: | |||||
raise ValueError("Empty paths is not allowed.") | |||||
else: | |||||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | |||||
def get_tokenizer(): | |||||
try: | |||||
import spacy | |||||
spacy.prefer_gpu() | |||||
en = spacy.load('en') | |||||
print('use spacy tokenizer') | |||||
return lambda x: [w.text for w in en.tokenizer(x)] | |||||
except Exception as e: | |||||
print('use raw tokenizer') | |||||
return lambda x: x.split() |
@@ -6,7 +6,7 @@ paper: [Star-Transformer](https://arxiv.org/abs/1902.09113) | |||||
|Pos Tagging|CTB 9.0|-|ACC 92.31| | |Pos Tagging|CTB 9.0|-|ACC 92.31| | ||||
|Pos Tagging|CONLL 2012|-|ACC 96.51| | |Pos Tagging|CONLL 2012|-|ACC 96.51| | ||||
|Named Entity Recognition|CONLL 2012|-|F1 85.66| | |Named Entity Recognition|CONLL 2012|-|F1 85.66| | ||||
|Text Classification|SST|-|49.18| | |||||
|Text Classification|SST|-|51.2| | |||||
|Natural Language Inference|SNLI|-|83.76| | |Natural Language Inference|SNLI|-|83.76| | ||||
## Usage | ## Usage | ||||
@@ -0,0 +1,105 @@ | |||||
import argparse | |||||
import torch | |||||
import os | |||||
from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const | |||||
from fastNLP.modules.encoder.embedding import StaticEmbedding | |||||
from reproduction.matching.data.MatchingDataLoader import QNLILoader, RTELoader, SNLILoader, MNLILoader | |||||
from reproduction.matching.model.cntn import CNTNModel | |||||
# define hyper-parameters | |||||
argument = argparse.ArgumentParser() | |||||
argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glove') | |||||
argument.add_argument('--batch-size-per-gpu', type=int, default=256) | |||||
argument.add_argument('--n-epochs', type=int, default=200) | |||||
argument.add_argument('--lr', type=float, default=1e-5) | |||||
argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='mask') | |||||
argument.add_argument('--save-dir', type=str, default=None) | |||||
argument.add_argument('--cntn-depth', type=int, default=1) | |||||
argument.add_argument('--cntn-ns', type=int, default=200) | |||||
argument.add_argument('--cntn-k-top', type=int, default=10) | |||||
argument.add_argument('--cntn-r', type=int, default=5) | |||||
argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli') | |||||
argument.add_argument('--max-len', type=int, default=50) | |||||
arg = argument.parse_args() | |||||
# dataset dict | |||||
dev_dict = { | |||||
'qnli': 'dev', | |||||
'rte': 'dev', | |||||
'snli': 'dev', | |||||
'mnli': 'dev_matched', | |||||
} | |||||
test_dict = { | |||||
'qnli': 'dev', | |||||
'rte': 'dev', | |||||
'snli': 'test', | |||||
'mnli': 'dev_matched', | |||||
} | |||||
# set num_labels | |||||
if arg.dataset == 'qnli' or arg.dataset == 'rte': | |||||
num_labels = 2 | |||||
else: | |||||
num_labels = 3 | |||||
# load data set | |||||
if arg.dataset == 'qnli': | |||||
data_info = QNLILoader().process( | |||||
paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
elif arg.dataset == 'rte': | |||||
data_info = RTELoader().process( | |||||
paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
elif arg.dataset == 'snli': | |||||
data_info = SNLILoader().process( | |||||
paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
elif arg.dataset == 'mnli': | |||||
data_info = MNLILoader().process( | |||||
paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None, | |||||
get_index=True, concat=False, auto_pad_length=arg.max_len) | |||||
else: | |||||
raise ValueError(f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!') | |||||
# load embedding | |||||
if arg.embedding == 'word2vec': | |||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-word2vec-300', requires_grad=True) | |||||
elif arg.embedding == 'glove': | |||||
embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-glove-840b-300', | |||||
requires_grad=True) | |||||
else: | |||||
raise ValueError(f'now we only support word2vec or glove embedding for cntn model!') | |||||
# define model | |||||
model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=num_labels, depth=arg.cntn_depth, | |||||
r=arg.cntn_r) | |||||
print(model) | |||||
# define trainer | |||||
trainer = Trainer(train_data=data_info.datasets['train'], model=model, | |||||
optimizer=Adam(lr=arg.lr, model_params=model.parameters()), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||||
n_epochs=arg.n_epochs, print_every=-1, | |||||
dev_data=data_info.datasets[dev_dict[arg.dataset]], | |||||
metrics=AccuracyMetric(), metric_key='acc', | |||||
device=[i for i in range(torch.cuda.device_count())], | |||||
check_code_level=-1) | |||||
# train model | |||||
trainer.train(load_best_model=True) | |||||
# define tester | |||||
tester = Tester( | |||||
data=data_info.datasets[test_dict[arg.dataset]], | |||||
model=model, | |||||
metrics=AccuracyMetric(), | |||||
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu, | |||||
device=[i for i in range(torch.cuda.device_count())] | |||||
) | |||||
# test model | |||||
tester.test() |
@@ -0,0 +1,120 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
import torch.nn.functional as F | |||||
import numpy as np | |||||
from torch.nn import CrossEntropyLoss | |||||
from fastNLP.models import BaseModel | |||||
from fastNLP.modules.encoder.embedding import TokenEmbedding | |||||
from fastNLP.core.const import Const | |||||
class DynamicKMaxPooling(nn.Module): | |||||
""" | |||||
:param k_top: Fixed number of pooling output features for the topmost convolutional layer. | |||||
:param l: Number of convolutional layers. | |||||
""" | |||||
def __init__(self, k_top, l): | |||||
super(DynamicKMaxPooling, self).__init__() | |||||
self.k_top = k_top | |||||
self.L = l | |||||
def forward(self, x, l): | |||||
""" | |||||
:param x: Input sequence. | |||||
:param l: Current convolutional layers. | |||||
""" | |||||
s = x.size()[3] | |||||
k_ll = ((self.L - l) / self.L) * s | |||||
k_l = int(round(max(self.k_top, np.ceil(k_ll)))) | |||||
out = F.adaptive_max_pool2d(x, (x.size()[2], k_l)) | |||||
return out | |||||
class CNTNModel(BaseModel): | |||||
""" | |||||
使用CNN进行问答匹配的模型 | |||||
'Qiu, Xipeng, and Xuanjing Huang. | |||||
Convolutional neural tensor network architecture for community-based question answering. | |||||
Twenty-Fourth International Joint Conference on Artificial Intelligence. 2015.' | |||||
:param init_embedding: Embedding. | |||||
:param ns: Sentence embedding size. | |||||
:param k_top: Fixed number of pooling output features for the topmost convolutional layer. | |||||
:param num_labels: Number of labels. | |||||
:param depth: Number of convolutional layers. | |||||
:param r: Number of weight tensor slices. | |||||
:param drop_rate: Dropout rate. | |||||
""" | |||||
def __init__(self, init_embedding: TokenEmbedding, ns=200, k_top=10, num_labels=2, depth=2, r=5, | |||||
dropout_rate=0.3): | |||||
super(CNTNModel, self).__init__() | |||||
self.embedding = init_embedding | |||||
self.depth = depth | |||||
self.kmaxpooling = DynamicKMaxPooling(k_top, depth) | |||||
self.conv_q = nn.ModuleList() | |||||
self.conv_a = nn.ModuleList() | |||||
width = self.embedding.embed_size | |||||
for i in range(depth): | |||||
self.conv_q.append(nn.Sequential( | |||||
nn.Dropout(p=dropout_rate), | |||||
nn.Conv2d( | |||||
in_channels=1, | |||||
out_channels=width // 2, | |||||
kernel_size=(width, 3), | |||||
padding=(0, 2)) | |||||
)) | |||||
self.conv_a.append(nn.Sequential( | |||||
nn.Dropout(p=dropout_rate), | |||||
nn.Conv2d( | |||||
in_channels=1, | |||||
out_channels=width // 2, | |||||
kernel_size=(width, 3), | |||||
padding=(0, 2)) | |||||
)) | |||||
width = width // 2 | |||||
self.fc_q = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns)) | |||||
self.fc_a = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(width * k_top, ns)) | |||||
self.weight_M = nn.Bilinear(ns, ns, r) | |||||
self.weight_V = nn.Linear(2 * ns, r) | |||||
self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels)) | |||||
def forward(self, words1, words2, seq_len1, seq_len2, target=None): | |||||
""" | |||||
:param words1: [batch, seq_len, emb_size] Question. | |||||
:param words2: [batch, seq_len, emb_size] Answer. | |||||
:param seq_len1: [batch] | |||||
:param seq_len2: [batch] | |||||
:param target: [batch] Glod labels. | |||||
:return: | |||||
""" | |||||
in_q = self.embedding(words1) | |||||
in_a = self.embedding(words2) | |||||
in_q = in_q.permute(0, 2, 1).unsqueeze(1) | |||||
in_a = in_a.permute(0, 2, 1).unsqueeze(1) | |||||
for i in range(self.depth): | |||||
in_q = F.relu(self.conv_q[i](in_q)) | |||||
in_q = in_q.squeeze().unsqueeze(1) | |||||
in_q = self.kmaxpooling(in_q, i + 1) | |||||
in_a = F.relu(self.conv_a[i](in_a)) | |||||
in_a = in_a.squeeze().unsqueeze(1) | |||||
in_a = self.kmaxpooling(in_a, i + 1) | |||||
in_q = self.fc_q(in_q.view(in_q.size(0), -1)) | |||||
in_a = self.fc_q(in_a.view(in_a.size(0), -1)) | |||||
score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1)))) | |||||
if target is not None: | |||||
loss_fct = CrossEntropyLoss() | |||||
loss = loss_fct(score, target) | |||||
return {Const.LOSS: loss, Const.OUTPUT: score} | |||||
else: | |||||
return {Const.OUTPUT: score} | |||||
def predict(self, **kwargs): | |||||
return self.forward(**kwargs) |
@@ -8,16 +8,23 @@ from fastNLP.core.const import Const as C | |||||
class IDCNN(nn.Module): | class IDCNN(nn.Module): | ||||
def __init__(self, init_embed, char_embed, | |||||
def __init__(self, | |||||
init_embed, | |||||
char_embed, | |||||
num_cls, | num_cls, | ||||
repeats, num_layers, num_filters, kernel_size, | repeats, num_layers, num_filters, kernel_size, | ||||
use_crf=False, use_projection=False, block_loss=False, | use_crf=False, use_projection=False, block_loss=False, | ||||
input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): | input_dropout=0.3, hidden_dropout=0.2, inner_dropout=0.0): | ||||
super(IDCNN, self).__init__() | super(IDCNN, self).__init__() | ||||
self.word_embeddings = Embedding(init_embed) | self.word_embeddings = Embedding(init_embed) | ||||
self.char_embeddings = Embedding(char_embed) | |||||
embedding_size = self.word_embeddings.embedding_dim + \ | |||||
self.char_embeddings.embedding_dim | |||||
if char_embed is None: | |||||
self.char_embeddings = None | |||||
embedding_size = self.word_embeddings.embedding_dim | |||||
else: | |||||
self.char_embeddings = Embedding(char_embed) | |||||
embedding_size = self.word_embeddings.embedding_dim + \ | |||||
self.char_embeddings.embedding_dim | |||||
self.conv0 = nn.Sequential( | self.conv0 = nn.Sequential( | ||||
nn.Conv1d(in_channels=embedding_size, | nn.Conv1d(in_channels=embedding_size, | ||||
@@ -31,7 +38,7 @@ class IDCNN(nn.Module): | |||||
block = [] | block = [] | ||||
for layer_i in range(num_layers): | for layer_i in range(num_layers): | ||||
dilated = 2 ** layer_i | |||||
dilated = 2 ** layer_i if layer_i+1 < num_layers else 1 | |||||
block.append(nn.Conv1d( | block.append(nn.Conv1d( | ||||
in_channels=num_filters, | in_channels=num_filters, | ||||
out_channels=num_filters, | out_channels=num_filters, | ||||
@@ -67,11 +74,24 @@ class IDCNN(nn.Module): | |||||
self.crf = ConditionalRandomField( | self.crf = ConditionalRandomField( | ||||
num_tags=num_cls) if use_crf else None | num_tags=num_cls) if use_crf else None | ||||
self.block_loss = block_loss | self.block_loss = block_loss | ||||
self.reset_parameters() | |||||
def forward(self, words, chars, seq_len, target=None): | |||||
e1 = self.word_embeddings(words) | |||||
e2 = self.char_embeddings(chars) | |||||
x = torch.cat((e1, e2), dim=-1) # b,l,h | |||||
def reset_parameters(self): | |||||
for m in self.modules(): | |||||
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): | |||||
nn.init.xavier_normal_(m.weight, gain=1) | |||||
if m.bias is not None: | |||||
nn.init.normal_(m.bias, mean=0, std=0.01) | |||||
def forward(self, words, seq_len, target=None, chars=None): | |||||
if self.char_embeddings is None: | |||||
x = self.word_embeddings(words) | |||||
else: | |||||
if chars is None: | |||||
raise ValueError('must provide chars for model with char embedding') | |||||
e1 = self.word_embeddings(words) | |||||
e2 = self.char_embeddings(chars) | |||||
x = torch.cat((e1, e2), dim=-1) # b,l,h | |||||
mask = seq_len_to_mask(seq_len) | mask = seq_len_to_mask(seq_len) | ||||
x = x.transpose(1, 2) # b,h,l | x = x.transpose(1, 2) # b,h,l | ||||
@@ -84,21 +104,24 @@ class IDCNN(nn.Module): | |||||
def compute_loss(y, t, mask): | def compute_loss(y, t, mask): | ||||
if self.crf is not None and target is not None: | if self.crf is not None and target is not None: | ||||
loss = self.crf(y, t, mask) | |||||
loss = self.crf(y.transpose(1, 2), t, mask) | |||||
else: | else: | ||||
t.masked_fill_(mask == 0, -100) | t.masked_fill_(mask == 0, -100) | ||||
loss = F.cross_entropy(y, t, ignore_index=-100) | loss = F.cross_entropy(y, t, ignore_index=-100) | ||||
return loss | return loss | ||||
if self.block_loss: | |||||
losses = [compute_loss(o, target, mask) for o in output] | |||||
loss = sum(losses) | |||||
if target is not None: | |||||
if self.block_loss: | |||||
losses = [compute_loss(o, target, mask) for o in output] | |||||
loss = sum(losses) | |||||
else: | |||||
loss = compute_loss(output[-1], target, mask) | |||||
else: | else: | ||||
loss = compute_loss(output[-1], target, mask) | |||||
loss = None | |||||
scores = output[-1] | scores = output[-1] | ||||
if self.crf is not None: | if self.crf is not None: | ||||
pred = self.crf.viterbi_decode(scores, target, mask) | |||||
pred, _ = self.crf.viterbi_decode(scores.transpose(1, 2), mask) | |||||
else: | else: | ||||
pred = scores.max(1)[1] * mask.long() | pred = scores.max(1)[1] * mask.long() | ||||
@@ -107,5 +130,13 @@ class IDCNN(nn.Module): | |||||
C.OUTPUT: pred, | C.OUTPUT: pred, | ||||
} | } | ||||
def predict(self, words, chars, seq_len): | |||||
return self.forward(words, chars, seq_len)[C.OUTPUT] | |||||
def predict(self, words, seq_len, chars=None): | |||||
res = self.forward( | |||||
words=words, | |||||
seq_len=seq_len, | |||||
chars=chars, | |||||
target=None | |||||
)[C.OUTPUT] | |||||
return { | |||||
C.OUTPUT: res | |||||
} |
@@ -0,0 +1,99 @@ | |||||
from reproduction.seqence_labelling.ner.data.OntoNoteLoader import OntoNoteNERDataLoader | |||||
from fastNLP.core.callback import FitlogCallback, LRScheduler | |||||
from fastNLP import GradientClipCallback | |||||
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR | |||||
from torch.optim import SGD, Adam | |||||
from fastNLP import Const | |||||
from fastNLP import RandomSampler, BucketSampler | |||||
from fastNLP import SpanFPreRecMetric | |||||
from fastNLP import Trainer | |||||
from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN | |||||
from fastNLP.core.utils import Option | |||||
from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding | |||||
from fastNLP.core.utils import cache_results | |||||
import sys | |||||
import torch.cuda | |||||
import os | |||||
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/' | |||||
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |||||
encoding_type = 'bioes' | |||||
def get_path(path): | |||||
return os.path.join(os.environ['HOME'], path) | |||||
data_path = get_path('workdir/datasets/ontonotes-v4') | |||||
ops = Option( | |||||
batch_size=128, | |||||
num_epochs=100, | |||||
lr=3e-4, | |||||
repeats=3, | |||||
num_layers=3, | |||||
num_filters=400, | |||||
use_crf=True, | |||||
gradient_clip=5, | |||||
) | |||||
@cache_results('ontonotes-cache') | |||||
def load_data(): | |||||
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process(data_path, | |||||
lower=True) | |||||
# char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], | |||||
# kernel_sizes=[3]) | |||||
word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], | |||||
model_dir_or_name='en-glove-840b-300', | |||||
requires_grad=True) | |||||
return data, [word_embed] | |||||
data, embeds = load_data() | |||||
print(data.datasets['train'][0]) | |||||
print(list(data.vocabs.keys())) | |||||
for ds in data.datasets.values(): | |||||
ds.rename_field('cap_words', 'chars') | |||||
ds.set_input('chars') | |||||
word_embed = embeds[0] | |||||
char_embed = CNNCharEmbedding(data.vocabs['cap_words']) | |||||
# for ds in data.datasets: | |||||
# ds.rename_field('') | |||||
print(data.vocabs[Const.TARGET].word2idx) | |||||
model = IDCNN(init_embed=word_embed, | |||||
char_embed=char_embed, | |||||
num_cls=len(data.vocabs[Const.TARGET]), | |||||
repeats=ops.repeats, | |||||
num_layers=ops.num_layers, | |||||
num_filters=ops.num_filters, | |||||
kernel_size=3, | |||||
use_crf=ops.use_crf, use_projection=True, | |||||
block_loss=True, | |||||
input_dropout=0.33, hidden_dropout=0.2, inner_dropout=0.2) | |||||
print(model) | |||||
callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='norm'),] | |||||
optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0) | |||||
# scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) | |||||
# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15))) | |||||
# optimizer = SWATS(model.parameters(), verbose=True) | |||||
# optimizer = Adam(model.parameters(), lr=0.005) | |||||
device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |||||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, | |||||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
device=device, dev_data=data.datasets['dev'], batch_size=ops.batch_size, | |||||
metrics=SpanFPreRecMetric( | |||||
tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), | |||||
check_code_level=-1, | |||||
callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs) | |||||
trainer.train() |
@@ -5,7 +5,8 @@ from fastNLP.core.vocabulary import VocabularyOption, Vocabulary | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader | ||||
import csv | |||||
from typing import Union, Dict | |||||
class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | URL = 'https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip' | ||||
@@ -97,3 +98,90 @@ class SSTLoader(DataSetLoader): | |||||
return info | return info | ||||
class sst2Loader(DataSetLoader): | |||||
''' | |||||
数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8', | |||||
''' | |||||
def __init__(self): | |||||
super(sst2Loader, self).__init__() | |||||
def _load(self, path: str) -> DataSet: | |||||
ds = DataSet() | |||||
all_count=0 | |||||
csv_reader = csv.reader(open(path, encoding='utf-8'),delimiter='\t') | |||||
skip_row = 0 | |||||
for idx,row in enumerate(csv_reader): | |||||
if idx<=skip_row: | |||||
continue | |||||
target = row[1] | |||||
words = row[0].split() | |||||
ds.append(Instance(words=words,target=target)) | |||||
all_count+=1 | |||||
print("all count:", all_count) | |||||
return ds | |||||
def process(self, | |||||
paths: Union[str, Dict[str, str]], | |||||
src_vocab_opt: VocabularyOption = None, | |||||
tgt_vocab_opt: VocabularyOption = None, | |||||
src_embed_opt: EmbeddingOption = None, | |||||
char_level_op=False): | |||||
paths = check_dataloader_paths(paths) | |||||
datasets = {} | |||||
info = DataInfo() | |||||
for name, path in paths.items(): | |||||
dataset = self.load(path) | |||||
datasets[name] = dataset | |||||
def wordtochar(words): | |||||
chars=[] | |||||
for word in words: | |||||
word=word.lower() | |||||
for char in word: | |||||
chars.append(char) | |||||
return chars | |||||
input_name, target_name = 'words', 'target' | |||||
info.vocabs={} | |||||
# 就分隔为char形式 | |||||
if char_level_op: | |||||
for dataset in datasets.values(): | |||||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||||
info.vocabs = { | |||||
"words": src_vocab, | |||||
"target": tgt_vocab | |||||
} | |||||
info.datasets = datasets | |||||
if src_embed_opt is not None: | |||||
embed = EmbedLoader.load_with_vocab(**src_embed_opt, vocab=src_vocab) | |||||
info.embeddings['words'] = embed | |||||
return info | |||||
if __name__=="__main__": | |||||
datapath = {"train": "/remote-home/ygwang/workspace/GLUE/SST-2/train.tsv", | |||||
"dev": "/remote-home/ygwang/workspace/GLUE/SST-2/dev.tsv"} | |||||
datainfo=sst2Loader().process(datapath,char_level_op=True) | |||||
#print(datainfo.datasets["train"]) | |||||
len_count = 0 | |||||
for instance in datainfo.datasets["train"]: | |||||
len_count += len(instance["chars"]) | |||||
ave_len = len_count / len(datainfo.datasets["train"]) | |||||
print(ave_len) |
@@ -8,19 +8,7 @@ from fastNLP.io.base_loader import DataInfo,DataSetLoader | |||||
from fastNLP.io.embed_loader import EmbeddingOption | from fastNLP.io.embed_loader import EmbeddingOption | ||||
from fastNLP.io.file_reader import _read_json | from fastNLP.io.file_reader import _read_json | ||||
from typing import Union, Dict | from typing import Union, Dict | ||||
from reproduction.utils import check_dataloader_paths | |||||
def get_tokenizer(): | |||||
try: | |||||
import spacy | |||||
en = spacy.load('en') | |||||
print('use spacy tokenizer') | |||||
return lambda x: [w.text for w in en.tokenizer(x)] | |||||
except Exception as e: | |||||
print('use raw tokenizer') | |||||
return lambda x: x.split() | |||||
from reproduction.utils import check_dataloader_paths, get_tokenizer | |||||
def clean_str(sentence, tokenizer, char_lower=False): | def clean_str(sentence, tokenizer, char_lower=False): | ||||
""" | """ | ||||
@@ -118,7 +106,7 @@ class yelpLoader(DataSetLoader): | |||||
print("all count:",all_count) | print("all count:",all_count) | ||||
return ds | return ds | ||||
''' | ''' | ||||
def _load(self, path): | def _load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
csv_reader=csv.reader(open(path,encoding='utf-8')) | csv_reader=csv.reader(open(path,encoding='utf-8')) | ||||
@@ -128,7 +116,7 @@ class yelpLoader(DataSetLoader): | |||||
all_count+=1 | all_count+=1 | ||||
if len(row)==2: | if len(row)==2: | ||||
target=self.tag_v[row[0]+".0"] | target=self.tag_v[row[0]+".0"] | ||||
words=clean_str(row[1],self.tokenizer,self.lower) | |||||
words = clean_str(row[1], self.tokenizer, self.lower) | |||||
if len(words)!=0: | if len(words)!=0: | ||||
ds.append(Instance(words=words,target=target)) | ds.append(Instance(words=words,target=target)) | ||||
real_count += 1 | real_count += 1 | ||||
@@ -1,4 +1,3 @@ | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.modules.utils import get_embeddings | from fastNLP.modules.utils import get_embeddings | ||||
@@ -11,13 +10,11 @@ class DPCNN(nn.Module): | |||||
super().__init__() | super().__init__() | ||||
self.region_embed = RegionEmbedding( | self.region_embed = RegionEmbedding( | ||||
init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) | init_embed, out_dim=n_filters, kernel_sizes=[1, 3, 5]) | ||||
embed_dim = self.region_embed.embedding_dim | embed_dim = self.region_embed.embedding_dim | ||||
self.conv_list = nn.ModuleList() | self.conv_list = nn.ModuleList() | ||||
for i in range(n_layers): | for i in range(n_layers): | ||||
self.conv_list.append(nn.Sequential( | self.conv_list.append(nn.Sequential( | ||||
nn.ReLU(), | nn.ReLU(), | ||||
nn.Conv1d(n_filters, n_filters, kernel_size, | nn.Conv1d(n_filters, n_filters, kernel_size, | ||||
padding=kernel_size//2), | padding=kernel_size//2), | ||||
nn.Conv1d(n_filters, n_filters, kernel_size, | nn.Conv1d(n_filters, n_filters, kernel_size, | ||||
@@ -27,12 +24,10 @@ class DPCNN(nn.Module): | |||||
self.embed_drop = nn.Dropout(embed_dropout) | self.embed_drop = nn.Dropout(embed_dropout) | ||||
self.classfier = nn.Sequential( | self.classfier = nn.Sequential( | ||||
nn.Dropout(cls_dropout), | nn.Dropout(cls_dropout), | ||||
nn.Linear(n_filters, num_cls), | nn.Linear(n_filters, num_cls), | ||||
) | ) | ||||
self.reset_parameters() | self.reset_parameters() | ||||
def reset_parameters(self): | def reset_parameters(self): | ||||
for m in self.modules(): | for m in self.modules(): | ||||
if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): | if isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Linear)): | ||||
@@ -40,7 +35,6 @@ class DPCNN(nn.Module): | |||||
if m.bias is not None: | if m.bias is not None: | ||||
nn.init.normal_(m.bias, mean=0, std=0.01) | nn.init.normal_(m.bias, mean=0, std=0.01) | ||||
def forward(self, words, seq_len=None): | def forward(self, words, seq_len=None): | ||||
words = words.long() | words = words.long() | ||||
# get region embeddings | # get region embeddings | ||||
@@ -58,20 +52,18 @@ class DPCNN(nn.Module): | |||||
x = self.classfier(x) | x = self.classfier(x) | ||||
return {C.OUTPUT: x} | return {C.OUTPUT: x} | ||||
def predict(self, words, seq_len=None): | def predict(self, words, seq_len=None): | ||||
x = self.forward(words, seq_len)[C.OUTPUT] | x = self.forward(words, seq_len)[C.OUTPUT] | ||||
return {C.OUTPUT: torch.argmax(x, 1)} | return {C.OUTPUT: torch.argmax(x, 1)} | ||||
class RegionEmbedding(nn.Module): | class RegionEmbedding(nn.Module): | ||||
def __init__(self, init_embed, out_dim=300, kernel_sizes=None): | def __init__(self, init_embed, out_dim=300, kernel_sizes=None): | ||||
super().__init__() | super().__init__() | ||||
if kernel_sizes is None: | if kernel_sizes is None: | ||||
kernel_sizes = [5, 9] | kernel_sizes = [5, 9] | ||||
assert isinstance( | assert isinstance( | ||||
kernel_sizes, list), 'kernel_sizes should be List(int)' | kernel_sizes, list), 'kernel_sizes should be List(int)' | ||||
self.embed = get_embeddings(init_embed) | self.embed = get_embeddings(init_embed) | ||||
try: | try: | ||||
embed_dim = self.embed.embedding_dim | embed_dim = self.embed.embedding_dim | ||||
@@ -103,4 +95,3 @@ if __name__ == '__main__': | |||||
model = DPCNN((10000, 300), 20) | model = DPCNN((10000, 300), 20) | ||||
y = model(x) | y = model(x) | ||||
print(y.size(), y.mean(1), y.std(1)) | print(y.size(), y.mean(1), y.std(1)) | ||||
@@ -9,6 +9,7 @@ from fastNLP import CrossEntropyLoss, AccuracyMetric | |||||
from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding | from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding | ||||
from reproduction.text_classification.model.dpcnn import DPCNN | from reproduction.text_classification.model.dpcnn import DPCNN | ||||
from data.yelpLoader import yelpLoader | from data.yelpLoader import yelpLoader | ||||
from fastNLP.core.sampler import BucketSampler | |||||
import torch.nn as nn | import torch.nn as nn | ||||
from fastNLP.core import LRScheduler | from fastNLP.core import LRScheduler | ||||
from fastNLP.core.const import Const as C | from fastNLP.core.const import Const as C | ||||
@@ -20,7 +21,6 @@ os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches' | |||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | ||||
# hyper | # hyper | ||||
class Config(): | class Config(): | ||||
@@ -29,19 +29,20 @@ class Config(): | |||||
embedding_grad = True | embedding_grad = True | ||||
train_epoch = 30 | train_epoch = 30 | ||||
batch_size = 100 | batch_size = 100 | ||||
num_classes = 2 | |||||
task = "yelp_p" | task = "yelp_p" | ||||
#datadir = '/remote-home/yfshao/workdir/datasets/SST' | |||||
datadir = '/remote-home/yfshao/workdir/datasets/yelp_polarity' | |||||
#datadir = 'workdir/datasets/SST' | |||||
datadir = 'workdir/datasets/yelp_polarity' | |||||
# datadir = 'workdir/datasets/yelp_full' | |||||
#datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} | #datafile = {"train": "train.txt", "dev": "dev.txt", "test": "test.txt"} | ||||
datafile = {"train": "train.csv", "test": "test.csv"} | datafile = {"train": "train.csv", "test": "test.csv"} | ||||
lr = 1e-3 | lr = 1e-3 | ||||
src_vocab_op = VocabularyOption() | |||||
src_vocab_op = VocabularyOption(max_size=100000) | |||||
embed_dropout = 0.3 | embed_dropout = 0.3 | ||||
cls_dropout = 0.1 | cls_dropout = 0.1 | ||||
weight_decay = 1e-4 | |||||
weight_decay = 1e-5 | |||||
def __init__(self): | def __init__(self): | ||||
self.datadir = os.path.join(os.environ['HOME'], self.datadir) | |||||
self.datapath = {k: os.path.join(self.datadir, v) | self.datapath = {k: os.path.join(self.datadir, v) | ||||
for k, v in self.datafile.items()} | for k, v in self.datafile.items()} | ||||
@@ -54,6 +55,8 @@ print('RNG SEED: {}'.format(ops.seed)) | |||||
# 1.task相关信息:利用dataloader载入dataInfo | # 1.task相关信息:利用dataloader载入dataInfo | ||||
#datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) | #datainfo=SSTLoader(fine_grained=True).process(paths=ops.datapath, train_ds=['train']) | ||||
@cache_results(ops.model_dir_or_name+'-data-cache') | @cache_results(ops.model_dir_or_name+'-data-cache') | ||||
def load_data(): | def load_data(): | ||||
datainfo = yelpLoader(fine_grained=True, lower=True).process( | datainfo = yelpLoader(fine_grained=True, lower=True).process( | ||||
@@ -62,31 +65,23 @@ def load_data(): | |||||
ds.apply_field(len, C.INPUT, C.INPUT_LEN) | ds.apply_field(len, C.INPUT, C.INPUT_LEN) | ||||
ds.set_input(C.INPUT, C.INPUT_LEN) | ds.set_input(C.INPUT, C.INPUT_LEN) | ||||
ds.set_target(C.TARGET) | ds.set_target(C.TARGET) | ||||
return datainfo | |||||
embedding = StaticEmbedding( | |||||
datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad, | |||||
normalize=False | |||||
) | |||||
return datainfo, embedding | |||||
datainfo = load_data() | |||||
datainfo, embedding = load_data() | |||||
# 2.或直接复用fastNLP的模型 | # 2.或直接复用fastNLP的模型 | ||||
vocab = datainfo.vocabs['words'] | |||||
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | # embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)]) | ||||
#embedding = StaticEmbedding(vocab) | |||||
embedding = StaticEmbedding( | |||||
vocab, model_dir_or_name='en-word2vec-300', requires_grad=ops.embedding_grad, | |||||
normalize=False | |||||
) | |||||
print(len(datainfo.datasets['train'])) | |||||
print(len(datainfo.datasets['test'])) | |||||
print(datainfo) | |||||
print(datainfo.datasets['train'][0]) | print(datainfo.datasets['train'][0]) | ||||
print(len(vocab)) | |||||
print(len(datainfo.vocabs['target'])) | |||||
model = DPCNN(init_embed=embedding, num_cls=ops.num_classes, | |||||
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]), | |||||
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout) | ||||
print(model) | print(model) | ||||
@@ -97,11 +92,11 @@ optimizer = SGD([param for param in model.parameters() if param.requires_grad == | |||||
lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) | lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay) | ||||
callbacks = [] | callbacks = [] | ||||
callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) | |||||
# callbacks.append | |||||
# LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < | |||||
# ops.train_epoch * 0.8 else ops.lr * 0.1)) | |||||
# ) | |||||
# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5))) | |||||
callbacks.append( | |||||
LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch < | |||||
ops.train_epoch * 0.8 else ops.lr * 0.1)) | |||||
) | |||||
# callbacks.append( | # callbacks.append( | ||||
# FitlogCallback(data=datainfo.datasets, verbose=1) | # FitlogCallback(data=datainfo.datasets, verbose=1) | ||||
@@ -113,6 +108,7 @@ print(device) | |||||
# 4.定义train方法 | # 4.定义train方法 | ||||
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, | ||||
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size), | |||||
metrics=[metric], | metrics=[metric], | ||||
dev_data=datainfo.datasets['test'], device=device, | dev_data=datainfo.datasets['test'], device=device, | ||||
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks, | ||||
@@ -122,4 +118,3 @@ trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=l | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
print(trainer.train()) | print(trainer.train()) | ||||
@@ -57,4 +57,13 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]: | |||||
else: | else: | ||||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | raise TypeError(f"paths only supports str and dict. not {type(paths)}.") | ||||
def get_tokenizer(): | |||||
try: | |||||
import spacy | |||||
spacy.prefer_gpu() | |||||
en = spacy.load('en') | |||||
print('use spacy tokenizer') | |||||
return lambda x: [w.text for w in en.tokenizer(x)] | |||||
except Exception as e: | |||||
print('use raw tokenizer') | |||||
return lambda x: x.split() |