Browse Source

fastNLP high-level interface:

- fastNLP interface for sequence labeling works
- fastNLP interface for text classification works
tags/v0.1.0
FengZiYjun 7 years ago
parent
commit
77b3a0c67d
8 changed files with 148 additions and 76 deletions
  1. +2
    -2
      fastNLP/core/inference.py
  2. +2
    -8
      fastNLP/core/trainer.py
  3. +93
    -44
      fastNLP/fastnlp.py
  4. +9
    -4
      fastNLP/models/cnn_text_classification.py
  5. +17
    -2
      test/data_for_tests/config
  6. +2
    -2
      test/seq_labeling.py
  7. +12
    -3
      test/test_fastNLP.py
  8. +11
    -11
      test/text_classify.py

+ 2
- 2
fastNLP/core/inference.py View File

@@ -149,7 +149,7 @@ class SeqLabelInfer(Inference):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, tag_seq_length].
:return:
:return results: 2-D list of strings
"""
results = []
for batch in batch_outputs:
@@ -178,7 +178,7 @@ class ClassificationInfer(Inference):
"""
Transform list of batch outputs into strings.
:param batch_outputs: list of 2-D Tensor, of shape [num_batch, batch-size, num_classes].
:return:
:return results: list of strings
"""
results = []
for batch_out in batch_outputs:


+ 2
- 8
fastNLP/core/trainer.py View File

@@ -315,14 +315,8 @@ class ClassificationTrainer(BaseTrainer):

def __init__(self, train_args):
super(ClassificationTrainer, self).__init__(train_args)
if "learn_rate" in train_args:
self.learn_rate = train_args["learn_rate"]
else:
self.learn_rate = 1e-3
if "momentum" in train_args:
self.momentum = train_args["momentum"]
else:
self.momentum = 0.9
self.learn_rate = train_args["learn_rate"]
self.momentum = train_args["momentum"]

self.iterator = None
self.loss_func = None


+ 93
- 44
fastNLP/fastnlp.py View File

@@ -1,4 +1,4 @@
from fastNLP.core.inference import Inference
from fastNLP.core.inference import SeqLabelInfer, ClassificationInfer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader

@@ -10,14 +10,28 @@ Example:
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"]
"""
FastNLP_MODEL_COLLECTION = {
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"]
"seq_label_model": {
"url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling",
"pickle": "seq_label_model.pkl",
"type": "seq_label"
},
"text_class_model": {
"url": "www.fudan.edu.cn",
"class": "cnn_text_classification.CNNText",
"pickle": "text_class_model.pkl",
"type": "text_class"
}
}

CONFIG_FILE_NAME = "config"
SECTION_NAME = "text_class_model"


class FastNLP(object):
"""
High-level interface for direct model inference.
Usage:
Example Usage:
fastnlp = FastNLP()
fastnlp.load("zh_pos_tag_model")
text = "这是最好的基于深度学习的中文分词系统。"
@@ -35,6 +49,7 @@ class FastNLP(object):
"""
self.model_dir = model_dir
self.model = None
self.infer_type = None # "seq_label"/"text_class"

def load(self, model_name):
"""
@@ -46,21 +61,21 @@ class FastNLP(object):
raise ValueError("No FastNLP model named {}.".format(model_name))

if not self.model_exist(model_dir=self.model_dir):
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name][0])
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"])

model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name][1])
model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])

model_args = ConfigSection()
# To do: customized config file for model init parameters
ConfigLoader.load_config(self.model_dir + "config", {"POS_infer": model_args})
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args})

# Construct the model
model = model_class(model_args)

# To do: framework independent
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name][2])
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"])

self.model = model
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"]

print("Model loaded. ")

@@ -71,12 +86,16 @@ class FastNLP(object):
:return results:
"""

infer = Inference(self.model_dir)
infer = self._create_inference(self.model_dir)

# string ---> 2-D list of string
infer_input = self.string_to_list(raw_input)

# 2-D list of string ---> list of strings
results = infer.predict(self.model, infer_input)

outputs = self.make_output(results)
# list of strings ---> final answers
outputs = self._make_output(results, infer_input)
return outputs

@staticmethod
@@ -95,6 +114,14 @@ class FastNLP(object):
module = getattr(module, sub)
return module

def _create_inference(self, model_dir):
if self.infer_type == "seq_label":
return SeqLabelInfer(model_dir)
elif self.infer_type == "text_class":
return ClassificationInfer(model_dir)
else:
raise ValueError("fail to create inference instance")

def _load(self, model_dir, model_name):
# To do
return 0
@@ -117,7 +144,6 @@ class FastNLP(object):

def string_to_list(self, text, delimiter="\n"):
"""
For word seg only, currently.
This function is used to transform raw input to lists, which is done by DatasetLoader in training.
Split text string into three-level lists.
[
@@ -127,7 +153,7 @@ class FastNLP(object):
]
:param text: string
:param delimiter: str, character used to split text into sentences.
:return data: three-level lists
:return data: two-level lists
"""
data = []
sents = text.strip().split(delimiter)
@@ -136,38 +162,61 @@ class FastNLP(object):
for ch in sent:
characters.append(ch)
data.append(characters)
# To refactor: this is used in make_output
self.data = data
return data

def make_output(self, results):
"""
Transform model output into user-friendly contents.
Example: In CWS, convert <BMES> labeling into segmented text.
:param results:
:return:
"""
outputs = []
for sent_char, sent_label in zip(self.data, results):
words = []
word = ""
for char, label in zip(sent_char, sent_label):
if label[0] == "B":
if word != "":
words.append(word)
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label")
outputs.append(" ".join(words))
def _make_output(self, results, infer_input):
if self.infer_type == "seq_label":
outputs = make_seq_label_output(results, infer_input)
elif self.infer_type == "text_class":
outputs = make_class_output(results, infer_input)
else:
raise ValueError("fail to make outputs with infer type {}".format(self.infer_type))
return outputs


def make_seq_label_output(result, infer_input):
"""
Transform model output into user-friendly contents.
:param result: 1-D list of strings. (model output)
:param infer_input: 2-D list of string (model input)
:return outputs:
"""
return result


def make_class_output(result, infer_input):
return result


def interpret_word_seg_results(infer_input, results):
"""
Transform model output into user-friendly contents.
Example: In CWS, convert <BMES> labeling into segmented text.
:param results: list of strings. (model output)
:param infer_input: 2-D list of string (model input)
:return output: list of strings
"""
outputs = []
for sent_char, sent_label in zip(infer_input, results):
words = []
word = ""
for char, label in zip(sent_char, sent_label):
if label[0] == "B":
if word != "":
words.append(word)
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label")
outputs.append(" ".join(words))
return outputs

+ 9
- 4
fastNLP/models/cnn_text_classification.py View File

@@ -15,12 +15,17 @@ class CNNText(torch.nn.Module):
Classification.'
"""

def __init__(self, class_num=9,
kernel_nums=[100, 100, 100], kernel_sizes=[3, 4, 5],
embed_num=1000, embed_dim=300, pretrained_embed=None,
drop_prob=0.5):
def __init__(self, args):
super(CNNText, self).__init__()

class_num = args["num_classes"]
kernel_nums = [100, 100, 100]
kernel_sizes = [3, 4, 5]
embed_num = args["vocab_size"]
embed_dim = 300
pretrained_embed = None
drop_prob = 0.5

# no support for pre-trained embedding currently
self.embed = nn.Embedding(embed_num, embed_dim, padding_idx=0)
self.conv_pool = ConvMaxpool(


+ 17
- 2
test/data_for_tests/config View File

@@ -89,5 +89,20 @@ rnn_hidden_units = 100
rnn_layers = 1
rnn_bi_direction = true
word_emb_dim = 100
vocab_size = 52
num_classes = 22
vocab_size = 53
num_classes = 27

[text_class]
epochs = 1
batch_size = 10
pickle_path = "./data_for_tests/"
validate = false
save_best_dev = false
model_saved_path = "./data_for_tests/"
use_cuda = true
learn_rate = 1e-3
momentum = 0.9

[text_class_model]
vocab_size = 867
num_classes = 18

+ 2
- 2
test/seq_labeling.py View File

@@ -112,5 +112,5 @@ def train_and_test():


if __name__ == "__main__":
# train_and_test()
infer()
train_and_test()
# infer()

+ 12
- 3
test/test_fastNLP.py View File

@@ -1,9 +1,18 @@
from fastNLP.fastnlp import FastNLP


def foo():
def word_seg():
nlp = FastNLP("./data_for_tests/")
nlp.load("zh_pos_tag_model")
nlp.load("seq_label_model")
text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text)
print(result)
print("FastNLP finished!")


def text_class():
nlp = FastNLP("./data_for_tests/")
nlp.load("text_class_model")
text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text)
print(result)
@@ -11,4 +20,4 @@ def foo():


if __name__ == "__main__":
foo()
text_class()

+ 11
- 11
test/text_classify.py View File

@@ -5,6 +5,7 @@ import os

from fastNLP.core.inference import ClassificationInfer
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import ClassDatasetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.loader.preprocess import ClassPreprocess
@@ -29,9 +30,13 @@ def infer():
print("vocabulary size:", vocab_size)
print("number of classes:", n_classes)

model_args = ConfigSection()
ConfigLoader.load_config("data_for_tests/config", {"text_class_model": model_args})

# construct model
print("Building model...")
cnn = CNNText(class_num=n_classes, embed_num=vocab_size)
cnn = CNNText(model_args)

# Dump trained parameters into the model
ModelLoader.load_pytorch(cnn, "./data_for_tests/saved_model.pkl")
print("model loaded!")
@@ -42,6 +47,9 @@ def infer():


def train():
train_args, model_args = ConfigSection(), ConfigSection()
ConfigLoader.load_config("data_for_tests/config", {"text_class": train_args, "text_class_model": model_args})

# load dataset
print("Loading data...")
ds_loader = ClassDatasetLoader("train", os.path.join(data_dir, train_file))
@@ -56,19 +64,11 @@ def train():

# construct model
print("Building model...")
cnn = CNNText(class_num=n_classes, embed_num=vocab_size)
cnn = CNNText(model_args)

# train
print("Training...")
train_args = {
"epochs": 1,
"batch_size": 10,
"pickle_path": data_dir,
"validate": False,
"save_best_dev": False,
"model_saved_path": "./data_for_tests/",
"use_cuda": True
}

trainer = ClassificationTrainer(train_args)
trainer.train(cnn)



Loading…
Cancel
Save