Browse Source

* delete readme_example.py because it is oooooooout of date.

* rename preprocess.py into utils.py, because nothing about preprocess in it
* anything in loader/ and saver/ is moved directly into io/
* corresponding unit tests are moved to /test/io
* delete fastnlp.py, because we have new and better APIs
* rename Biaffine_parser/run_test.py to Biaffine_parser/main.py; Otherwise, test will fail.
* A looooooooooot of ancient codes to be refined...........
tags/v0.2.0
FengZiYjun yunfan 5 years ago
parent
commit
e9d7074ba1
41 changed files with 113 additions and 830 deletions
  1. +0
    -75
      examples/readme_example.py
  2. +26
    -16
      fastNLP/api/api.py
  3. +2
    -1
      fastNLP/core/field.py
  4. +2
    -2
      fastNLP/core/fieldarray.py
  5. +0
    -14
      fastNLP/core/predictor.py
  6. +12
    -10
      fastNLP/core/sampler.py
  7. +3
    -23
      fastNLP/core/tester.py
  8. +3
    -50
      fastNLP/core/trainer.py
  9. +0
    -2
      fastNLP/core/utils.py
  10. +1
    -1
      fastNLP/core/vocabulary.py
  11. +0
    -343
      fastNLP/fastnlp.py
  12. +0
    -0
      fastNLP/io/__init__.py
  13. +0
    -0
      fastNLP/io/base_loader.py
  14. +1
    -1
      fastNLP/io/config_loader.py
  15. +2
    -2
      fastNLP/io/config_saver.py
  16. +1
    -1
      fastNLP/io/dataset_loader.py
  17. +1
    -4
      fastNLP/io/embed_loader.py
  18. +0
    -0
      fastNLP/io/logger.py
  19. +4
    -4
      fastNLP/io/model_loader.py
  20. +0
    -0
      fastNLP/io/model_saver.py
  21. +3
    -1
      fastNLP/modules/dropout.py
  22. +2
    -4
      reproduction/Biaffine_parser/infer.py
  23. +0
    -2
      reproduction/Biaffine_parser/main.py
  24. +4
    -9
      reproduction/Biaffine_parser/run.py
  25. +4
    -4
      reproduction/LSTM+self_attention_sentiment_analysis/main.py
  26. +2
    -2
      reproduction/chinese_word_segment/cws_io/cws_reader.py
  27. +6
    -7
      reproduction/chinese_word_segment/run.py
  28. +2
    -2
      reproduction/pos_tag_model/train_pos_tag.py
  29. +1
    -1
      test/core/test_dataset.py
  30. +2
    -4
      test/core/test_predictor.py
  31. +0
    -0
      test/io/__init__.py
  32. +0
    -0
      test/io/config
  33. +1
    -1
      test/io/test_config_loader.py
  34. +2
    -2
      test/io/test_config_saver.py
  35. +3
    -3
      test/io/test_dataset_loader.py
  36. +2
    -4
      test/io/test_embed_loader.py
  37. +5
    -5
      test/model/seq_labeling.py
  38. +6
    -7
      test/model/test_cws.py
  39. +5
    -5
      test/model/test_seq_label.py
  40. +5
    -5
      test/model/text_classify.py
  41. +0
    -213
      test/test_fastNLP.py

+ 0
- 75
examples/readme_example.py View File

@@ -1,75 +0,0 @@
from fastNLP.core.loss import Loss
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.preprocess import ClassPreprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules import aggregator
from fastNLP.modules import decoder
from fastNLP.modules import encoder


class ClassificationModel(BaseModel):
"""
Simple text classification model based on CNN.
"""

def __init__(self, num_classes, vocab_size):
super(ClassificationModel, self).__init__()

self.emb = encoder.Embedding(nums=vocab_size, dims=300)
self.enc = encoder.Conv(
in_channels=300, out_channels=100, kernel_size=3)
self.agg = aggregator.MaxPool()
self.dec = decoder.MLP(size_layer=[100, num_classes])

def forward(self, x):
x = self.emb(x) # [N,L] -> [N,L,C]
x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
x = self.agg(x) # [N,L,C] -> [N,C]
x = self.dec(x) # [N,C] -> [N, N_class]
return x


data_dir = 'save/' # directory to save data and model
train_path = './data_for_tests/text_classify.txt' # training set file

# load dataset
ds_loader = ClassDataSetLoader()
data = ds_loader.load()

# pre-process dataset
pre = ClassPreprocess()
train_set, dev_set = pre.run(data, train_dev_split=0.3, pickle_path=data_dir)
n_classes, vocab_size = pre.num_classes, pre.vocab_size

# construct model
model_args = {
'num_classes': n_classes,
'vocab_size': vocab_size
}
model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)

# construct trainer
train_args = {
"epochs": 3,
"batch_size": 16,
"pickle_path": data_dir,
"validate": False,
"save_best_dev": False,
"model_saved_path": None,
"use_cuda": True,
"loss": Loss("cross_entropy"),
"optimizer": Optimizer("Adam", lr=0.001)
}
trainer = ClassificationTrainer(**train_args)

# start training
trainer.train(model, train_data=train_set, dev_data=dev_set)

# predict using model
data_infer = [x[0] for x in data]
infer = ClassificationInfer(data_dir)
labels_pred = infer.predict(model.cpu(), data_infer)
print(labels_pred)

+ 26
- 16
fastNLP/api/api.py View File

@@ -1,5 +1,7 @@
import torch
import warnings

import torch

warnings.filterwarnings('ignore')
import os

@@ -17,7 +19,6 @@ from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SeqLabelEvaluator2
from fastNLP.core.tester import Tester


model_urls = {
}

@@ -228,7 +229,7 @@ class Parser(API):
elif p.field_name == 'pos_list':
p.field_name = 'gold_pos'
pp(ds)
head_cor, label_cor, total = 0,0,0
head_cor, label_cor, total = 0, 0, 0
for ins in ds:
head_gold = ins['gold_heads']
head_pred = ins['heads']
@@ -236,7 +237,7 @@ class Parser(API):
total += length
for i in range(length):
head_cor += 1 if head_pred[i] == head_gold[i] else 0
uas = head_cor/total
uas = head_cor / total
print('uas:{:.2f}'.format(uas))

for p in pp:
@@ -247,25 +248,34 @@ class Parser(API):

return uas


if __name__ == "__main__":
# pos_model_path = '../../reproduction/pos_tag_model/pos_crf.pkl'
pos = POS(device='cpu')
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' ,
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# 以下路径在102
"""
pos_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/pos_crf-5e26d3b0.pkl'
pos = POS(model_path=pos_model_path, device='cpu')
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(pos.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
#print(pos.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
print(pos.predict(s))
"""

# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
cws = CWS(device='cuda:0')
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂' ,
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
"""
cws_model_path = '/home/hyan/fastNLP_models/upload-demo/upload/cws_crf-5a8a3e66.pkl'
cws = CWS(model_path=cws_model_path, device='cuda:0')
s = ['本品是一个抗酸抗胆汁的胃黏膜保护剂',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
#print(cws.test('../../reproduction/chinese_word_segment/new-clean.txt.conll'))
cws.predict(s)
parser = Parser(device='cuda:0')
print(parser.test('../../reproduction/Biaffine_parser/test.conll'))
"""

parser_model_path = "/home/hyan/fastNLP_models/upload-demo/upload/parser-d57cd5fc.pkl"
parser = Parser(model_path=parser_model_path, device='cuda:0')
# print(parser.test('../../reproduction/Biaffine_parser/test.conll'))
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?']
print(parser.predict(s))


+ 2
- 1
fastNLP/core/field.py View File

@@ -1,5 +1,4 @@
import torch
import numpy as np


class Field(object):
@@ -30,6 +29,7 @@ class Field(object):
def __repr__(self):
return self.content.__repr__()


class TextField(Field):
def __init__(self, text, is_target):
"""
@@ -43,6 +43,7 @@ class LabelField(Field):
"""The Field representing a single label. Can be a string or integer.

"""

def __init__(self, label, is_target=True):
super(LabelField, self).__init__(label, is_target)



+ 2
- 2
fastNLP/core/fieldarray.py View File

@@ -1,6 +1,6 @@
import torch
import numpy as np


class FieldArray(object):
def __init__(self, name, content, padding_val=0, is_target=False, need_tensor=False):
self.name = name
@@ -10,7 +10,7 @@ class FieldArray(object):
self.need_tensor = need_tensor

def __repr__(self):
#TODO
# TODO
return '{}: {}'.format(self.name, self.content.__repr__())

def append(self, val):


+ 0
- 14
fastNLP/core/predictor.py View File

@@ -50,20 +50,6 @@ class Predictor(object):
return y


class SeqLabelInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] SeqLabelInfer will be deprecated. Please use Predictor directly.")
super(SeqLabelInfer, self).__init__()


class ClassificationInfer(Predictor):
def __init__(self, pickle_path):
print(
"[FastNLP Warning] ClassificationInfer will be deprecated. Please use Predictor directly.")
super(ClassificationInfer, self).__init__()


def seq_label_post_processor(batch_outputs, label_vocab):
results = []
for batch in batch_outputs:


+ 12
- 10
fastNLP/core/sampler.py View File

@@ -1,6 +1,8 @@
from itertools import chain

import numpy as np
import torch
from itertools import chain

def convert_to_torch_tensor(data_list, use_cuda):
"""Convert lists into (cuda) Tensors.
@@ -43,6 +45,7 @@ class RandomSampler(BaseSampler):
def __call__(self, data_set):
return list(np.random.permutation(len(data_set)))


class BucketSampler(BaseSampler):

def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'):
@@ -56,14 +59,14 @@ class BucketSampler(BaseSampler):
total_sample_num = len(seq_lens)

bucket_indexes = []
num_sample_per_bucket = total_sample_num//self.num_buckets
num_sample_per_bucket = total_sample_num // self.num_buckets
for i in range(self.num_buckets):
bucket_indexes.append([num_sample_per_bucket*i, num_sample_per_bucket*(i+1)])
bucket_indexes.append([num_sample_per_bucket * i, num_sample_per_bucket * (i + 1)])
bucket_indexes[-1][1] = total_sample_num

sorted_seq_lens = list(sorted([(idx, seq_len) for
idx, seq_len in zip(range(total_sample_num), seq_lens)],
key=lambda x:x[1]))
key=lambda x: x[1]))

batchs = []

@@ -73,19 +76,18 @@ class BucketSampler(BaseSampler):
end_idx = bucket_indexes[b_idx][1]
sorted_bucket_seq_lens = sorted_seq_lens[start_idx:end_idx]
left_init_indexes.extend([tup[0] for tup in sorted_bucket_seq_lens])
num_batch_per_bucket = len(left_init_indexes)//self.batch_size
num_batch_per_bucket = len(left_init_indexes) // self.batch_size
np.random.shuffle(left_init_indexes)
for i in range(num_batch_per_bucket):
batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:]
if (left_init_indexes)!=0:
batchs.append(left_init_indexes[i * self.batch_size:(i + 1) * self.batch_size])
left_init_indexes = left_init_indexes[num_batch_per_bucket * self.batch_size:]
if (left_init_indexes) != 0:
batchs.append(left_init_indexes)
np.random.shuffle(batchs)

return list(chain(*batchs))



def simple_sort_bucketing(lengths):
"""

@@ -105,6 +107,7 @@ def simple_sort_bucketing(lengths):
# TODO: need to return buckets
return [idx for idx, _ in sorted_lengths]


def k_means_1d(x, k, max_iter=100):
"""Perform k-means on 1-D data.

@@ -159,4 +162,3 @@ def k_means_bucketing(lengths, buckets):
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
bucket_data[bucket_id].append(idx)
return bucket_data


+ 3
- 23
fastNLP/core/tester.py View File

@@ -1,10 +1,11 @@
import torch
from collections import defaultdict

import torch

from fastNLP.core.batch import Batch
from fastNLP.core.metrics import Evaluator
from fastNLP.core.sampler import RandomSampler
from fastNLP.saver.logger import create_logger
from fastNLP.io.logger import create_logger

logger = create_logger(__name__, "./train_test.log")

@@ -119,24 +120,3 @@ class Tester(object):

"""
return ", ".join([str(key) + "=" + str(value) for key, value in results.items()])


class SeqLabelTester(Tester):
def __init__(self, **test_args):
print(
"[FastNLP Warning] SeqLabelTester will be deprecated. Please use Tester directly.")
super(SeqLabelTester, self).__init__(**test_args)


class ClassificationTester(Tester):
def __init__(self, **test_args):
print(
"[FastNLP Warning] ClassificationTester will be deprecated. Please use Tester directly.")
super(ClassificationTester, self).__init__(**test_args)


class SNLITester(Tester):
def __init__(self, **test_args):
print(
"[FastNLP Warning] SNLITester will be deprecated. Please use Tester directly.")
super(SNLITester, self).__init__(**test_args)

+ 3
- 50
fastNLP/core/trainer.py View File

@@ -9,11 +9,10 @@ from fastNLP.core.batch import Batch
from fastNLP.core.loss import Loss
from fastNLP.core.metrics import Evaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.sampler import BucketSampler
from fastNLP.core.tester import SeqLabelTester, ClassificationTester, SNLITester
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.tester import Tester
from fastNLP.saver.logger import create_logger
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.io.logger import create_logger
from fastNLP.io.model_saver import ModelSaver

logger = create_logger(__name__, "./train_test.log")
logger.disabled = True
@@ -182,19 +181,10 @@ class Trainer(object):
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self._model.named_parameters():
if param.requires_grad:
<<<<<<< HEAD
# self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=step)
pass

if kwargs["n_print"] > 0 and step % kwargs["n_print"] == 0:
=======
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if kwargs["n_print"] > 0 and self.step % kwargs["n_print"] == 0:
>>>>>>> 5924fe0... fix and update tester, trainer, seq_model, add parser pipeline builder
end = time.time()
diff = timedelta(seconds=round(end - kwargs["start"]))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
@@ -339,40 +329,3 @@ class Trainer(object):
def set_validator(self, validor):
self.validator = validor


class SeqLabelTrainer(Trainer):
"""Trainer for Sequence Labeling

"""

def __init__(self, **kwargs):
print(
"[FastNLP Warning] SeqLabelTrainer will be deprecated. Please use Trainer directly.")
super(SeqLabelTrainer, self).__init__(**kwargs)

def _create_validator(self, valid_args):
return SeqLabelTester(**valid_args)


class ClassificationTrainer(Trainer):
"""Trainer for text classification."""

def __init__(self, **train_args):
print(
"[FastNLP Warning] ClassificationTrainer will be deprecated. Please use Trainer directly.")
super(ClassificationTrainer, self).__init__(**train_args)

def _create_validator(self, valid_args):
return ClassificationTester(**valid_args)


class SNLITrainer(Trainer):
"""Trainer for text SNLI."""

def __init__(self, **train_args):
print(
"[FastNLP Warning] SNLITrainer will be deprecated. Please use Trainer directly.")
super(SNLITrainer, self).__init__(**train_args)

def _create_validator(self, valid_args):
return SNLITester(**valid_args)

fastNLP/core/preprocess.py → fastNLP/core/utils.py View File

@@ -2,8 +2,6 @@ import _pickle
import os


# the first vocab in dict with the index = 5

def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.


+ 1
- 1
fastNLP/core/vocabulary.py View File

@@ -13,7 +13,7 @@ DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,

def isiterable(p_object):
try:
it = iter(p_object)
_ = iter(p_object)
except TypeError:
return False
return True


+ 0
- 343
fastNLP/fastnlp.py View File

@@ -1,343 +0,0 @@
import os

from fastNLP.core.dataset import DataSet
from fastNLP.loader.dataset_loader import convert_seq_dataset
from fastNLP.core.predictor import SeqLabelInfer, ClassificationInfer
from fastNLP.core.preprocess import load_pickle
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader

"""
mapping from model name to [URL, file_name.class_name, model_pickle_name]
Notice that the class of the model should be in "models" directory.

Example:
"seq_label_model": {
"url": "www.fudan.edu.cn",
"class": "sequence_modeling.SeqLabeling", # file_name.class_name in models/
"pickle": "seq_label_model.pkl",
"type": "seq_label",
"config_file_name": "config", # the name of the config file which stores model initialization parameters
"config_section_name": "text_class_model" # the name of the section in the config file which stores model init params
},
"text_class_model": {
"url": "www.fudan.edu.cn",
"class": "cnn_text_classification.CNNText",
"pickle": "text_class_model.pkl",
"type": "text_class"
}
"""
FastNLP_MODEL_COLLECTION = {
"cws_basic_model": {
"url": "",
"class": "sequence_modeling.AdvSeqLabel",
"pickle": "cws_basic_model_v_0.pkl",
"type": "seq_label",
"config_file_name": "cws.cfg",
"config_section_name": "text_class_model"
},
"pos_tag_model": {
"url": "",
"class": "sequence_modeling.AdvSeqLabel",
"pickle": "pos_tag_model_v_0.pkl",
"type": "seq_label",
"config_file_name": "pos_tag.cfg",
"config_section_name": "pos_tag_model"
},
"text_classify_model": {
"url": "",
"class": "cnn_text_classification.CNNText",
"pickle": "text_class_model_v0.pkl",
"type": "text_class",
"config_file_name": "text_classify.cfg",
"config_section_name": "model"
}
}


class FastNLP(object):
"""
High-level interface for direct model inference.
Example Usage
::
fastnlp = FastNLP()
fastnlp.load("zh_pos_tag_model")
text = "这是最好的基于深度学习的中文分词系统。"
result = fastnlp.run(text)
print(result) # ["这", "是", "最好", "的", "基于", "深度学习", "的", "中文", "分词", "系统", "。"]

"""

def __init__(self, model_dir="./"):
"""
:param model_dir: this directory should contain the following files:
1. a trained model
2. a config file, which is a fastNLP's configuration.
3. two Vocab files, which are pickle objects of Vocab instances, representing feature and label vocabs.
"""
self.model_dir = model_dir
self.model = None
self.infer_type = None # "seq_label"/"text_class"
self.word_vocab = None
self.label_vocab = None

def load(self, model_name, config_file="config", section_name="model"):
"""
Load a pre-trained FastNLP model together with additional data.
:param model_name: str, the name of a FastNLP model.
:param config_file: str, the name of the config file which stores the initialization information of the model.
(default: "config")
:param section_name: str, the name of the corresponding section in the config file. (default: model)
"""
assert type(model_name) is str
if model_name not in FastNLP_MODEL_COLLECTION:
raise ValueError("No FastNLP model named {}.".format(model_name))

if not self.model_exist(model_dir=self.model_dir):
self._download(model_name, FastNLP_MODEL_COLLECTION[model_name]["url"])

model_class = self._get_model_class(FastNLP_MODEL_COLLECTION[model_name]["class"])
print("Restore model class {}".format(str(model_class)))

model_args = ConfigSection()
ConfigLoader.load_config(os.path.join(self.model_dir, config_file), {section_name: model_args})
print("Restore model hyper-parameters {}".format(str(model_args.data)))

# fetch dictionary size and number of labels from pickle files
self.word_vocab = load_pickle(self.model_dir, "word2id.pkl")
model_args["vocab_size"] = len(self.word_vocab)
self.label_vocab = load_pickle(self.model_dir, "label2id.pkl")
model_args["num_classes"] = len(self.label_vocab)

# Construct the model
model = model_class(model_args)
print("Model constructed.")

# To do: framework independent
ModelLoader.load_pytorch(model, os.path.join(self.model_dir, FastNLP_MODEL_COLLECTION[model_name]["pickle"]))
print("Model weights loaded.")

self.model = model
self.infer_type = FastNLP_MODEL_COLLECTION[model_name]["type"]

print("Inference ready.")

def run(self, raw_input):
"""
Perform inference over given input using the loaded model.
:param raw_input: list of string. Each list is an input query.
:return results:
"""

infer = self._create_inference(self.model_dir)

# tokenize: list of string ---> 2-D list of string
infer_input = self.tokenize(raw_input, language="zh")

# create DataSet: 2-D list of strings ----> DataSet
infer_data = self._create_data_set(infer_input)

# DataSet ---> 2-D list of tags
results = infer.predict(self.model, infer_data)

# 2-D list of tags ---> list of final answers
outputs = self._make_output(results, infer_input)
return outputs

@staticmethod
def _get_model_class(file_class_name):
"""
Feature the class specified by <file_class_name>
:param file_class_name: str, contains the name of the Python module followed by the name of the class.
Example: "sequence_modeling.SeqLabeling"
:return module: the model class
"""
import_prefix = "fastNLP.models."
parts = (import_prefix + file_class_name).split(".")
from_module = ".".join(parts[:-1])
module = __import__(from_module)
for sub in parts[1:]:
module = getattr(module, sub)
return module

def _create_inference(self, model_dir):
"""Specify which task to perform.

:param model_dir:
:return:
"""
if self.infer_type == "seq_label":
return SeqLabelInfer(model_dir)
elif self.infer_type == "text_class":
return ClassificationInfer(model_dir)
else:
raise ValueError("fail to create inference instance")

def _create_data_set(self, infer_input):
"""Create a DataSet object given the raw inputs.

:param infer_input: 2-D lists of strings
:return data_set: a DataSet object
"""
if self.infer_type in ["seq_label", "text_class"]:
data_set = convert_seq_dataset(infer_input)
data_set.index_field("word_seq", self.word_vocab)
if self.infer_type == "seq_label":
data_set.set_origin_len("word_seq")
return data_set
else:
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))


def _load(self, model_dir, model_name):

return 0

def _download(self, model_name, url):
"""
Download the model weights from <url> and save in <self.model_dir>.
:param model_name:
:param url:
"""
print("Downloading {} from {}".format(model_name, url))
# TODO: download model via url

def model_exist(self, model_dir):
"""
Check whether the desired model is already in the directory.
:param model_dir:
"""
return True

def tokenize(self, text, language):
"""Extract tokens from strings.
For English, extract words separated by space.
For Chinese, extract characters.
TODO: more complex tokenization methods

:param text: list of string
:param language: str, one of ('zh', 'en'), Chinese or English.
:return data: list of list of string, each string is a token.
"""
assert language in ("zh", "en")
data = []
for sent in text:
if language == "en":
tokens = sent.strip().split()
elif language == "zh":
tokens = [char for char in sent]
else:
raise RuntimeError("Unknown language {}".format(language))
data.append(tokens)
return data

def _make_output(self, results, infer_input):
"""Transform the infer output into user-friendly output.

:param results: 1 or 2-D list of strings.
If self.infer_type == "seq_label", it is of shape [num_examples, tag_seq_length]
If self.infer_type == "text_class", it is of shape [num_examples]
:param infer_input: 2-D list of string, the input query before inference.
:return outputs: list. Each entry is a prediction.
"""
if self.infer_type == "seq_label":
outputs = make_seq_label_output(results, infer_input)
elif self.infer_type == "text_class":
outputs = make_class_output(results, infer_input)
else:
raise RuntimeError("fail to make outputs with infer type {}".format(self.infer_type))
return outputs


def make_seq_label_output(result, infer_input):
"""Transform model output into user-friendly contents.

:param result: 2-D list of strings. (model output)
:param infer_input: 2-D list of string (model input)
:return ret: list of list of tuples
[
[(word_11, label_11), (word_12, label_12), ...],
[(word_21, label_21), (word_22, label_22), ...],
...
]
"""
ret = []
for example_x, example_y in zip(infer_input, result):
ret.append([(x, y) for x, y in zip(example_x, example_y)])
return ret

def make_class_output(result, infer_input):
"""Transform model output into user-friendly contents.

:param result: 2-D list of strings. (model output)
:param infer_input: 1-D list of string (model input)
:return ret: the same as result, [label_1, label_2, ...]
"""
return result


def interpret_word_seg_results(char_seq, label_seq):
"""Transform model output into user-friendly contents.

Example: In CWS, convert <BMES> labeling into segmented text.
:param char_seq: list of string,
:param label_seq: list of string, the same length as char_seq
Each entry is one of ('B', 'M', 'E', 'S').
:return output: list of words
"""
words = []
word = ""
for char, label in zip(char_seq, label_seq):
if label[0] == "B":
if word != "":
words.append(word)
word = char
elif label[0] == "M":
word += char
elif label[0] == "E":
word += char
words.append(word)
word = ""
elif label[0] == "S":
if word != "":
words.append(word)
word = ""
words.append(char)
else:
raise ValueError("invalid label {}".format(label[0]))
return words


def interpret_cws_pos_results(char_seq, label_seq):
"""Transform model output into user-friendly contents.

:param char_seq: list of string
:param label_seq: list of string, the same length as char_seq.
:return outputs: list of tuple (words, pos_tag):
"""

def pos_tag_check(seq):
"""check whether all entries are the same """
return len(set(seq)) <= 1

word = []
word_pos = []
outputs = []
for char, label in zip(char_seq, label_seq):
tmp = label.split("-")
cws_label, pos_tag = tmp[0], tmp[1]

if cws_label == "B" or cws_label == "M":
word.append(char)
word_pos.append(pos_tag)
elif cws_label == "E":
word.append(char)
word_pos.append(pos_tag)
if not pos_tag_check(word_pos):
raise RuntimeError("character-wise pos tags inconsistent. ")
outputs.append(("".join(word), word_pos[0]))
word.clear()
word_pos.clear()
elif cws_label == "S":
outputs.append((char, pos_tag))
return outputs

fastNLP/loader/__init__.py → fastNLP/io/__init__.py View File


fastNLP/loader/base_loader.py → fastNLP/io/base_loader.py View File


fastNLP/loader/config_loader.py → fastNLP/io/config_loader.py View File

@@ -2,7 +2,7 @@ import configparser
import json
import os

from fastNLP.loader.base_loader import BaseLoader
from fastNLP.io.base_loader import BaseLoader


class ConfigLoader(BaseLoader):

fastNLP/saver/config_saver.py → fastNLP/io/config_saver.py View File

@@ -1,7 +1,7 @@
import os

from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.saver.logger import create_logger
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.logger import create_logger


class ConfigSaver(object):

fastNLP/loader/dataset_loader.py → fastNLP/io/dataset_loader.py View File

@@ -3,7 +3,7 @@ import os
from fastNLP.core.dataset import DataSet
from fastNLP.core.field import *
from fastNLP.core.instance import Instance
from fastNLP.loader.base_loader import BaseLoader
from fastNLP.io.base_loader import BaseLoader


def convert_seq_dataset(data):

fastNLP/loader/embed_loader.py → fastNLP/io/embed_loader.py View File

@@ -1,10 +1,7 @@
import _pickle
import os

import torch

from fastNLP.loader.base_loader import BaseLoader
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import BaseLoader


class EmbedLoader(BaseLoader):

fastNLP/saver/logger.py → fastNLP/io/logger.py View File


fastNLP/loader/model_loader.py → fastNLP/io/model_loader.py View File

@@ -1,6 +1,6 @@
import torch

from fastNLP.loader.base_loader import BaseLoader
from fastNLP.io.base_loader import BaseLoader


class ModelLoader(BaseLoader):
@@ -19,10 +19,10 @@ class ModelLoader(BaseLoader):
:param model_path: str, the path to the saved model.
"""
empty_model.load_state_dict(torch.load(model_path))
@staticmethod
def load_pytorch(model_path):
def load_pytorch_model(model_path):
"""Load the entire model.

"""
return torch.load(model_path)
return torch.load(model_path)

fastNLP/saver/model_saver.py → fastNLP/io/model_saver.py View File


+ 3
- 1
fastNLP/modules/dropout.py View File

@@ -1,13 +1,15 @@
import torch


class TimestepDropout(torch.nn.Dropout):
"""This module accepts a `[batch_size, num_timesteps, embedding_dim)]` and use a single
dropout mask of shape `(batch_size, embedding_dim)` to apply on every time step.
"""

def forward(self, x):
dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
dropout_mask = dropout_mask.unsqueeze(1) # [batch_size, 1, embedding_dim]
if self.inplace:
x *= dropout_mask
return


+ 2
- 4
reproduction/Biaffine_parser/infer.py View File

@@ -1,13 +1,11 @@
import sys
import os
import sys

sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])

from fastNLP.api.processor import *
from fastNLP.api.pipeline import Pipeline
from fastNLP.core.dataset import DataSet
from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.config_loader import ConfigSection, ConfigLoader

import _pickle as pickle
import torch


reproduction/Biaffine_parser/run_test.py → reproduction/Biaffine_parser/main.py View File

@@ -1,11 +1,9 @@
import sys
import os

sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])

import torch
import argparse
import numpy as np

from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.core.dataset import DataSet

+ 4
- 9
reproduction/Biaffine_parser/run.py View File

@@ -3,8 +3,6 @@ import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from collections import defaultdict
import math
import torch
import re

@@ -13,16 +11,13 @@ from fastNLP.core.metrics import Evaluator
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.field import TextField, SeqLabelField
from fastNLP.core.preprocess import load_pickle
from fastNLP.core.tester import Tester
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.loader.embed_loader import EmbedLoader
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.model_loader import ModelLoader
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.models.biaffine_parser import BiaffineParser
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.io.model_saver import ModelSaver

BOS = '<BOS>'
EOS = '<EOS>'


+ 4
- 4
reproduction/LSTM+self_attention_sentiment_analysis/main.py View File

@@ -1,10 +1,10 @@
import torch.nn.functional as F

from fastNLP.core.preprocess import ClassPreprocess as Preprocess
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.config_loader import ConfigLoader
from fastNLP.loader.config_loader import ConfigSection
from fastNLP.loader.dataset_loader import ClassDataSetLoader as Dataset_loader
from fastNLP.core.utils import ClassPreprocess as Preprocess
from fastNLP.io.config_loader import ConfigLoader
from fastNLP.io.config_loader import ConfigSection
from fastNLP.io.dataset_loader import ClassDataSetLoader as Dataset_loader
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.aggregator.self_attention import SelfAttention
from fastNLP.modules.decoder.MLP import MLP


+ 2
- 2
reproduction/chinese_word_segment/cws_io/cws_reader.py View File

@@ -1,8 +1,8 @@


from fastNLP.loader.dataset_loader import DataSetLoader
from fastNLP.core.instance import Instance
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.io.dataset_loader import DataSetLoader


def cut_long_sentence(sent, max_sample_length=200):


+ 6
- 7
reproduction/chinese_word_segment/run.py View File

@@ -3,17 +3,16 @@ import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import BaseLoader, TokenizeDataSetLoader
from fastNLP.core.preprocess import load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.io.dataset_loader import BaseLoader, TokenizeDataSetLoader
from fastNLP.core.utils import load_pickle
from fastNLP.io.model_saver import ModelSaver
from fastNLP.io.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.predictor import SeqLabelInfer
from fastNLP.core.dataset import DataSet
from fastNLP.core.preprocess import save_pickle
from fastNLP.core.utils import save_pickle
from fastNLP.core.metrics import SeqLabelEvaluator

# not in the file's dir


+ 2
- 2
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -13,8 +13,8 @@ from fastNLP.core.instance import Instance
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.trainer import Trainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.dataset_loader import PeopleDailyCorpusLoader
from fastNLP.models.sequence_modeling import AdvSeqLabel




+ 1
- 1
test/core/test_dataset.py View File

@@ -1,6 +1,6 @@
import unittest

from fastNLP.loader.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset
from fastNLP.io.dataset_loader import convert_seq2seq_dataset, convert_seq_dataset


class TestDataSet(unittest.TestCase):


+ 2
- 4
test/core/test_predictor.py View File

@@ -1,12 +1,10 @@
import os
import unittest

from fastNLP.core.dataset import DataSet
from fastNLP.core.predictor import Predictor
from fastNLP.core.preprocess import save_pickle
from fastNLP.core.utils import save_pickle
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.base_loader import BaseLoader
from fastNLP.loader.dataset_loader import convert_seq_dataset
from fastNLP.io.dataset_loader import convert_seq_dataset
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.models.sequence_modeling import SeqLabeling



fastNLP/saver/__init__.py → test/io/__init__.py View File


test/loader/config → test/io/config View File


test/loader/test_config_loader.py → test/io/test_config_loader.py View File

@@ -3,7 +3,7 @@ import json
import os
import unittest

from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.config_loader import ConfigSection, ConfigLoader


class TestConfigLoader(unittest.TestCase):

test/saver/test_config_saver.py → test/io/test_config_saver.py View File

@@ -1,8 +1,8 @@
import os
import unittest

from fastNLP.loader.config_loader import ConfigSection, ConfigLoader
from fastNLP.saver.config_saver import ConfigSaver
from fastNLP.io.config_loader import ConfigSection, ConfigLoader
from fastNLP.io.config_saver import ConfigSaver


class TestConfigSaver(unittest.TestCase):

test/loader/test_dataset_loader.py → test/io/test_dataset_loader.py View File

@@ -1,9 +1,9 @@
import os
import unittest

from fastNLP.loader.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
PeopleDailyCorpusLoader, ConllLoader
from fastNLP.core.dataset import DataSet
from fastNLP.io.dataset_loader import POSDataSetLoader, LMDataSetLoader, TokenizeDataSetLoader, \
PeopleDailyCorpusLoader, ConllLoader


class TestDatasetLoader(unittest.TestCase):
def test_case_1(self):

test/loader/test_embed_loader.py → test/io/test_embed_loader.py View File

@@ -1,10 +1,8 @@
import unittest
import os
import unittest

import torch

from fastNLP.loader.embed_loader import EmbedLoader
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.embed_loader import EmbedLoader


class TestEmbedLoader(unittest.TestCase):

+ 5
- 5
test/model/seq_labeling.py View File

@@ -3,17 +3,17 @@ import sys

sys.path.append("..")
import argparse
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import BaseLoader
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.dataset_loader import BaseLoader
from fastNLP.io.model_saver import ModelSaver
from fastNLP.io.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.core.predictor import SeqLabelInfer
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.dataset import SeqLabelDataSet, change_field_is_target
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.preprocess import save_pickle, load_pickle
from fastNLP.core.utils import save_pickle, load_pickle

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./seq_label/", help="path to save pickle files")


+ 6
- 7
test/model/test_cws.py View File

@@ -1,17 +1,16 @@
import os

from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.predictor import SeqLabelInfer
from fastNLP.core.preprocess import save_pickle, load_pickle
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader, BaseLoader, RawDataSetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.utils import save_pickle, load_pickle
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.dataset_loader import TokenizeDataSetLoader, RawDataSetLoader
from fastNLP.io.model_loader import ModelLoader
from fastNLP.io.model_saver import ModelSaver
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver

data_name = "pku_training.utf8"
cws_data_path = "./test/data_for_tests/cws_pku_utf_8"


+ 5
- 5
test/model/test_seq_label.py View File

@@ -2,15 +2,15 @@ import os

from fastNLP.core.metrics import SeqLabelEvaluator
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.preprocess import save_pickle
from fastNLP.core.tester import SeqLabelTester
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.core.utils import save_pickle
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import TokenizeDataSetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.dataset_loader import TokenizeDataSetLoader
from fastNLP.io.model_loader import ModelLoader
from fastNLP.io.model_saver import ModelSaver
from fastNLP.models.sequence_modeling import SeqLabeling
from fastNLP.saver.model_saver import ModelSaver

pickle_path = "./seq_label/"
model_name = "seq_label_model.pkl"


+ 5
- 5
test/model/text_classify.py View File

@@ -8,15 +8,15 @@ import sys
sys.path.append("..")
from fastNLP.core.predictor import ClassificationInfer
from fastNLP.core.trainer import ClassificationTrainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.loader.dataset_loader import ClassDataSetLoader
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.io.config_loader import ConfigLoader, ConfigSection
from fastNLP.io.dataset_loader import ClassDataSetLoader
from fastNLP.io.model_loader import ModelLoader
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.io.model_saver import ModelSaver
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.loss import Loss
from fastNLP.core.dataset import TextClassifyDataSet
from fastNLP.core.preprocess import save_pickle, load_pickle
from fastNLP.core.utils import save_pickle, load_pickle

parser = argparse.ArgumentParser()
parser.add_argument("-s", "--save", type=str, default="./test_classification/", help="path to save pickle files")


+ 0
- 213
test/test_fastNLP.py View File

@@ -1,213 +0,0 @@
# encoding: utf-8
import os

from fastNLP.core.preprocess import save_pickle
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.fastnlp import FastNLP
from fastNLP.fastnlp import interpret_word_seg_results, interpret_cws_pos_results
from fastNLP.models.cnn_text_classification import CNNText
from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.saver.model_saver import ModelSaver

PATH_TO_CWS_PICKLE_FILES = "/home/zyfeng/fastNLP/reproduction/chinese_word_segment/save/"
PATH_TO_POS_TAG_PICKLE_FILES = "/home/zyfeng/data/crf_seg/"
PATH_TO_TEXT_CLASSIFICATION_PICKLE_FILES = "/home/zyfeng/data/text_classify/"

DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
DEFAULT_RESERVED_LABEL = ['<reserved-2>',
'<reserved-3>',
'<reserved-4>'] # dict index = 2~4

DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1,
DEFAULT_RESERVED_LABEL[0]: 2, DEFAULT_RESERVED_LABEL[1]: 3,
DEFAULT_RESERVED_LABEL[2]: 4}


def word_seg(model_dir, config, section):
nlp = FastNLP(model_dir=model_dir)
nlp.load("cws_basic_model", config_file=config, section_name=section)
text = ["这是最好的基于深度学习的中文分词系统。",
"大王叫我来巡山。",
"我党多年来致力于改善人民生活水平。"]
results = nlp.run(text)
print(results)
for example in results:
words, labels = [], []
for res in example:
words.append(res[0])
labels.append(res[1])
print(interpret_word_seg_results(words, labels))


def mock_cws():
os.makedirs("mock", exist_ok=True)
text = ["这是最好的基于深度学习的中文分词系统。",
"大王叫我来巡山。",
"我党多年来致力于改善人民生活水平。"]

word2id = Vocabulary()
word_list = [ch for ch in "".join(text)]
word2id.update(word_list)
save_pickle(word2id, "./mock/", "word2id.pkl")

class2id = Vocabulary(need_default=False)
label_list = ['B', 'M', 'E', 'S']
class2id.update(label_list)
save_pickle(class2id, "./mock/", "label2id.pkl")

model_args = {"vocab_size": len(word2id), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(class2id)}
config_file = """
[test_section]
vocab_size = {}
word_emb_dim = 50
rnn_hidden_units = 50
num_classes = {}
""".format(len(word2id), len(class2id))
with open("mock/test.cfg", "w", encoding="utf-8") as f:
f.write(config_file)

model = AdvSeqLabel(model_args)
ModelSaver("mock/cws_basic_model_v_0.pkl").save_pytorch(model)


def test_word_seg():
# fake the model and pickles
print("start mocking")
mock_cws()
# run the inference codes
print("start testing")
word_seg("./mock/", "test.cfg", "test_section")
# clean up environments
print("clean up")
os.system("rm -rf mock")


def pos_tag(model_dir, config, section):
nlp = FastNLP(model_dir=model_dir)
nlp.load("pos_tag_model", config_file=config, section_name=section)
text = ["这是最好的基于深度学习的中文分词系统。",
"大王叫我来巡山。",
"我党多年来致力于改善人民生活水平。"]
results = nlp.run(text)
for example in results:
words, labels = [], []
for res in example:
words.append(res[0])
labels.append(res[1])
try:
print(interpret_cws_pos_results(words, labels))
except RuntimeError:
print("inconsistent pos tags. this is for test only.")


def mock_pos_tag():
os.makedirs("mock", exist_ok=True)
text = ["这是最好的基于深度学习的中文分词系统。",
"大王叫我来巡山。",
"我党多年来致力于改善人民生活水平。"]

vocab = Vocabulary()
word_list = [ch for ch in "".join(text)]
vocab.update(word_list)
save_pickle(vocab, "./mock/", "word2id.pkl")

idx2label = Vocabulary(need_default=False)
label_list = ['B-n', 'M-v', 'E-nv', 'S-adj', 'B-v', 'M-vn', 'S-adv']
idx2label.update(label_list)
save_pickle(idx2label, "./mock/", "label2id.pkl")

model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
config_file = """
[test_section]
vocab_size = {}
word_emb_dim = 50
rnn_hidden_units = 50
num_classes = {}
""".format(len(vocab), len(idx2label))
with open("mock/test.cfg", "w", encoding="utf-8") as f:
f.write(config_file)

model = AdvSeqLabel(model_args)
ModelSaver("mock/pos_tag_model_v_0.pkl").save_pytorch(model)


def test_pos_tag():
mock_pos_tag()
pos_tag("./mock/", "test.cfg", "test_section")
os.system("rm -rf mock")


def text_classify(model_dir, config, section):
nlp = FastNLP(model_dir=model_dir)
nlp.load("text_classify_model", config_file=config, section_name=section)
text = [
"世界物联网大会明日在京召开龙头股启动在即",
"乌鲁木齐市新增一处城市中心旅游目的地",
"朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"]
results = nlp.run(text)
print(results)


def mock_text_classify():
os.makedirs("mock", exist_ok=True)
text = ["世界物联网大会明日在京召开龙头股启动在即",
"乌鲁木齐市新增一处城市中心旅游目的地",
"朱元璋的大明朝真的源于明教吗?——告诉你一个真实的“明教”"
]
vocab = Vocabulary()
word_list = [ch for ch in "".join(text)]
vocab.update(word_list)
save_pickle(vocab, "./mock/", "word2id.pkl")

idx2label = Vocabulary(need_default=False)
label_list = ['class_A', 'class_B', 'class_C', 'class_D', 'class_E', 'class_F']
idx2label.update(label_list)
save_pickle(idx2label, "./mock/", "label2id.pkl")

model_args = {"vocab_size": len(vocab), "word_emb_dim": 50, "rnn_hidden_units": 50, "num_classes": len(idx2label)}
config_file = """
[test_section]
vocab_size = {}
word_emb_dim = 50
rnn_hidden_units = 50
num_classes = {}
""".format(len(vocab), len(idx2label))
with open("mock/test.cfg", "w", encoding="utf-8") as f:
f.write(config_file)

model = CNNText(model_args)
ModelSaver("mock/text_class_model_v0.pkl").save_pytorch(model)


def test_text_classify():
mock_text_classify()
text_classify("./mock/", "test.cfg", "test_section")
os.system("rm -rf mock")


def test_word_seg_interpret():
foo = [[('这', 'S'), ('是', 'S'), ('最', 'S'), ('好', 'S'), ('的', 'S'), ('基', 'B'), ('于', 'E'), ('深', 'B'), ('度', 'E'),
('学', 'B'), ('习', 'E'), ('的', 'S'), ('中', 'B'), ('文', 'E'), ('分', 'B'), ('词', 'E'), ('系', 'B'), ('统', 'E'),
('。', 'S')]]
chars = [x[0] for x in foo[0]]
labels = [x[1] for x in foo[0]]
print(interpret_word_seg_results(chars, labels))


def test_interpret_cws_pos_results():
foo = [
[('这', 'S-r'), ('是', 'S-v'), ('最', 'S-d'), ('好', 'S-a'), ('的', 'S-u'), ('基', 'B-p'), ('于', 'E-p'), ('深', 'B-d'),
('度', 'E-d'), ('学', 'B-v'), ('习', 'E-v'), ('的', 'S-u'), ('中', 'B-nz'), ('文', 'E-nz'), ('分', 'B-vn'),
('词', 'E-vn'), ('系', 'B-n'), ('统', 'E-n'), ('。', 'S-w')]
]
chars = [x[0] for x in foo[0]]
labels = [x[1] for x in foo[0]]
print(interpret_cws_pos_results(chars, labels))

if __name__ == "__main__":
test_word_seg()
test_pos_tag()
test_text_classify()
test_word_seg_interpret()
test_interpret_cws_pos_results()

Loading…
Cancel
Save