Browse Source

Merge pull request #22 from fastnlp/dev/ner

Dev/ner
tags/v0.1.0
Coet GitHub 6 years ago
parent
commit
3560fb1f67
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 619 additions and 84 deletions
  1. +5
    -5
      fastNLP/core/inference.py
  2. +2
    -5
      fastNLP/core/tester.py
  3. +2
    -8
      fastNLP/core/trainer.py
  4. +93
    -44
      fastNLP/fastnlp.py
  5. +9
    -4
      fastNLP/models/cnn_text_classification.py
  6. +46
    -0
      fastNLP/models/sequence_modeling.py
  7. +17
    -2
      test/data_for_tests/config
  8. +154
    -0
      test/data_for_tests/people.txt
  9. +137
    -0
      test/ner.py
  10. +129
    -0
      test/ner_decode.py
  11. +2
    -2
      test/seq_labeling.py
  12. +12
    -3
      test/test_fastNLP.py
  13. +11
    -11
      test/text_classify.py

+ 5
- 5
fastNLP/core/inference.py View File

@@ -63,7 +63,7 @@ class Inference(object):
"""
Perform inference.
:param network:
:param data: multi-level lists of strings
:param data: two-level lists of strings
:return result: the model outputs
"""
# transform strings into indices
@@ -97,7 +97,7 @@ class Inference(object):

def prepare_input(self, data):
"""
Transform three-level list of strings into that of index.
Transform two-level list of strings into that of index.
:param data:
[
[word_11, word_12, ...],
@@ -140,7 +140,7 @@ class SeqLabelInfer(Inference):
mask = mask.byte().view(batch_size, max_len)
y = network(x)
prediction = network.prediction(y, mask)
return torch.Tensor(prediction, required_grad=False)
return torch.Tensor(prediction)

def make_batch(self, iterator, data, use_cuda):
return make_batch(iterator, data, use_cuda, output_length=True)
@@ -149,7 +149,7 @@ class SeqLabelInfer(Inference):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length].
:return:
:return results: 2-D list of strings
"""
results = []
for batch in batch_outputs:
@@ -178,7 +178,7 @@ class ClassificationInfer(Inference):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
:return:
:return results: list of strings
"""
results = []
for batch_out in batch_outputs:


+ 2
- 5
fastNLP/core/tester.py View File

@@ -37,10 +37,6 @@ class BaseTester(object):
else:
self.model = network

# no backward setting for model
for param in network.parameters():
param.requires_grad = False

# turn on the testing mode; clean up the history
self.mode(network, test=True)
self.eval_history.clear()
@@ -112,6 +108,7 @@ class SeqLabelTester(BaseTester):
super(SeqLabelTester, self).__init__(test_args)
self.max_len = None
self.mask = None
self.seq_len = None
self.batch_result = None

def data_forward(self, network, inputs):
@@ -125,7 +122,7 @@ class SeqLabelTester(BaseTester):
if torch.cuda.is_available() and self.use_cuda:
mask = mask.cuda()
self.mask = mask
self.seq_len = seq_len
y = network(x)
return y



+ 2
- 8
fastNLP/core/trainer.py View File

@@ -315,14 +315,8 @@ class ClassificationTrainer(BaseTrainer):

def __init__(self, train_args):
super(ClassificationTrainer, self).__init__(train_args)
if "learn_rate" in train_args:
self.learn_rate = train_args["learn_rate"]
else:
self.learn_rate = 1e-3
if "momentum" in train_args:
self.momentum = train_args["momentum"]
else:
self.momentum = 0.9
self.learn_rate = train_args["learn_rate"]
self.momentum = train_args["momentum"]

self.iterator = None
self.loss_func = None


+ 93
- 44
fastNLP/fastnlp.py View File

@@ -1,4 +1,4 @@
from fastNLP.core.inference import Inference
from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader

@@ -10,14 +10,28 @@ Example:
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"]
"""
FastNLP_MODEL_COLLECTION = {
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"]
"seq_label_model": {
"url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling",
"pickle": "seq_label_model.pkl",
"type": "seq_label"
},
"text_class_model": {
"url": "www.fudan.edu.cn",
"class": "cnn_text_classification.CNNText",
"pickle": "text_class_model.pkl",
"type": "text_class"
}
}

CONFIG_FILE_NAME = "config"
SECTION_NAME = "text_class_model"


class FastNLP(object):
"""
High-level interface for direct model inference.
Usage:
Example Usage:
fastnlp = FastNLP()
fastnlp.load("zh_pos_tag_model")
text = "这是最好的基于深度学习的中文分词系统。"
@@ -35,6 +49,7 @@ class FastNLP(object):
"""
self.model_dir = model_dir
self.model = None
self.infer_type = None # "seq_label"/"text_class"

def load(self, model_name):
"""
@@ -46,21 +61,21 @@ class FastNLP(object):
raise ValueError("No FastNLP model named {}.".format(model_name))

if not self.model_exist(model_dir=self.model_dir):
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0])
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"])

model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1])
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])

model_args = ConfigSection()
# To do: customized config file for model init parameters
ConfigLoader.load_config(self.model_dir + "config", {"POS_infer": model_args})
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args})

# Construct the model
model = model_class(model_args)

# To do: framework independent
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name][2])
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"])

self.model = model
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"]

print("Model loaded. ")

@@ -71,12 +86,16 @@ class FastNLP(object):
:return results:
"""

infer = Inference(self.model_dir)
infer = self._create_inference(self.model_dir)

# string ---> 2-D list of string
infer_input = self.string_to_list(raw_input)

# 2-D list of string ---> list of strings
results = infer.predict(self.model, infer_input)

outputs = self.make_output(results)
# list of strings ---> final answers
outputs = self._make_output(results, infer_input)
return outputs

@staticmethod
@@ -95,6 +114,14 @@ class FastNLP(object):
module = getattr(module, sub)
return module

def _create_inference(self, model_dir):
if self.infer_type == "seq_label":
return SeqLabelInfer(model_dir)
elif self.infer_type == "text_class":
return ClassificationInfer(model_dir)
else:
raise ValueError("fail to create inference instance")

def _load(self, model_dir, model_name):
# To do
return 0
@@ -117,7 +144,6 @@ class FastNLP(object):

def string_to_list(self, text, delimiter="\n"):
"""
For word seg only, currently.
This function is used to transform raw input to lists, which is done by DatasetLoader in training.
Split text string into three-level lists.
[
@@ -127,7 +153,7 @@ class FastNLP(object):
]
:param text: string
:param delimiter: str, character used to split text into sentences.
:return data: three-level lists
:return data: two-level lists
"""
data = []
sents = text.strip().split(delimiter)
@@ -136,38 +162,61 @@ class FastNLP(object):
for ch in sent:
characters.append(ch)
data.append(characters)
# To refactor: this is used in make_output
self.data = data
return data

def make_output(self, results):
"""
Transform model output into user-friendly contents.
Example: In CWS, convert <BMES> labeling into segmented text.
:param results:
:return:
"""
outputs = []
for sent_char, sent_label in zip(self.data, results):
words = []
word = ""
for char, label in zip(sent_char, sent_label):
if label[0] == "B":
if word != "":
words.append(word)
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label")
outputs.append(" ".join(words))
def _make_output(self, results, infer_input):
if self.infer_type == "seq_label":
outputs = make_seq_label_output(results, infer_input)
elif self.infer_type == "text_class":
outputs = make_class_output(results, infer_input)
else:
raise ValueError("fail to make outputs with infer type {}".format(self.infer_type))
return outputs


def make_seq_label_output(result, infer_input):
"""
Transform model output into user-friendly contents.
:param result: 1-D list of strings. (model output)
:param infer_input: 2-D list of string (model input)
:return outputs:
"""
return result


def make_class_output(result, infer_input):
return result


def interpret_word_seg_results(infer_input, results):
"""
Transform model output into user-friendly contents.
Example: In CWS, convert <BMES> labeling into segmented text.
:param results: list of strings. (model output)
:param infer_input: 2-D list of string (model input)
:return output: list of strings
"""
outputs = []
for sent_char, sent_label in zip(infer_input, results):
words = []
word = ""
for char, label in zip(sent_char, sent_label):
if label[0] == "B":
if word != "":
words.append(word)
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label")
outputs.append(" ".join(words))
return outputs

+ 9
- 4
fastNLP/models/cnn_text_classification.py View File

@@ -15,12 +15,17 @@ class CNNText(torch.nn.Module):
Classification.'
"""

def __init__(self, class_num=9,
kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5],
embed_num=1000, embed_dim=300, pretrained_embed=None,
drop_prob=0.5):
def __init__(self, args):
super(CNNText, self).__init__()

class_num = args["num_classes"]
kernel_nums = [100, 100, 100]
kernel_sizes = [3, 4, 5]
embed_num = args["vocab_size"]
embed_dim = 300
pretrained_embed = None
drop_prob = 0.5

# no support for pre-trained embedding currently
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0)
self.conv_pool = ConvMaxpool(


+ 46
- 0
fastNLP/models/sequence_modeling.py View File

@@ -56,3 +56,49 @@ class SeqLabeling(BaseModel):
"""
tag_seq = self.Crf.viterbi_decode(x, mask)
return tag_seq


class AdvSeqLabel(SeqLabeling):
"""
Advanced Sequence Labeling Model
"""

def __init__(self, args, emb=None):
super(AdvSeqLabel, self).__init__(args)

vocab_size = args["vocab_size"]
word_emb_dim = args["word_emb_dim"]
hidden_dim = args["rnn_hidden_units"]
num_classes = args["num_classes"]

self.Embedding = encoder.embedding.Embedding(vocab_size, word_emb_dim, init_emb=emb)
self.Rnn = encoder.lstm.Lstm(word_emb_dim, hidden_dim, num_layers=3, dropout=0.3, bidirectional=True)
self.Linear1 = encoder.Linear(hidden_dim * 2, hidden_dim * 2 // 3)
self.batch_norm = torch.nn.BatchNorm1d(hidden_dim * 2 // 3)
self.relu = torch.nn.ReLU()
self.drop = torch.nn.Dropout(0.3)
self.Linear2 = encoder.Linear(hidden_dim * 2 // 3, num_classes)

self.Crf = decoder.CRF.ConditionalRandomField(num_classes)

def forward(self, x):
"""
:param x: LongTensor, [batch_size, mex_len]
:return y: [batch_size, mex_len, tag_size]
"""
batch_size = x.size(0)
max_len = x.size(1)
x = self.Embedding(x)
# [batch_size, max_len, word_emb_dim]
x = self.Rnn(x)
# [batch_size, max_len, hidden_size * direction]
x = x.contiguous()
x = x.view(batch_size * max_len, -1)
x = self.Linear1(x)
x = self.batch_norm(x)
x = self.relu(x)
x = self.drop(x)
x = self.Linear2(x)
x = x.view(batch_size, max_len, -1)
# [batch_size, max_len, num_classes]
return x

+ 17
- 2
test/data_for_tests/config View File

@@ -89,5 +89,20 @@ rnn_hidden_units = 100
rnn_layers = 1
rnn_bi_direction = true
word_emb_dim = 100
vocab_size = 52
num_classes = 22
vocab_size = 53
num_classes = 27

[text_class]
epochs = 1
batch_size = 10
pickle_path = "./data_for_tests/"
validate = false
save_best_dev = false
model_saved_path = "./data_for_tests/"
use_cuda = true
learn_rate = 1e-3
momentum = 0.9

[text_class_model]
vocab_size = 867
num_classes = 18

+ 154
- 0
test/data_for_tests/people.txt View File

@@ -123,6 +123,160 @@
张 S-q
) S-w

迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

中 B-nt
共 M-nt
中 M-nt
央 E-nt
总 B-n
书 M-n
记 E-n
、 S-w
国 B-n
家 E-n
主 B-n
席 E-n
江 B-nr
泽 M-nr
民 E-nr

( S-w
一 B-t
九 M-t
九 M-t
七 M-t
年 E-t
十 B-t
二 M-t
月 E-t
三 B-t
十 M-t
一 M-t
日 E-t
) S-w

1 B-t
2 M-t
月 E-t
3 B-t
1 M-t
日 E-t
, S-w
迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

迈 B-v
向 E-v
充 B-v
满 E-v
希 B-n
望 E-n
的 S-u
新 S-a
世 B-n
纪 E-n
— B-w
— E-w
一 B-t
九 M-t
九 M-t
八 M-t
年 E-t
新 B-t
年 E-t
讲 B-n
话 E-n
( S-w
附 S-v
图 B-n
片 E-n
1 S-m
张 S-q
) S-w

迈 B-v
向 E-v
充 B-v


+ 137
- 0
test/ner.py View File

@@ -0,0 +1,137 @@
import _pickle
import os

import numpy as np
import torch

from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.preprocess import POSPreprocess
from fastNLP.models.sequence_modeling import AdvSeqLabel


class MyNERTrainer(SeqLabelTrainer):
def __init__(self, train_args):
super(MyNERTrainer, self).__init__(train_args)
self.scheduler = None

def define_optimizer(self):
"""
override
:return:
"""
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.001)
self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=3000, gamma=0.5)

def update(self):
"""
override
:return:
"""
self.optimizer.step()
self.scheduler.step()

def _create_validator(self, valid_args):
return MyNERTester(valid_args)

def best_eval_result(self, validator):
accuracy = validator.metrics()
if accuracy > self.best_accuracy:
self.best_accuracy = accuracy
return True
else:
return False


class MyNERTester(SeqLabelTester):
def __init__(self, test_args):
super(MyNERTester, self).__init__(test_args)

def _evaluate(self, prediction, batch_y, seq_len):
"""
:param prediction: [batch_size, seq_len, num_classes]
:param batch_y: [batch_size, seq_len]
:param seq_len: [batch_size]
:return:
"""
summ = 0
correct = 0
_, indices = torch.max(prediction, 2)
for p, y, l in zip(indices, batch_y, seq_len):
summ += l
correct += np.sum(p[:l].cpu().numpy() == y[:l].cpu().numpy())
return float(correct / summ)

def evaluate(self, predict, truth):
return self._evaluate(predict, truth, self.seq_len)

def metrics(self):
return np.mean(self.eval_history)

def show_matrices(self):
return "dev accuracy={:.2f}".format(float(self.metrics()))


def embedding_process(emb_file, word_dict, emb_dim, emb_pkl):
if os.path.exists(emb_pkl):
with open(emb_pkl, "rb") as f:
embedding_np = _pickle.load(f)
return embedding_np
with open(emb_file, "r", encoding="utf-8") as f:
embedding_np = np.random.uniform(-1, 1, size=(len(word_dict), emb_dim))
for line in f:
line = line.strip().split()
if len(line) != emb_dim + 1:
continue
if line[0] in word_dict:
embedding_np[word_dict[line[0]]] = [float(i) for i in line[1:]]
with open(emb_pkl, "wb") as f:
_pickle.dump(embedding_np, f)
return embedding_np


def data_load(data_file):
with open(data_file, "r", encoding="utf-8") as f:
all_data = []
sent = []
label = []
for line in f:
line = line.strip().split()

if not len(line) <= 1:
sent.append(line[0])
label.append(line[1])
else:
all_data.append([sent, label])
sent = []
label = []
return all_data


data_path = "data_for_tests/people.txt"
pick_path = "data_for_tests/"
emb_path = "data_for_tests/emb50.txt"
save_path = "data_for_tests/"
if __name__ == "__main__":
data = data_load(data_path)
p = POSPreprocess(data, pickle_path=pick_path, train_dev_split=0.3)
# emb = embedding_process(emb_path, p.word2index, 50, os.path.join(pick_path, "embedding.pkl"))
emb = None
args = {"epochs": 20,
"batch_size": 1,
"pickle_path": pick_path,
"validate": True,
"save_best_dev": True,
"model_saved_path": save_path,
"use_cuda": True,

"vocab_size": p.vocab_size,
"num_classes": p.num_classes,
"word_emb_dim": 50,
"rnn_hidden_units": 100
}
# emb = torch.Tensor(emb).float().cuda()
networks = AdvSeqLabel(args, emb)
trainer = MyNERTrainer(args)
trainer.train(network=networks)
print("Training finished!")

+ 129
- 0
test/ner_decode.py View File

@@ -0,0 +1,129 @@
import _pickle
import os

import torch

from fastNLP.core.inference import SeqLabelInfer
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.models.sequence_modeling import AdvSeqLabel


class Decode(SeqLabelTrainer):
def __init__(self, args):
super(Decode, self).__init__(args)

def decoder(self, network, sents, model_path):
self.model = network
self.model.load_state_dict(torch.load(model_path))
out_put = []
self.mode(network, test=True)
for batch_x in sents:
prediction = self.data_forward(self.model, batch_x)

seq_tag = self.model.prediction(prediction, batch_x[1])

out_put.append(list(seq_tag)[0])
return out_put


def process_sent(sents, word2id):
sents_num = []
for s in sents:
sent_num = []
for c in s:
if c in word2id:
sent_num.append(word2id[c])
else:
sent_num.append(word2id["<unk>"])
sents_num.append(([sent_num], [len(sent_num)])) # batch_size is 1

return sents_num


def process_tag(sents, tags, id2class):
Tags = []
for ttt in tags:
Tags.append([id2class[t] for t in ttt])

Segs = []
PosNers = []
for sent, tag in zip(sents, tags):
word__ = []
lll__ = []
for c, t in zip(sent, tag):

t = id2class[t]
l = t.split("-")
split_ = l[0]
pn = l[1]

if split_ == "S":
word__.append(c)
lll__.append(pn)
word_1 = ""
elif split_ == "E":
word_1 += c
word__.append(word_1)
lll__.append(pn)
word_1 = ""
elif split_ == "B":
word_1 = ""
word_1 += c
else:
word_1 += c
Segs.append(word__)
PosNers.append(lll__)
return Segs, PosNers


pickle_path = "data_for_tests/"
model_path = "data_for_tests/model_best_dev.pkl"
if __name__ == "__main__":

with open(os.path.join(pickle_path, "id2word.pkl"), "rb") as f:
id2word = _pickle.load(f)
with open(os.path.join(pickle_path, "word2id.pkl"), "rb") as f:
word2id = _pickle.load(f)
with open(os.path.join(pickle_path, "id2class.pkl"), "rb") as f:
id2class = _pickle.load(f)

sent = ["中共中央总书记、国家主席江泽民",
"逆向处理输入序列并返回逆序后的序列"] # here is input

args = {"epochs": 1,
"batch_size": 1,
"pickle_path": "data_for_tests/",
"validate": True,
"save_best_dev": True,
"model_saved_path": "data_for_tests/",
"use_cuda": False,

"vocab_size": len(word2id),
"num_classes": len(id2class),
"word_emb_dim": 50,
"rnn_hidden_units": 100,
}
"""
network = AdvSeqLabel(args, None)
decoder_ = Decode(args)
tags_num = decoder_.decoder(network, process_sent(sent, word2id), model_path=model_path)
output_seg, output_pn = process_tag(sent, tags_num, id2class) # here is output
print(output_seg)
print(output_pn)
"""
# Define the same model
model = AdvSeqLabel(args, None)

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./data_for_tests/model_best_dev.pkl")
print("model loaded!")

# Inference interface
infer = SeqLabelInfer(pickle_path)
sent = [[ch for ch in s] for s in sent]
results = infer.predict(model, sent)

for res in results:
print(res)
print("Inference finished!")

+ 2
- 2
test/seq_labeling.py View File

@@ -112,5 +112,5 @@ def train_and_test():


if __name__ == "__main__":
# train_and_test()
infer()
train_and_test()
# infer()

+ 12
- 3
test/test_fastNLP.py View File

@@ -1,9 +1,18 @@
from fastNLP.fastnlp import FastNLP


def foo():
def word_seg():
nlp = FastNLP("./data_for_tests/")
nlp.load("zh_pos_tag_model")
nlp.load("seq_label_model")
text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text)
print(result)
print("FastNLP finished!")


def text_class():
nlp = FastNLP("./data_for_tests/")
nlp.load("text_class_model")
text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text)
print(result)
@@ -11,4 +20,4 @@ def foo():


if __name__ == "__main__":
foo()
text_class()

+ 11
- 11
test/text_classify.py View File

@@ -5,6 +5,7 @@ import os

from fastNLP.core.inference import ClassificationInfer
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.loader.preprocess import ClassPreprocess
@@ -29,9 +30,13 @@ def infer():
print("vocabulary size:", vocab_size)
print("number of classes:", n_classes)

model_args = ConfigSection()
ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args})

# construct model
print("Building model...")
cnn = CNNText(class_num=n_classes, embed_num=vocab_size)
cnn = CNNText(model_args)

# Dump trained parameters into the model
ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
print("model loaded!")
@@ -42,6 +47,9 @@ def infer():


def train():
train_args, model_args = ConfigSection(), ConfigSection()
ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args})

# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
@@ -56,19 +64,11 @@ def train():

# construct model
print("Building model...")
cnn = CNNText(class_num=n_classes, embed_num=vocab_size)
cnn = CNNText(model_args)

# train
print("Training...")
train_args = {
"epochs": 1,
"batch_size": 10,
"pickle_path": data_dir,
"validate": False,
"save_best_dev": False,
"model_saved_path": "./data_for_tests/",
"use_cuda": True
}

trainer = ClassificationTrainer(train_args)
trainer.train(cnn)



Loading…
Cancel
Save