| @@ -20,87 +20,10 @@ fastNLP is a modular Natural Language Processing system based on PyTorch, for fa | |||
| ## Resources | |||
| - [Documentation](https://github.com/fastnlp/fastNLP) | |||
| - [Documentation](https://fastnlp.readthedocs.io/en/latest/) | |||
| - [Source Code](https://github.com/fastnlp/fastNLP) | |||
| ## Example | |||
| ### Basic Usage | |||
| A typical fastNLP routine is composed of four phases: loading dataset, pre-processing data, constructing model and training model. | |||
| ```python | |||
| from fastNLP.models.base_model import BaseModel | |||
| from fastNLP.modules import encoder | |||
| from fastNLP.modules import aggregation | |||
| from fastNLP.modules import decoder | |||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||
| from fastNLP.loader.preprocess import ClassPreprocess | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.core.inference import ClassificationInfer | |||
| class ClassificationModel(BaseModel): | |||
| """ | |||
| Simple text classification model based on CNN. | |||
| """ | |||
| def __init__(self, num_classes, vocab_size): | |||
| super(ClassificationModel, self).__init__() | |||
| self.emb = encoder.Embedding(nums=vocab_size, dims=300) | |||
| self.enc = encoder.Conv( | |||
| in_channels=300, out_channels=100, kernel_size=3) | |||
| self.agg = aggregation.MaxPool() | |||
| self.dec = decoder.MLP(100, num_classes=num_classes) | |||
| def forward(self, x): | |||
| x = self.emb(x) # [N,L] -> [N,L,C] | |||
| x = self.enc(x) # [N,L,C_in] -> [N,L,C_out] | |||
| x = self.agg(x) # [N,L,C] -> [N,C] | |||
| x = self.dec(x) # [N,C] -> [N, N_class] | |||
| return x | |||
| data_dir = 'data' # directory to save data and model | |||
| train_path = 'test/data_for_tests/text_classify.txt' # training set file | |||
| # load dataset | |||
| ds_loader = ClassDatasetLoader("train", train_path) | |||
| data = ds_loader.load() | |||
| # pre-process dataset | |||
| pre = ClassPreprocess(data_dir) | |||
| vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||
| # construct model | |||
| model_args = { | |||
| 'num_classes': n_classes, | |||
| 'vocab_size': vocab_size | |||
| } | |||
| model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size) | |||
| # train model | |||
| train_args = { | |||
| "epochs": 20, | |||
| "batch_size": 50, | |||
| "pickle_path": data_dir, | |||
| "validate": False, | |||
| "save_best_dev": False, | |||
| "model_saved_path": None, | |||
| "use_cuda": True, | |||
| "learn_rate": 1e-3, | |||
| "momentum": 0.9} | |||
| trainer = ClassificationTrainer(train_args) | |||
| trainer.train(model) | |||
| # predict using model | |||
| seqs = [x[0] for x in data] | |||
| infer = ClassificationInfer(data_dir) | |||
| labels_pred = infer.predict(model, seqs) | |||
| ``` | |||
| ## Installation | |||
| @@ -1,3 +1,4 @@ | |||
| sphinx | |||
| -e git://github.com/snide/sphinx_rtd_theme.git#egg=sphinx_rtd_theme | |||
| sphinxcontrib.katex | |||
| numpy>=1.14.2 | |||
| http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl | |||
| torchvision>=0.1.8 | |||
| sphinx-rtd-theme==0.4.1 | |||
| @@ -42,6 +42,8 @@ release = '1.0' | |||
| extensions = [ | |||
| 'sphinx.ext.autodoc', | |||
| 'sphinx.ext.viewcode', | |||
| 'sphinx.ext.autosummary', | |||
| ] | |||
| # Add any paths that contain templates here, relative to this directory. | |||
| @@ -1,62 +1,54 @@ | |||
| fastNLP.core package | |||
| ==================== | |||
| fastNLP.core | |||
| ============= | |||
| Submodules | |||
| ---------- | |||
| fastNLP.core.action module | |||
| -------------------------- | |||
| fastNLP.core.action | |||
| -------------------- | |||
| .. automodule:: fastNLP.core.action | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.core.metrics module | |||
| --------------------------- | |||
| fastNLP.core.loss | |||
| ------------------ | |||
| .. automodule:: fastNLP.core.loss | |||
| :members: | |||
| fastNLP.core.metrics | |||
| --------------------- | |||
| .. automodule:: fastNLP.core.metrics | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.core.optimizer module | |||
| ----------------------------- | |||
| fastNLP.core.optimizer | |||
| ----------------------- | |||
| .. automodule:: fastNLP.core.optimizer | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.core.predictor module | |||
| ----------------------------- | |||
| fastNLP.core.predictor | |||
| ----------------------- | |||
| .. automodule:: fastNLP.core.predictor | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.core.tester module | |||
| -------------------------- | |||
| fastNLP.core.preprocess | |||
| ------------------------ | |||
| .. automodule:: fastNLP.core.preprocess | |||
| :members: | |||
| fastNLP.core.tester | |||
| -------------------- | |||
| .. automodule:: fastNLP.core.tester | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.core.trainer module | |||
| --------------------------- | |||
| fastNLP.core.trainer | |||
| --------------------- | |||
| .. automodule:: fastNLP.core.trainer | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.core | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,62 +1,36 @@ | |||
| fastNLP.loader package | |||
| ====================== | |||
| fastNLP.loader | |||
| =============== | |||
| Submodules | |||
| ---------- | |||
| fastNLP.loader.base\_loader module | |||
| ---------------------------------- | |||
| fastNLP.loader.base\_loader | |||
| ---------------------------- | |||
| .. automodule:: fastNLP.loader.base_loader | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.loader.config\_loader module | |||
| ------------------------------------ | |||
| fastNLP.loader.config\_loader | |||
| ------------------------------ | |||
| .. automodule:: fastNLP.loader.config_loader | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.loader.dataset\_loader module | |||
| ------------------------------------- | |||
| fastNLP.loader.dataset\_loader | |||
| ------------------------------- | |||
| .. automodule:: fastNLP.loader.dataset_loader | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.loader.embed\_loader module | |||
| ----------------------------------- | |||
| fastNLP.loader.embed\_loader | |||
| ----------------------------- | |||
| .. automodule:: fastNLP.loader.embed_loader | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.loader.model\_loader module | |||
| ----------------------------------- | |||
| fastNLP.loader.model\_loader | |||
| ----------------------------- | |||
| .. automodule:: fastNLP.loader.model_loader | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.loader.preprocess module | |||
| -------------------------------- | |||
| .. automodule:: fastNLP.loader.preprocess | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.loader | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,46 +1,30 @@ | |||
| fastNLP.models package | |||
| ====================== | |||
| fastNLP.models | |||
| =============== | |||
| Submodules | |||
| ---------- | |||
| fastNLP.models.base\_model module | |||
| --------------------------------- | |||
| fastNLP.models.base\_model | |||
| --------------------------- | |||
| .. automodule:: fastNLP.models.base_model | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.models.char\_language\_model module | |||
| ------------------------------------------- | |||
| fastNLP.models.char\_language\_model | |||
| ------------------------------------- | |||
| .. automodule:: fastNLP.models.char_language_model | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.models.cnn\_text\_classification module | |||
| ----------------------------------------------- | |||
| fastNLP.models.cnn\_text\_classification | |||
| ----------------------------------------- | |||
| .. automodule:: fastNLP.models.cnn_text_classification | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.models.sequence\_modeling module | |||
| ---------------------------------------- | |||
| fastNLP.models.sequence\_modeling | |||
| ---------------------------------- | |||
| .. automodule:: fastNLP.models.sequence_modeling | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.models | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,54 +1,36 @@ | |||
| fastNLP.modules.aggregation package | |||
| =================================== | |||
| fastNLP.modules.aggregation | |||
| ============================ | |||
| Submodules | |||
| ---------- | |||
| fastNLP.modules.aggregation.attention module | |||
| -------------------------------------------- | |||
| fastNLP.modules.aggregation.attention | |||
| -------------------------------------- | |||
| .. automodule:: fastNLP.modules.aggregation.attention | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.aggregation.avg\_pool module | |||
| -------------------------------------------- | |||
| fastNLP.modules.aggregation.avg\_pool | |||
| -------------------------------------- | |||
| .. automodule:: fastNLP.modules.aggregation.avg_pool | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.aggregation.kmax\_pool module | |||
| --------------------------------------------- | |||
| fastNLP.modules.aggregation.kmax\_pool | |||
| --------------------------------------- | |||
| .. automodule:: fastNLP.modules.aggregation.kmax_pool | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.aggregation.max\_pool module | |||
| -------------------------------------------- | |||
| fastNLP.modules.aggregation.max\_pool | |||
| -------------------------------------- | |||
| .. automodule:: fastNLP.modules.aggregation.max_pool | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.aggregation.self\_attention module | |||
| -------------------------------------------------- | |||
| fastNLP.modules.aggregation.self\_attention | |||
| -------------------------------------------- | |||
| .. automodule:: fastNLP.modules.aggregation.self_attention | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.modules.aggregation | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,22 +1,18 @@ | |||
| fastNLP.modules.decoder package | |||
| =============================== | |||
| fastNLP.modules.decoder | |||
| ======================== | |||
| Submodules | |||
| ---------- | |||
| fastNLP.modules.decoder.CRF module | |||
| ---------------------------------- | |||
| fastNLP.modules.decoder.CRF | |||
| ---------------------------- | |||
| .. automodule:: fastNLP.modules.decoder.CRF | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.decoder.MLP | |||
| ---------------------------- | |||
| .. automodule:: fastNLP.modules.decoder.MLP | |||
| :members: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.modules.decoder | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,78 +1,54 @@ | |||
| fastNLP.modules.encoder package | |||
| =============================== | |||
| fastNLP.modules.encoder | |||
| ======================== | |||
| Submodules | |||
| ---------- | |||
| fastNLP.modules.encoder.char\_embedding module | |||
| ---------------------------------------------- | |||
| fastNLP.modules.encoder.char\_embedding | |||
| ---------------------------------------- | |||
| .. automodule:: fastNLP.modules.encoder.char_embedding | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.encoder.conv module | |||
| ----------------------------------- | |||
| fastNLP.modules.encoder.conv | |||
| ----------------------------- | |||
| .. automodule:: fastNLP.modules.encoder.conv | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.encoder.conv\_maxpool module | |||
| -------------------------------------------- | |||
| fastNLP.modules.encoder.conv\_maxpool | |||
| -------------------------------------- | |||
| .. automodule:: fastNLP.modules.encoder.conv_maxpool | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.encoder.embedding module | |||
| ---------------------------------------- | |||
| fastNLP.modules.encoder.embedding | |||
| ---------------------------------- | |||
| .. automodule:: fastNLP.modules.encoder.embedding | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.encoder.linear module | |||
| ------------------------------------- | |||
| fastNLP.modules.encoder.linear | |||
| ------------------------------- | |||
| .. automodule:: fastNLP.modules.encoder.linear | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.encoder.lstm module | |||
| ----------------------------------- | |||
| fastNLP.modules.encoder.lstm | |||
| ----------------------------- | |||
| .. automodule:: fastNLP.modules.encoder.lstm | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.encoder.masked\_rnn module | |||
| ------------------------------------------ | |||
| fastNLP.modules.encoder.masked\_rnn | |||
| ------------------------------------ | |||
| .. automodule:: fastNLP.modules.encoder.masked_rnn | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.encoder.variational\_rnn module | |||
| ----------------------------------------------- | |||
| fastNLP.modules.encoder.variational\_rnn | |||
| ----------------------------------------- | |||
| .. automodule:: fastNLP.modules.encoder.variational_rnn | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.modules.encoder | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,10 +1,5 @@ | |||
| fastNLP.modules.interaction package | |||
| =================================== | |||
| Module contents | |||
| --------------- | |||
| fastNLP.modules.interaction | |||
| ============================ | |||
| .. automodule:: fastNLP.modules.interaction | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,8 +1,5 @@ | |||
| fastNLP.modules package | |||
| ======================= | |||
| Subpackages | |||
| ----------- | |||
| fastNLP.modules | |||
| ================ | |||
| .. toctree:: | |||
| @@ -11,30 +8,18 @@ Subpackages | |||
| fastNLP.modules.encoder | |||
| fastNLP.modules.interaction | |||
| Submodules | |||
| ---------- | |||
| fastNLP.modules.other\_modules module | |||
| ------------------------------------- | |||
| fastNLP.modules.other\_modules | |||
| ------------------------------- | |||
| .. automodule:: fastNLP.modules.other_modules | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.modules.utils module | |||
| ---------------------------- | |||
| fastNLP.modules.utils | |||
| ---------------------- | |||
| .. automodule:: fastNLP.modules.utils | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.modules | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,8 +1,5 @@ | |||
| fastNLP package | |||
| =============== | |||
| Subpackages | |||
| ----------- | |||
| fastNLP | |||
| ======== | |||
| .. toctree:: | |||
| @@ -12,22 +9,12 @@ Subpackages | |||
| fastNLP.modules | |||
| fastNLP.saver | |||
| Submodules | |||
| ---------- | |||
| fastNLP.fastnlp module | |||
| ---------------------- | |||
| fastNLP.fastnlp | |||
| ---------------- | |||
| .. automodule:: fastNLP.fastnlp | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,30 +1,18 @@ | |||
| fastNLP.saver package | |||
| ===================== | |||
| fastNLP.saver | |||
| ============== | |||
| Submodules | |||
| ---------- | |||
| fastNLP.saver.logger module | |||
| --------------------------- | |||
| fastNLP.saver.logger | |||
| --------------------- | |||
| .. automodule:: fastNLP.saver.logger | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| fastNLP.saver.model\_saver module | |||
| --------------------------------- | |||
| fastNLP.saver.model\_saver | |||
| --------------------------- | |||
| .. automodule:: fastNLP.saver.model_saver | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| Module contents | |||
| --------------- | |||
| .. automodule:: fastNLP.saver | |||
| :members: | |||
| :undoc-members: | |||
| :show-inheritance: | |||
| @@ -1,16 +1,54 @@ | |||
| .. fastNLP documentation master file, created by | |||
| sphinx-quickstart on Mon Aug 20 17:06:44 2018. | |||
| You can adapt this file completely to your liking, but it should at least | |||
| contain the root `toctree` directive. | |||
| fastNLP documentation | |||
| ===================== | |||
| fastNLP,目前仍在孵化中。 | |||
| Welcome to fastNLP's documentation! | |||
| =================================== | |||
| Introduction | |||
| ------------ | |||
| fastNLP是一个基于PyTorch的模块化自然语言处理系统,用于快速开发NLP工具。 | |||
| 它将基于深度学习的NLP模型划分为不同的模块。 | |||
| 这些模块分为4类:encoder(编码),interaction(交互), aggregration(聚合) and decoder(解码), | |||
| 而每个类别包含不同的实现模块。 | |||
| 大多数当前的NLP模型可以构建在这些模块上,这极大地简化了开发NLP模型的过程。 | |||
| fastNLP的架构如下左图所示: | |||
| .. image:: figures/procedures_and_sequence_labeling.png | |||
| 在constructing model部分,以序列标注(上右图)和文本分类(下图)为例进行说明: | |||
| .. image:: figures/text_classification.png | |||
| * encoder module:将输入编码为一些抽象表示,输入的是单词序列,输出向量序列。 | |||
| * interaction module:使表示中的信息相互交互,输入的是向量序列,输出的也是向量序列。 | |||
| * aggregation module:聚合和减少信息,输入向量序列,输出一个向量。 | |||
| * decoder module:将表示解码为输出,输出一个label(文本分类)或者输出label序列(序列标注) | |||
| 其中interaction module和aggregation module在模型中不一定存在,例如上面的序列标注模型。 | |||
| User's Guide | |||
| ------------ | |||
| .. toctree:: | |||
| :maxdepth: 2 | |||
| user/installation | |||
| user/quickstart | |||
| API Reference | |||
| ------------- | |||
| If you are looking for information on a specific function, class or | |||
| method, this part of the documentation is for you. | |||
| .. toctree:: | |||
| :maxdepth: 4 | |||
| :caption: Contents: | |||
| :maxdepth: 2 | |||
| fastNLP | |||
| fastNLP API <fastNLP> | |||
| @@ -1,7 +0,0 @@ | |||
| fastNLP | |||
| ======= | |||
| .. toctree:: | |||
| :maxdepth: 4 | |||
| fastNLP | |||
| @@ -0,0 +1,31 @@ | |||
| ============ | |||
| Installation | |||
| ============ | |||
| .. contents:: | |||
| :local: | |||
| Cloning From GitHub | |||
| ~~~~~~~~~~~~~~~~~~~ | |||
| If you just want to use fastNLP, use: | |||
| .. code:: shell | |||
| git clone https://github.com/fastnlp/fastNLP | |||
| cd fastNLP | |||
| PyTorch Installation | |||
| ~~~~~~~~~~~~~~~~~~~~ | |||
| Visit the [PyTorch official website] for installation instructions based | |||
| on your system. In general, you could use: | |||
| .. code:: shell | |||
| # using conda | |||
| conda install pytorch torchvision -c pytorch | |||
| # or using pip | |||
| pip3 install torch torchvision | |||
| @@ -0,0 +1,84 @@ | |||
| ========== | |||
| Quickstart | |||
| ========== | |||
| Example | |||
| ------- | |||
| Basic Usage | |||
| ~~~~~~~~~~~ | |||
| A typical fastNLP routine is composed of four phases: loading dataset, | |||
| pre-processing data, constructing model and training model. | |||
| .. code:: python | |||
| from fastNLP.models.base_model import BaseModel | |||
| from fastNLP.modules import encoder | |||
| from fastNLP.modules import aggregation | |||
| from fastNLP.modules import decoder | |||
| from fastNLP.loader.dataset_loader import ClassDatasetLoader | |||
| from fastNLP.loader.preprocess import ClassPreprocess | |||
| from fastNLP.core.trainer import ClassificationTrainer | |||
| from fastNLP.core.inference import ClassificationInfer | |||
| class ClassificationModel(BaseModel): | |||
| """ | |||
| Simple text classification model based on CNN. | |||
| """ | |||
| def __init__(self, num_classes, vocab_size): | |||
| super(ClassificationModel, self).__init__() | |||
| self.emb = encoder.Embedding(nums=vocab_size, dims=300) | |||
| self.enc = encoder.Conv( | |||
| in_channels=300, out_channels=100, kernel_size=3) | |||
| self.agg = aggregation.MaxPool() | |||
| self.dec = decoder.MLP(100, num_classes=num_classes) | |||
| def forward(self, x): | |||
| x = self.emb(x) # [N,L] -> [N,L,C] | |||
| x = self.enc(x) # [N,L,C_in] -> [N,L,C_out] | |||
| x = self.agg(x) # [N,L,C] -> [N,C] | |||
| x = self.dec(x) # [N,C] -> [N, N_class] | |||
| return x | |||
| data_dir = 'data' # directory to save data and model | |||
| train_path = 'test/data_for_tests/text_classify.txt' # training set file | |||
| # load dataset | |||
| ds_loader = ClassDatasetLoader("train", train_path) | |||
| data = ds_loader.load() | |||
| # pre-process dataset | |||
| pre = ClassPreprocess(data_dir) | |||
| vocab_size, n_classes = pre.process(data, "data_train.pkl") | |||
| # construct model | |||
| model_args = { | |||
| 'num_classes': n_classes, | |||
| 'vocab_size': vocab_size | |||
| } | |||
| model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size) | |||
| # train model | |||
| train_args = { | |||
| "epochs": 20, | |||
| "batch_size": 50, | |||
| "pickle_path": data_dir, | |||
| "validate": False, | |||
| "save_best_dev": False, | |||
| "model_saved_path": None, | |||
| "use_cuda": True, | |||
| "learn_rate": 1e-3, | |||
| "momentum": 0.9} | |||
| trainer = ClassificationTrainer(train_args) | |||
| trainer.train(model) | |||
| # predict using model | |||
| seqs = [x[0] for x in data] | |||
| infer = ClassificationInfer(data_dir) | |||
| labels_pred = infer.predict(model, seqs) | |||