|
- import sys
- sys.path.append('../../')
-
- from reproduction.text_classification.data.IMDBLoader import IMDBLoader
- from fastNLP.embeddings import BertEmbedding
- from reproduction.text_classification.model.lstm import BiLSTMSentiment
- from fastNLP import Trainer
- from fastNLP import CrossEntropyLoss, AccuracyMetric
- from fastNLP import cache_results
- from fastNLP import Tester
-
- # 对返回结果进行缓存,下一次运行就会自动跳过预处理
- @cache_results('imdb.pkl')
- def get_data():
- data_bundle = IMDBLoader().process('imdb/')
- return data_bundle
- data_bundle = get_data()
-
- print(data_bundle)
-
- # 删除超过512, 但由于英语中会把word进行word piece处理,所以截取的时候做一点的裕量
- data_bundle.datasets['train'].drop(lambda x:len(x['words'])>400)
- data_bundle.datasets['dev'].drop(lambda x:len(x['words'])>400)
- data_bundle.datasets['test'].drop(lambda x:len(x['words'])>400)
- bert_embed = BertEmbedding(data_bundle.vocabs['words'], requires_grad=False,
- model_dir_or_name="en-base-uncased")
- model = BiLSTMSentiment(bert_embed, len(data_bundle.vocabs['target']))
-
- Trainer(data_bundle.datasets['train'], model, optimizer=None, loss=CrossEntropyLoss(), device=0,
- batch_size=10, dev_data=data_bundle.datasets['dev'], metrics=AccuracyMetric()).train()
-
- # 在测试集上测试一下效果
- Tester(data_bundle.datasets['test'], model, batch_size=32, metrics=AccuracyMetric()).test()
|