diff --git a/reproduction/text_classification/model/BertTC.py b/reproduction/text_classification/model/BertTC.py new file mode 100644 index 00000000..702c0cd1 --- /dev/null +++ b/reproduction/text_classification/model/BertTC.py @@ -0,0 +1,24 @@ +from fastNLP.embeddings import BertEmbedding +import torch +import torch.nn as nn +from fastNLP.core.const import Const as C + +class BertTC(nn.Module): + def __init__(self, vocab,num_class,bert_model_dir_or_name,fine_tune=False): + super(BertTC, self).__init__() + self.embed=BertEmbedding(vocab, requires_grad=fine_tune, + model_dir_or_name=bert_model_dir_or_name,include_cls_sep=True) + self.classifier = nn.Linear(self.embed.embedding_dim, num_class) + + def forward(self, words): + embedding_cls=self.embed(words)[:,0] + output=self.classifier(embedding_cls) + return {C.OUTPUT: output} + + def predict(self,words): + return self.forward(words) + +if __name__=="__main__": + ta=torch.tensor([[1,2,3],[4,5,6],[7,8,9]]) + tb=ta[:,0] + print(tb) diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py index 3482de70..6b56608a 100644 --- a/reproduction/text_classification/train_char_cnn.py +++ b/reproduction/text_classification/train_char_cnn.py @@ -8,6 +8,7 @@ sys.path.append('../..') from fastNLP.core.const import Const as C import torch.nn as nn from fastNLP.io.data_loader import YelpLoader +from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe #from data.sstLoader import sst2Loader from model.char_cnn import CharacterLevelCNN from fastNLP import CrossEntropyLoss, AccuracyMetric @@ -46,6 +47,8 @@ class Config(): extra_characters='' max_length=1014 weight_decay = 1e-5 + to_lower=True + tokenizer = 'spacy' # 使用spacy进行分词 char_cnn_config={ "alphabet": { @@ -111,12 +114,35 @@ ops=Config ##1.task相关信息:利用dataloader载入dataInfo #dataloader=SST2Loader() #dataloader=IMDBLoader() -dataloader=YelpLoader(fine_grained=True) -datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) +# dataloader=YelpLoader(fine_grained=True) +# datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False) char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"] ops.number_of_characters=len(char_vocab) ops.embedding_dim=ops.number_of_characters +# load data set +if ops.task == 'yelp_p': + data_bundle = YelpPolarityPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() +elif ops.task == 'yelp_f': + data_bundle = YelpFullPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() +elif ops.task == 'imdb': + data_bundle = IMDBPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() +elif ops.task == 'sst-2': + data_bundle = SST2Pipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file() +else: + raise RuntimeError(f'NOT support {ops.task} task yet!') + + +def wordtochar(words): + chars = [] + for word in words: + #word = word.lower() + for char in word: + chars.append(char) + chars.append('') + chars.pop() + return chars + #chartoindex def chartoindex(chars): max_seq_len=ops.max_length @@ -136,13 +162,14 @@ def chartoindex(chars): char_index_list=[zero_index]*max_seq_len return char_index_list -for dataset in datainfo.datasets.values(): +for dataset in data_bundle.datasets.values(): + dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars') dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars') -datainfo.datasets['train'].set_input('chars') -datainfo.datasets['test'].set_input('chars') -datainfo.datasets['train'].set_target('target') -datainfo.datasets['test'].set_target('target') +data_bundle.datasets['train'].set_input('chars') +data_bundle.datasets['test'].set_input('chars') +data_bundle.datasets['train'].set_target('target') +data_bundle.datasets['test'].set_target('target') ##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model class ModelFactory(nn.Module): @@ -165,7 +192,7 @@ class ModelFactory(nn.Module): ## 2.或直接复用fastNLP的模型 #vocab=datainfo.vocabs['words'] -vocab_label=datainfo.vocabs['target'] +vocab_label=data_bundle.vocabs['target'] ''' # emded_char=CNNCharEmbedding(vocab) # embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) @@ -212,5 +239,5 @@ if __name__=="__main__": #print(vocab_label) #print(datainfo.datasets["train"]) - train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch) + train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch) \ No newline at end of file