Browse Source

Merge remote-tracking branch 'upstream/master'

tags/v0.1.0
Ke Zhen 6 years ago
parent
commit
de89674436
13 changed files with 187 additions and 10814 deletions
  1. +2
    -2
      fastNLP/core/preprocess.py
  2. +2
    -2
      fastNLP/core/tester.py
  3. +26
    -55
      fastNLP/core/trainer.py
  4. +109
    -70
      fastNLP/fastnlp.py
  5. BIN
      fastnlp-architecture.jpg
  6. +0
    -5331
      reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg
  7. +0
    -5331
      reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos
  8. BIN
      reproduction/HAN-document_classification/data/test_samples.pkl
  9. BIN
      reproduction/HAN-document_classification/data/train_samples.pkl
  10. BIN
      reproduction/HAN-document_classification/data/yelp.word2vec
  11. +19
    -14
      reproduction/chinese_word_segment/run.py
  12. +2
    -2
      test/seq_labeling.py
  13. +27
    -7
      test/test_fastNLP.py

+ 2
- 2
fastNLP/core/preprocess.py View File

@@ -19,13 +19,13 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
with open(os.path.join(pickle_path, file_name), "wb") as f: with open(os.path.join(pickle_path, file_name), "wb") as f:
_pickle.dump(obj, f) _pickle.dump(obj, f)
print("{} saved. ".format(file_name))
print("{} saved in {}".format(file_name, pickle_path))




def load_pickle(pickle_path, file_name): def load_pickle(pickle_path, file_name):
with open(os.path.join(pickle_path, file_name), "rb") as f: with open(os.path.join(pickle_path, file_name), "rb") as f:
obj = _pickle.load(f) obj = _pickle.load(f)
print("{} loaded. ".format(file_name))
print("{} loaded from {}".format(file_name, pickle_path))
return obj return obj






+ 2
- 2
fastNLP/core/tester.py View File

@@ -98,7 +98,7 @@ class BaseTester(object):


print_output = "[test step {}] {}".format(step, eval_results) print_output = "[test step {}] {}".format(step, eval_results)
logger.info(print_output) logger.info(print_output)
if step % self.print_every_step == 0:
if self.print_every_step > 0 and step % self.print_every_step == 0:
print(print_output) print(print_output)
step += 1 step += 1


@@ -187,7 +187,7 @@ class SeqLabelTester(BaseTester):
# make sure "results" is in the same device as "truth" # make sure "results" is in the same device as "truth"
results = results.to(truth) results = results.to(truth)
accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0] accuracy = torch.sum(results == truth.view((-1,))).to(torch.float) / results.shape[0]
return [loss.data, accuracy.data]
return [float(loss), float(accuracy)]


def metrics(self): def metrics(self):
batch_loss = np.mean([x[0] for x in self.eval_history]) batch_loss = np.mean([x[0] for x in self.eval_history])


+ 26
- 55
fastNLP/core/trainer.py View File

@@ -4,7 +4,6 @@ import os
import time import time
from datetime import timedelta from datetime import timedelta


import numpy as np
import torch import torch


from fastNLP.core.action import Action from fastNLP.core.action import Action
@@ -47,7 +46,7 @@ class BaseTrainer(object):
Otherwise, error will raise. Otherwise, error will raise.
""" """
default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/", default_args = {"epochs": 3, "batch_size": 8, "validate": True, "use_cuda": True, "pickle_path": "./save/",
"save_best_dev": True, "model_name": "default_model_name.pkl",
"save_best_dev": True, "model_name": "default_model_name.pkl", "print_every_step": 1,
"loss": Loss(None), "loss": Loss(None),
"optimizer": Optimizer("Adam", lr=0.001, weight_decay=0) "optimizer": Optimizer("Adam", lr=0.001, weight_decay=0)
} }
@@ -86,6 +85,7 @@ class BaseTrainer(object):
self.save_best_dev = default_args["save_best_dev"] self.save_best_dev = default_args["save_best_dev"]
self.use_cuda = default_args["use_cuda"] self.use_cuda = default_args["use_cuda"]
self.model_name = default_args["model_name"] self.model_name = default_args["model_name"]
self.print_every_step = default_args["print_every_step"]


self._model = None self._model = None
self._loss_func = default_args["loss"].get() # return a pytorch loss function or None self._loss_func = default_args["loss"].get() # return a pytorch loss function or None
@@ -93,48 +93,35 @@ class BaseTrainer(object):
self._optimizer_proto = default_args["optimizer"] self._optimizer_proto = default_args["optimizer"]


def train(self, network, train_data, dev_data=None): def train(self, network, train_data, dev_data=None):
"""General Training Steps
"""General Training Procedure
:param network: a model :param network: a model
:param train_data: three-level list, the training set. :param train_data: three-level list, the training set.
:param dev_data: three-level list, the validation data (optional) :param dev_data: three-level list, the validation data (optional)

The method is framework independent.
Work by calling the following methods:
- prepare_input
- mode
- define_optimizer
- data_forward
- get_loss
- grad_backward
- update
Subclasses must implement these methods with a specific framework.
""" """
# prepare model and data, transfer model to gpu if available
# transfer model to gpu if available
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self._model = network.cuda() self._model = network.cuda()
# self._model is used to access model-specific loss
else: else:
self._model = network self._model = network


# define tester over dev data
# define Tester over dev data
if self.validate: if self.validate:
default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True, default_valid_args = {"save_output": True, "validate_in_training": True, "save_dev_input": True,
"save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path, "save_loss": True, "batch_size": self.batch_size, "pickle_path": self.pickle_path,
"use_cuda": self.use_cuda}
"use_cuda": self.use_cuda, "print_every_step": 0}
validator = self._create_validator(default_valid_args) validator = self._create_validator(default_valid_args)
logger.info("validator defined as {}".format(str(validator))) logger.info("validator defined as {}".format(str(validator)))


# optimizer and loss
self.define_optimizer() self.define_optimizer()
logger.info("optimizer defined as {}".format(str(self._optimizer))) logger.info("optimizer defined as {}".format(str(self._optimizer)))
self.define_loss() self.define_loss()
logger.info("loss function defined as {}".format(str(self._loss_func))) logger.info("loss function defined as {}".format(str(self._loss_func)))


# main training epochs
n_samples = len(train_data)
n_batches = n_samples // self.batch_size
n_print = 1
# main training procedure
start = time.time() start = time.time()
logger.info("training epochs started") logger.info("training epochs started")

for epoch in range(1, self.n_epochs + 1): for epoch in range(1, self.n_epochs + 1):
logger.info("training epoch {}".format(epoch)) logger.info("training epoch {}".format(epoch))


@@ -144,23 +131,30 @@ class BaseTrainer(object):
data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False)) data_iterator = iter(Batchifier(RandomSampler(train_data), self.batch_size, drop_last=False))
logger.info("prepared data iterator") logger.info("prepared data iterator")


self._train_step(data_iterator, network, start=start, n_print=n_print, epoch=epoch)
# one forward and backward pass
self._train_step(data_iterator, network, start=start, n_print=self.print_every_step, epoch=epoch)


# validation
if self.validate: if self.validate:
logger.info("validation started") logger.info("validation started")
validator.test(network, dev_data) validator.test(network, dev_data)


if self.save_best_dev and self.best_eval_result(validator): if self.save_best_dev and self.best_eval_result(validator):
self.save_model(network, self.model_name) self.save_model(network, self.model_name)
print("saved better model selected by dev")
logger.info("saved better model selected by dev")
print("Saved better model selected by validation.")
logger.info("Saved better model selected by validation.")


valid_results = validator.show_matrices() valid_results = validator.show_matrices()
print("[epoch {}] {}".format(epoch, valid_results)) print("[epoch {}] {}".format(epoch, valid_results))
logger.info("[epoch {}] {}".format(epoch, valid_results)) logger.info("[epoch {}] {}".format(epoch, valid_results))


def _train_step(self, data_iterator, network, **kwargs): def _train_step(self, data_iterator, network, **kwargs):
"""Training process in one epoch."""
"""Training process in one epoch.
kwargs should contain:
- n_print: int, print training information every n steps.
- start: time.time(), the starting time of this step.
- epoch: int,
"""
step = 0 step = 0
for batch_x, batch_y in self.make_batch(data_iterator): for batch_x, batch_y in self.make_batch(data_iterator):


@@ -170,7 +164,7 @@ class BaseTrainer(object):
self.grad_backward(loss) self.grad_backward(loss)
self.update() self.update()


if step % kwargs["n_print"] == 0:
if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"])) diff = timedelta(seconds=round(end - kwargs["start"]))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format( print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.2} time: {}".format(
@@ -287,10 +281,11 @@ class BaseTrainer(object):
raise NotImplementedError raise NotImplementedError


def save_model(self, network, model_name): def save_model(self, network, model_name):
"""
"""Save this model with such a name.
This method may be called multiple times by Trainer to overwritten a better model.

:param network: the PyTorch model :param network: the PyTorch model
:param model_name: str :param model_name: str
model_best_dev.pkl may be overwritten by a better model in future epochs.
""" """
if model_name[-4:] != ".pkl": if model_name[-4:] != ".pkl":
model_name += ".pkl" model_name += ".pkl"
@@ -300,33 +295,9 @@ class BaseTrainer(object):
raise NotImplementedError raise NotImplementedError




class ToyTrainer(BaseTrainer):
"""
An example to show the definition of Trainer.
"""

def __init__(self, training_args):
super(ToyTrainer, self).__init__(training_args)

def load_train_data(self, data_path):
data_train = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
data_dev = _pickle.load(open(data_path + "/data_train.pkl", "rb"))
return data_train, data_dev, 0, 1

def data_forward(self, network, x):
return network(x)

def grad_backward(self, loss):
self._model.zero_grad()
loss.backward()

def get_loss(self, pred, truth):
return np.mean(np.square(pred - truth))


class SeqLabelTrainer(BaseTrainer): class SeqLabelTrainer(BaseTrainer):
""" """
Trainer for Sequence Modeling
Trainer for Sequence Labeling


""" """


@@ -384,7 +355,7 @@ class SeqLabelTrainer(BaseTrainer):




class ClassificationTrainer(BaseTrainer): class ClassificationTrainer(BaseTrainer):
"""Trainer for classification."""
"""Trainer for text classification."""


def __init__(self, **train_args): def __init__(self, **train_args):
super(ClassificationTrainer, self).__init__(**train_args) super(ClassificationTrainer, self).__init__(**train_args)


+ 109
- 70
fastNLP/fastnlp.py View File

@@ -1,4 +1,5 @@
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.model_loader import ModelLoader


@@ -7,14 +8,13 @@ mapping from model name to [URL, file_name.class_name, model_pickle_name]
Notice that the class of the model should be in "models" directory. Notice that the class of the model should be in "models" directory.


Example: Example:
"zh_pos_tag_model": ["www.fudan.edu.cn", "sequence_modeling.SeqLabeling", "saved_model.pkl"]
"""
FastNLP_MODEL_COLLECTION = {
"seq_label_model": { "seq_label_model": {
"url": "www.fudan.edu.cn", "url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling",
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
"pickle": "seq_label_model.pkl", "pickle": "seq_label_model.pkl",
"type": "seq_label"
"type": "seq_label",
"config_file_name": "config", # the name of the config file which stores model initialization parameters
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params
}, },
"text_class_model": { "text_class_model": {
"url": "www.fudan.edu.cn", "url": "www.fudan.edu.cn",
@@ -22,11 +22,18 @@ FastNLP_MODEL_COLLECTION = {
"pickle": "text_class_model.pkl", "pickle": "text_class_model.pkl",
"type": "text_class" "type": "text_class"
} }
"""
FastNLP_MODEL_COLLECTION = {
"cws_basic_model": {
"url": "",
"class": "sequence_modeling.AdvSeqLabel",
"pickle": "cws_basic_model_v_0.pkl",
"type": "seq_label",
"config_file_name": "config",
"config_section_name": "text_class_model"
}
} }


CONFIG_FILE_NAME = "config"
SECTION_NAME = "text_class_model"



class FastNLP(object): class FastNLP(object):
""" """
@@ -51,10 +58,13 @@ class FastNLP(object):
self.model = None self.model = None
self.infer_type = None # "seq_label"/"text_class" self.infer_type = None # "seq_label"/"text_class"


def load(self, model_name):
def load(self, model_name, config_file="config", section_name="model"):
""" """
Load a pre-trained FastNLP model together with additional data. Load a pre-trained FastNLP model together with additional data.
:param model_name: str, the name of a FastNLP model. :param model_name: str, the name of a FastNLP model.
:param config_file: str, the name of the config file which stores the initialization information of the model.
(default: "config")
:param section_name: str, the name of the corresponding section in the config file. (default: model)
""" """
assert type(model_name) is str assert type(model_name) is str
if model_name not in FastNLP_MODEL_COLLECTION: if model_name not in FastNLP_MODEL_COLLECTION:
@@ -64,37 +74,47 @@ class FastNLP(object):
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"]) self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"])


model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"]) model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])
print("Restore model class {}".format(str(model_class)))


model_args = ConfigSection() model_args = ConfigSection()
ConfigLoader.load_config(self.model_dir + CONFIG_FILE_NAME, {SECTION_NAME: model_args})
ConfigLoader.load_config(self.model_dir + config_file, {section_name: model_args})
print("Restore model hyper-parameters {}".format(str(model_args.data)))

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(self.model_dir, "word2id.pkl")
model_args["vocab_size"] = len(word2index)
index2label = load_pickle(self.model_dir, "id2class.pkl")
model_args["num_classes"] = len(index2label)


# Construct the model # Construct the model
model = model_class(model_args) model = model_class(model_args)
print("Model constructed.")


# To do: framework independent # To do: framework independent
ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"]) ModelLoader.load_pytorch(model, self.model_dir + FastNLP_MODEL_COLLECTION[model_name]["pickle"])
print("Model weights loaded.")


self.model = model self.model = model
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"] self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"]


print("Model loaded. ")
print("Inference ready.")


def run(self, raw_input): def run(self, raw_input):
""" """
Perform inference over given input using the loaded model. Perform inference over given input using the loaded model.
:param raw_input: str, raw text
:param raw_input: list of string. Each list is an input query.
:return results: :return results:
""" """


infer = self._create_inference(self.model_dir) infer = self._create_inference(self.model_dir)


# string ---> 2-D list of string
infer_input = self.string_to_list(raw_input)
# tokenize: list of string ---> 2-D list of string
infer_input = self.tokenize(raw_input, language="zh")


# 2-D list of string ---> list of strings
# 2-D list of string ---> 2-D list of tags
results = infer.predict(self.model, infer_input) results = infer.predict(self.model, infer_input)


# list of strings ---> final answers
# 2-D list of tags ---> list of final answers
outputs = self._make_output(results, infer_input) outputs = self._make_output(results, infer_input)
return outputs return outputs


@@ -142,81 +162,100 @@ class FastNLP(object):
""" """
return True return True


def string_to_list(self, text, delimiter="\n"):
"""
This function is used to transform raw input to lists, which is done by DatasetLoader in training.
Split text string into three-level lists.
[
[word_11, word_12, ...],
[word_21, word_22, ...],
...
]
:param text: string
:param delimiter: str, character used to split text into sentences.
:return data: two-level lists
def tokenize(self, text, language):
"""Extract tokens from strings.
For English, extract words separated by space.
For Chinese, extract characters.
TODO: more complex tokenization methods

:param text: list of string
:param language: str, one of ('zh', 'en'), Chinese or English.
:return data: list of list of string, each string is a token.
""" """
assert language in ("zh", "en")
data = [] data = []
sents = text.strip().split(delimiter)
for sent in sents:
characters = []
for ch in sent:
characters.append(ch)
data.append(characters)
for sent in text:
if language == "en":
tokens = sent.strip().split()
elif language == "zh":
tokens = [char for char in sent]
else:
raise RuntimeError("Unknown language {}".format(language))
data.append(tokens)
return data return data


def _make_output(self, results, infer_input): def _make_output(self, results, infer_input):
"""Transform the infer output into user-friendly output.

:param results: 1 or 2-D list of strings.
If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length]
If self.infer_type == "text_class", it is of shape [num_examples]
:param infer_input: 2-D list of string, the input query before inference.
:return outputs: list. Each entry is a prediction.
"""
if self.infer_type == "seq_label": if self.infer_type == "seq_label":
outputs = make_seq_label_output(results, infer_input) outputs = make_seq_label_output(results, infer_input)
elif self.infer_type == "text_class": elif self.infer_type == "text_class":
outputs = make_class_output(results, infer_input) outputs = make_class_output(results, infer_input)
else: else:
raise ValueError("fail to make outputs with infer type {}".format(self.infer_type))
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
return outputs return outputs




def make_seq_label_output(result, infer_input): def make_seq_label_output(result, infer_input):
"""
Transform model output into user-friendly contents.
:param result: 1-D list of strings. (model output)
"""Transform model output into user-friendly contents.
:param result: 2-D list of strings. (model output)
:param infer_input: 2-D list of string (model input) :param infer_input: 2-D list of string (model input)
:return outputs:
:return ret: list of list of tuples
[
[(word_11, label_11), (word_12, label_12), ...],
[(word_21, label_21), (word_22, label_22), ...],
...
]
""" """
return result

ret = []
for example_x, example_y in zip(infer_input, result):
ret.append([(x, y) for x, y in zip(example_x, example_y)])
return ret


def make_class_output(result, infer_input): def make_class_output(result, infer_input):
"""Transform model output into user-friendly contents.

:param result: 2-D list of strings. (model output)
:param infer_input: 1-D list of string (model input)
:return ret: the same as result, [label_1, label_2, ...]
"""
return result return result




def interpret_word_seg_results(infer_input, results):
"""
Transform model output into user-friendly contents.
def interpret_word_seg_results(char_seq, label_seq):
"""Transform model output into user-friendly contents.
Example: In CWS, convert <BMES> labeling into segmented text. Example: In CWS, convert <BMES> labeling into segmented text.
:param results: list of strings. (model output)
:param infer_input: 2-D list of string (model input)
:return output: list of strings
:param char_seq: list of string,
:param label_seq: list of string, the same length as char_seq
Each entry is one of ('B', 'M', 'E', 'S').
:return output: list of words
""" """
outputs = []
for sent_char, sent_label in zip(infer_input, results):
words = []
word = ""
for char, label in zip(sent_char, sent_label):
if label[0] == "B":
if word != "":
words.append(word)
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words = []
word = ""
for char, label in zip(char_seq, label_seq):
if label[0] == "B":
if word != "":
words.append(word) words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label")
outputs.append(" ".join(words))
return outputs
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label {}".format(label[0]))
return words

BIN
fastnlp-architecture.jpg View File

Before After
Width: 960  |  Height: 540  |  Size: 36 kB

+ 0
- 5331
reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.neg
File diff suppressed because it is too large
View File


+ 0
- 5331
reproduction/CNN-sentence_classification/rt-polaritydata/rt-polarity.pos
File diff suppressed because it is too large
View File


BIN
reproduction/HAN-document_classification/data/test_samples.pkl View File


BIN
reproduction/HAN-document_classification/data/train_samples.pkl View File


BIN
reproduction/HAN-document_classification/data/yelp.word2vec View File


+ 19
- 14
reproduction/chinese_word_segment/run.py View File

@@ -1,26 +1,26 @@
import sys, os
import os
import sys


sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.inference import SeqLabelInfer
from fastNLP.core.optimizer import SGD
from fastNLP.core.predictor import SeqLabelInfer


# not in the file's dir # not in the file's dir
if len(os.path.dirname(__file__)) != 0: if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__)) os.chdir(os.path.dirname(__file__))
datadir = 'icwb2-data'
cfgfile = 'cws.cfg'
datadir = "/home/zyfeng/data/"
cfgfile = './cws.cfg'
data_name = "pku_training.utf8" data_name = "pku_training.utf8"


cws_data_path = os.path.join(datadir, "training/pku_training.utf8")
cws_data_path = os.path.join(datadir, "pku_training.utf8")
pickle_path = "save" pickle_path = "save"
data_infer_path = os.path.join(datadir, "infer.utf8") data_infer_path = os.path.join(datadir, "infer.utf8")


@@ -70,12 +70,13 @@ def train():
train_data = loader.load_pku() train_data = loader.load_pku()


# Preprocessor # Preprocessor
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes
preprocessor = SeqLabelPreprocess()
data_train, data_dev = preprocessor.run(train_data, pickle_path=pickle_path, train_dev_split=0.3)
train_args["vocab_size"] = preprocessor.vocab_size
train_args["num_classes"] = preprocessor.num_classes


# Trainer # Trainer
trainer = SeqLabelTrainer(train_args)
trainer = SeqLabelTrainer(**train_args.data)


# Model # Model
model = AdvSeqLabel(train_args) model = AdvSeqLabel(train_args)
@@ -83,10 +84,11 @@ def train():
ModelLoader.load_pytorch(model, "./save/saved_model.pkl") ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!') print('model parameter loaded!')
except Exception as e: except Exception as e:
print("No saved model. Continue.")
pass pass
# Start training # Start training
trainer.train(model)
trainer.train(model, data_train, data_dev)
print("Training finished!") print("Training finished!")


# Saver # Saver
@@ -106,6 +108,9 @@ def test():
index2label = load_pickle(pickle_path, "id2class.pkl") index2label = load_pickle(pickle_path, "id2class.pkl")
test_args["num_classes"] = len(index2label) test_args["num_classes"] = len(index2label)


# load dev data
dev_data = load_pickle(pickle_path, "data_dev.pkl")

# Define the same model # Define the same model
model = AdvSeqLabel(test_args) model = AdvSeqLabel(test_args)


@@ -114,10 +119,10 @@ def test():
print("model loaded!") print("model loaded!")


# Tester # Tester
tester = SeqLabelTester(test_args)
tester = SeqLabelTester(**test_args.data)


# Start testing # Start testing
tester.test(model)
tester.test(model, dev_data)


# print test results # print test results
print(tester.show_matrices()) print(tester.show_matrices())


+ 2
- 2
test/seq_labeling.py View File

@@ -123,7 +123,7 @@ def train_and_test():
tester = SeqLabelTester(save_output=False, tester = SeqLabelTester(save_output=False,
save_loss=False, save_loss=False,
save_best_dev=False, save_best_dev=False,
batch_size=8,
batch_size=4,
use_cuda=False, use_cuda=False,
pickle_path=pickle_path, pickle_path=pickle_path,
model_name="seq_label_in_test.pkl", model_name="seq_label_in_test.pkl",
@@ -140,4 +140,4 @@ def train_and_test():


if __name__ == "__main__": if __name__ == "__main__":
train_and_test() train_and_test()
infer()
# infer()

+ 27
- 7
test/test_fastNLP.py View File

@@ -1,13 +1,24 @@
import sys
sys.path.append("..")
from fastNLP.fastnlp import FastNLP from fastNLP.fastnlp import FastNLP
from fastNLP.fastnlp import interpret_word_seg_results


PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/"


def word_seg(): def word_seg():
nlp = FastNLP("./data_for_tests/")
nlp.load("seq_label_model")
text = "这是最好的基于深度学习的中文分词系统。"
result = nlp.run(text)
print(result)
print("FastNLP finished!")
nlp = FastNLP(model_dir=PATH_TO_CWS_PICKLE_FILES)
nlp.load("cws_basic_model", config_file="cws.cfg", section_name="POS_test")
text = ["这是最好的基于深度学习的中文分词系统。",
"大王叫我来巡山。",
"我党多年来致力于改善人民生活水平。"]
results = nlp.run(text)
print(results)
for example in results:
words, labels = [], []
for res in example:
words.append(res[0])
labels.append(res[1])
print(interpret_word_seg_results(words, labels))




def text_class(): def text_class():
@@ -19,5 +30,14 @@ def text_class():
print("FastNLP finished!") print("FastNLP finished!")




def test_word_seg_interpret():
foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'),
('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'),
('。', 'S')]]
chars = [x[0] for x in foo[0]]
labels = [x[1] for x in foo[0]]
print(interpret_word_seg_results(chars, labels))


if __name__ == "__main__": if __name__ == "__main__":
text_class()
word_seg()

Loading…
Cancel
Save