Browse Source

Merge pull request #32 from choosewhatulike/master

add a new chinese word segmentation model
tags/v0.1.0
Xipeng Qiu GitHub 6 years ago
parent
commit
2f83010d9d
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 174 additions and 0 deletions
  1. +34
    -0
      reproduction/chinese_word_segment/cws.cfg
  2. +140
    -0
      reproduction/chinese_word_segment/run.py

+ 34
- 0
reproduction/chinese_word_segment/cws.cfg View File

@@ -0,0 +1,34 @@
[train]
epochs = 30
batch_size = 64
pickle_path = "./save/"
validate = true
save_best_dev = true
model_saved_path = "./save/"
rnn_hidden_units = 100
word_emb_dim = 100
use_crf = true
use_cuda = true

[test]
save_output = true
validate_in_training = true
save_dev_input = false
save_loss = true
batch_size = 640
pickle_path = "./save/"
use_crf = true
use_cuda = true


[POS_test]
save_output = true
validate_in_training = true
save_dev_input = false
save_loss = true
batch_size = 640
pickle_path = "./save/"
use_crf = true
use_cuda = true
rnn_hidden_units = 100
word_emb_dim = 100

+ 140
- 0
reproduction/chinese_word_segment/run.py View File

@@ -0,0 +1,140 @@
import sys, os

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import TokenizeDatasetLoader, BaseLoader
from fastNLP.loader.preprocess import POSPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.inference import SeqLabelInfer
from fastNLP.core.optimizer import SGD

# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))
datadir = 'icwb2-data'
cfgfile = 'cws.cfg'
data_name = "pku_training.utf8"

cws_data_path = os.path.join(datadir, "training/pku_training.utf8")
pickle_path = "save"
data_infer_path = os.path.join(datadir, "infer.utf8")

def infer():
# Config Loader
test_args = ConfigSection()
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "id2class.pkl")
test_args["num_classes"] = len(index2label)


# Define the same model
model = AdvSeqLabel(test_args)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model loaded!')
except Exception as e:
print('cannot load model!')
raise

# Data Loader
raw_data_loader = BaseLoader(data_name, data_infer_path)
infer_data = raw_data_loader.load_lines()
print('data loaded')

# Inference interface
infer = SeqLabelInfer(pickle_path)
results = infer.predict(model, infer_data)

print(results)
print("Inference finished!")


def train():
# Config Loader
train_args = ConfigSection()
test_args = ConfigSection()
ConfigLoader("good_name", "good_path").load_config(cfgfile, {"train": train_args, "test": test_args})

# Data Loader
loader = TokenizeDatasetLoader(data_name, cws_data_path)
train_data = loader.load_pku()

# Preprocessor
p = POSPreprocess(train_data, pickle_path, train_dev_split=0.3)
train_args["vocab_size"] = p.vocab_size
train_args["num_classes"] = p.num_classes

# Trainer
trainer = SeqLabelTrainer(train_args)

# Model
model = AdvSeqLabel(train_args)
try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model parameter loaded!')
except Exception as e:
pass
# Start training
trainer.train(model)
print("Training finished!")

# Saver
saver = ModelSaver("./save/saved_model.pkl")
saver.save_pytorch(model)
print("Model saved!")


def test():
# Config Loader
test_args = ConfigSection()
ConfigLoader("config", "").load_config(cfgfile, {"POS_test": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "id2class.pkl")
test_args["num_classes"] = len(index2label)

# Define the same model
model = AdvSeqLabel(test_args)

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print("model loaded!")

# Tester
tester = SeqLabelTester(test_args)

# Start testing
tester.test(model)

# print test results
print(tester.show_matrices())
print("model tested!")


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
parser.add_argument('--mode', help='set the model\'s model', choices=['train', 'test', 'infer'])
args = parser.parse_args()
if args.mode == 'train':
train()
elif args.mode == 'test':
test()
elif args.mode == 'infer':
infer()
else:
print('no mode specified for model!')
parser.print_help()

Loading…
Cancel
Save