Browse Source

- 修改pos tag训练脚本,可以跑

- 在api中创建converter.py
- Pipeline添加初始化方法,方便一次性添加processors
- 删除pos_tagger.py
- 优化整体code style
tags/v0.2.0
FengZiYjun 5 years ago
parent
commit
26e3abdf58
12 changed files with 330 additions and 209 deletions
  1. +182
    -0
      fastNLP/api/converter.py
  2. +12
    -4
      fastNLP/api/pipeline.py
  3. +0
    -44
      fastNLP/api/pos_tagger.py
  4. +22
    -5
      fastNLP/api/processor.py
  5. +0
    -3
      fastNLP/core/batch.py
  6. +27
    -37
      fastNLP/core/dataset.py
  7. +2
    -2
      fastNLP/core/instance.py
  8. +3
    -2
      fastNLP/loader/dataset_loader.py
  9. +7
    -1
      fastNLP/models/sequence_modeling.py
  10. +13
    -11
      fastNLP/modules/decoder/CRF.py
  11. +5
    -3
      reproduction/pos_tag_model/pos_tag.cfg
  12. +57
    -97
      reproduction/pos_tag_model/train_pos_tag.py

+ 182
- 0
fastNLP/api/converter.py View File

@@ -0,0 +1,182 @@
import re


class SpanConverter:
def __init__(self, replace_tag, pattern):
super(SpanConverter, self).__init__()

self.replace_tag = replace_tag
self.pattern = pattern

def find_certain_span_and_replace(self, sentence):
replaced_sentence = ''
prev_end = 0
for match in re.finditer(self.pattern, sentence):
start, end = match.span()
span = sentence[start:end]
replaced_sentence += sentence[prev_end:start] + \
self.span_to_special_tag(span)
prev_end = end
replaced_sentence += sentence[prev_end:]

return replaced_sentence

def span_to_special_tag(self, span):

return self.replace_tag

def find_certain_span(self, sentence):
spans = []
for match in re.finditer(self.pattern, sentence):
spans.append(match.span())
return spans


class AlphaSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<ALPHA>'
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag).
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])'

super(AlphaSpanConverter, self).__init__(replace_tag, pattern)


class DigitSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<NUM>'
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])'

super(DigitSpanConverter, self).__init__(replace_tag, pattern)

def span_to_special_tag(self, span):
# return self.special_tag
if span[0] == '0' and len(span) > 2:
return '<NUM>'
decimal_point_count = 0 # one might have more than one decimal pointers
for idx, char in enumerate(span):
if char == '.' or char == '﹒' or char == '·':
decimal_point_count += 1
if span[-1] == '.' or span[-1] == '﹒' or span[
-1] == '·': # last digit being decimal point means this is not a number
if decimal_point_count == 1:
return span
else:
return '<UNKDGT>'
if decimal_point_count == 1:
return '<DEC>'
elif decimal_point_count > 1:
return '<UNKDGT>'
else:
return '<NUM>'


class TimeConverter(SpanConverter):
def __init__(self):
replace_tag = '<TOC>'
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])'

super().__init__(replace_tag, pattern)


class MixNumAlphaConverter(SpanConverter):
def __init__(self):
replace_tag = '<MIX>'
pattern = None

super().__init__(replace_tag, pattern)

def find_certain_span_and_replace(self, sentence):
replaced_sentence = ''
start = 0
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
for idx in range(len(sentence)):
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]):
if not matching_flag:
replaced_sentence += sentence[start:idx]
start = idx
if re.match('[0-9]', sentence[idx]):
number_flag = True
elif re.match('[\'′&\\-]', sentence[idx]):
link_flag = True
elif re.match('/', sentence[idx]):
slash_flag = True
elif re.match('[\\(\\)]', sentence[idx]):
bracket_flag = True
else:
alpha_flag = True
matching_flag = True
elif re.match('[\\.]', sentence[idx]):
pass
else:
if matching_flag:
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \
or (slash_flag and alpha_flag) or (link_flag and number_flag) \
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag):
span = sentence[start:idx]
start = idx
replaced_sentence += self.span_to_special_tag(span)
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False

replaced_sentence += sentence[start:]
return replaced_sentence

def find_certain_span(self, sentence):
spans = []
start = 0
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
for idx in range(len(sentence)):
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]):
if not matching_flag:
start = idx
if re.match('[0-9]', sentence[idx]):
number_flag = True
elif re.match('[\'′&\\-]', sentence[idx]):
link_flag = True
elif re.match('/', sentence[idx]):
slash_flag = True
elif re.match('[\\(\\)]', sentence[idx]):
bracket_flag = True
else:
alpha_flag = True
matching_flag = True
elif re.match('[\\.]', sentence[idx]):
pass
else:
if matching_flag:
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \
or (slash_flag and alpha_flag) or (link_flag and number_flag) \
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag):
spans.append((start, idx))
start = idx

matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False

return spans


class EmailConverter(SpanConverter):
def __init__(self):
replaced_tag = "<EML>"
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])'

super(EmailConverter, self).__init__(replaced_tag, pattern)

+ 12
- 4
fastNLP/api/pipeline.py View File

@@ -1,17 +1,25 @@
from fastNLP.api.processor import Processor



class Pipeline:
def __init__(self):
"""
Pipeline takes a DataSet object as input, runs multiple processors sequentially, and
outputs a DataSet object.
"""

def __init__(self, processors=None):
self.pipeline = []
if isinstance(processors, list):
for proc in processors:
assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(processor))
self.pipeline = processors

def add_processor(self, processor):
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
self.pipeline.append(processor)

def process(self, dataset):
assert len(self.pipeline)!=0, "You need to add some processor first."
assert len(self.pipeline) != 0, "You need to add some processor first."

for proc_name, proc in self.pipeline:
dataset = proc(dataset)
@@ -19,4 +27,4 @@ class Pipeline:
return dataset

def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)
return self.process(*args, **kwargs)

+ 0
- 44
fastNLP/api/pos_tagger.py View File

@@ -1,44 +0,0 @@
import pickle

import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.predictor import Predictor


class POS_tagger:
def __init__(self):
pass

def predict(self, query):
"""
:param query: List[str]
:return answer: List[str]

"""
# TODO: 根据query 构建DataSet
pos_dataset = DataSet()
pos_dataset["text_field"] = np.array(query)

# 加载pipeline和model
pipeline = self.load_pipeline("./xxxx")

# 将DataSet作为参数运行 pipeline
pos_dataset = pipeline(pos_dataset)

# 加载模型
model = ModelLoader().load_pytorch("./xxx")

# 调 predictor
predictor = Predictor()
output = predictor.predict(model, pos_dataset)

# TODO: 转成最终输出
return None

@staticmethod
def load_pipeline(path):
with open(path, "r") as fp:
pipeline = pickle.load(fp)
return pipeline

+ 22
- 5
fastNLP/api/processor.py View File

@@ -1,7 +1,7 @@

from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary


class Processor:
def __init__(self, field_name, new_added_field_name):
self.field_name = field_name
@@ -10,15 +10,18 @@ class Processor:
else:
self.new_added_field_name = new_added_field_name

def process(self):
def process(self, *args, **kwargs):
pass

def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)



class FullSpaceToHalfSpaceProcessor(Processor):
"""全角转半角,以字符为处理单元

"""

def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
change_space=True):
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)
@@ -64,11 +67,12 @@ class FullSpaceToHalfSpaceProcessor(Processor):
if self.change_space:
FHs += FH_SPACE
self.convert_map = {k: v for k, v in FHs}

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
sentence = ins[self.field_name]
new_sentence = [None]*len(sentence)
new_sentence = [None] * len(sentence)
for idx, char in enumerate(sentence):
if char in self.convert_map:
char = self.convert_map[char]
@@ -98,7 +102,7 @@ class IndexerProcessor(Processor):
index = [self.vocab.to_index(token) for token in tokens]
ins[self.new_added_field_name] = index

dataset.set_need_tensor(**{self.new_added_field_name:True})
dataset.set_need_tensor(**{self.new_added_field_name: True})

if self.delete_old_field:
dataset.delete_field(self.field_name)
@@ -122,3 +126,16 @@ class VocabProcessor(Processor):
def get_vocab(self):
self.vocab.build_vocab()
return self.vocab


class SeqLenProcessor(Processor):
def __init__(self, field_name, new_added_field_name='seq_lens'):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
length = len(ins[self.field_name])
ins[self.new_added_field_name] = length
dataset.set_need_tensor(**{self.new_added_field_name: True})
return dataset

+ 0
- 3
fastNLP/core/batch.py View File

@@ -1,5 +1,3 @@
from collections import defaultdict

import torch


@@ -68,4 +66,3 @@ class Batch(object):
self.curidx = endidx

return batch_x, batch_y


+ 27
- 37
fastNLP/core/dataset.py View File

@@ -1,23 +1,27 @@
import random
import sys, os
sys.path.append('../..')
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path

from collections import defaultdict
from copy import deepcopy
import numpy as np

from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.fieldarray import FieldArray

_READERS = {}


def construct_dataset(sentences):
"""Construct a data set from a list of sentences.

:param sentences: list of str
:return dataset: a DataSet object
"""
dataset = DataSet()
for sentence in sentences:
instance = Instance()
instance['raw_sentence'] = sentence
dataset.append(instance)
return dataset


class DataSet(object):
"""A DataSet object is a list of Instance objects.

"""

class DataSetIter(object):
def __init__(self, dataset):
self.dataset = dataset
@@ -34,13 +38,12 @@ class DataSet(object):

def __setitem__(self, name, val):
if name not in self.dataset:
new_fields = [None]*len(self.dataset)
new_fields = [None] * len(self.dataset)
self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val

def __repr__(self):
# TODO
pass
return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset])

def __init__(self, instance=None):
self.field_arrays = {}
@@ -72,7 +75,7 @@ class DataSet(object):
self.field_arrays[name].append(field)

def add_field(self, name, fields):
if len(self.field_arrays)!=0:
if len(self.field_arrays) != 0:
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields)

@@ -90,27 +93,10 @@ class DataSet(object):
return len(field)

def get_length(self):
"""Fetch lengths of all fields in all instances in a dataset.

:return lengths: dict of (str: list). The str is the field name.
The list contains lengths of this field in all instances.

"""
pass

def shuffle(self):
pass

def split(self, ratio, shuffle=True):
"""Train/dev splitting

:param ratio: float, between 0 and 1. The ratio of development set in origin data set.
:param shuffle: bool, whether shuffle the data set before splitting. Default: True.
:return train_set: a DataSet object, representing the training set
dev_set: a DataSet object, representing the validation set
"""The same as __len__

"""
pass
return len(self)

def rename_field(self, old_name, new_name):
"""rename a field
@@ -118,7 +104,7 @@ class DataSet(object):
if old_name in self.field_arrays:
self.field_arrays[new_name] = self.field_arrays.pop(old_name)
else:
raise KeyError
raise KeyError("{} is not a valid name. ".format(old_name))
return self

def set_is_target(self, **fields):
@@ -150,6 +136,7 @@ class DataSet(object):
data = _READERS[name]().load(*args, **kwargs)
self.extend(data)
return self

return _read
else:
return object.__getattribute__(self, name)
@@ -159,18 +146,21 @@ class DataSet(object):
"""decorator to add dataloader support
"""
assert isinstance(method_name, str)

def wrapper(read_cls):
_READERS[method_name] = read_cls
return read_cls

return wrapper


if __name__ == '__main__':
from fastNLP.core.instance import Instance

ins = Instance(test='test0')
dataset = DataSet([ins])
for _iter in dataset:
print(_iter['test'])
_iter['test'] = 'abc'
print(_iter['test'])
print(dataset.field_arrays)
print(dataset.field_arrays)

+ 2
- 2
fastNLP/core/instance.py View File

@@ -1,4 +1,4 @@
import torch

class Instance(object):
"""An instance which consists of Fields is an example in the DataSet.
@@ -35,4 +35,4 @@ class Instance(object):
return self.add_field(name, field)

def __repr__(self):
return self.fields.__repr__()
return self.fields.__repr__()

+ 3
- 2
fastNLP/loader/dataset_loader.py View File

@@ -1,9 +1,9 @@
import os

from fastNLP.loader.base_loader import BaseLoader
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.field import *
from fastNLP.core.instance import Instance
from fastNLP.loader.base_loader import BaseLoader


def convert_seq_dataset(data):
@@ -393,6 +393,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
sent_words.append(token)
pos_tag_examples.append([sent_words, sent_pos_tag])
ner_examples.append([sent_words, sent_ner])
# List[List[List[str], List[str]]]
return pos_tag_examples, ner_examples

def convert(self, data):


+ 7
- 1
fastNLP/models/sequence_modeling.py View File

@@ -44,6 +44,9 @@ class SeqLabeling(BaseModel):
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training.
"""
assert word_seq.shape[0] == word_seq_origin_len.shape[0]
if truth is not None:
assert truth.shape == word_seq.shape
self.mask = self.make_mask(word_seq, word_seq_origin_len)

x = self.Embedding(word_seq)
@@ -80,7 +83,7 @@ class SeqLabeling(BaseModel):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len)
mask = mask.to(x)
mask = mask.to(x).float()
return mask

def decode(self, x, pad=True):
@@ -130,6 +133,9 @@ class AdvSeqLabel(SeqLabeling):
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training.
"""
word_seq = word_seq.long()
word_seq_origin_len = word_seq_origin_len.long()
truth = truth.long()
self.mask = self.make_mask(word_seq, word_seq_origin_len)

batch_size = word_seq.size(0)


+ 13
- 11
fastNLP/modules/decoder/CRF.py View File

@@ -3,6 +3,7 @@ from torch import nn

from fastNLP.modules.utils import initial_parameter


def log_sum_exp(x, dim=-1):
max_value, _ = x.max(dim=dim, keepdim=True)
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value
@@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens):


class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None):
def __init__(self, tag_size, include_start_end_trans=False, initial_method=None):
"""
:param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag
@@ -38,6 +39,7 @@ class ConditionalRandomField(nn.Module):

# self.reset_parameter()
initial_parameter(self, initial_method)

def reset_parameter(self):
nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans:
@@ -81,15 +83,15 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)

# trans_socre [L-1, B]
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :]
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]] * mask[1:, :]
# emit_score [L, B]
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags] * mask
# score [L-1, B]
score = trans_score + emit_score[:seq_len-1, :]
score = trans_score + emit_score[:seq_len - 1, :]
score = score.sum(0) + emit_score[-1]
if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = masks.long().sum(0)
last_idx = mask.long().sum(0)
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
score += st_scores + ed_scores
# return [B,]
@@ -120,14 +122,14 @@ class ConditionalRandomField(nn.Module):
:return: scores, paths
"""
batch_size, seq_len, n_tags = data.size()
data = data.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.float() # L, B
data = data.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.float() # L, B

# dp
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = data[0]
if self.include_start_end_trans:
vscore += self.start_scores.view(1. -1)
vscore += self.start_scores.view(1. - 1)
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = data[i].view(batch_size, 1, n_tags)
@@ -145,15 +147,15 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len

ans = data.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i+1], batch_idx] = last_tags
ans[idxes[i + 1], batch_idx] = last_tags

if get_score:
return ans_score, ans.transpose(0, 1)
return ans.transpose(0, 1)
return ans.transpose(0, 1)

+ 5
- 3
reproduction/pos_tag_model/pos_tag.cfg View File

@@ -1,10 +1,12 @@
[train]
epochs = 30
batch_size = 64
epochs = 5
batch_size = 2
pickle_path = "./save/"
validate = true
validate = false
save_best_dev = true
model_saved_path = "./save/"

[model]
rnn_hidden_units = 100
word_emb_dim = 100
use_crf = true


+ 57
- 97
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -1,130 +1,88 @@
import os
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import torch

from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.trainer import Trainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader
from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.predictor import SeqLabelInfer

# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))
datadir = "/home/zyfeng/data/"
cfgfile = './pos_tag.cfg'
data_name = "CWS_POS_TAG_NER_people_daily.txt"
datadir = "/home/zyfeng/fastnlp_0.2.0/test/data_for_tests/"
data_name = "people_daily_raw.txt"

pos_tag_data_path = os.path.join(datadir, data_name)
pickle_path = "save"
data_infer_path = os.path.join(datadir, "infer.utf8")


def infer():
# Config Loader
test_args = ConfigSection()
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "class2id.pkl")
test_args["num_classes"] = len(index2label)

# Define the same model
model = AdvSeqLabel(test_args)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model loaded!')
except Exception as e:
print('cannot load model!')
raise

# Data Loader
raw_data_loader = BaseLoader(data_infer_path)
infer_data = raw_data_loader.load_lines()
print('data loaded')

# Inference interface
infer = SeqLabelInfer(pickle_path)
results = infer.predict(model, infer_data)

print(results)
print("Inference finished!")


def train():
def train():
# load config
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args})
train_param = ConfigSection()
model_param = ConfigSection()
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param})
print("config loaded")

# Data Loader
loader = PeopleDailyCorpusLoader()
train_data, _ = loader.load()

# TODO: define processors
# define pipeline
pp = Pipeline()
# TODO: pp.add_processor()

# run the pipeline, get data_set
train_data = pp(train_data)
train_data, _ = loader.load(os.path.join(datadir, data_name))
print("data loaded")

dataset = DataSet()
for data in train_data:
instance = Instance()
instance["words"] = data[0]
instance["tag"] = data[1]
dataset.append(instance)
print("dataset transformed")

# processor_1 = FullSpaceToHalfSpaceProcessor('words')
# processor_1(dataset)
word_vocab_proc = VocabProcessor('words')
tag_vocab_proc = VocabProcessor("tag")
word_vocab_proc(dataset)
tag_vocab_proc(dataset)
word_indexer = IndexerProcessor(word_vocab_proc.get_vocab(), 'words', 'word_seq', delete_old_field=True)
word_indexer(dataset)
tag_indexer = IndexerProcessor(tag_vocab_proc.get_vocab(), 'tag', 'truth', delete_old_field=True)
tag_indexer(dataset)
seq_len_proc = SeqLenProcessor("word_seq", "word_seq_origin_len")
seq_len_proc(dataset)

print("processors defined")
# dataset.set_is_target(tag_ids=True)
model_param["vocab_size"] = len(word_vocab_proc.get_vocab())
model_param["num_classes"] = len(tag_vocab_proc.get_vocab())
print("vocab_size={} num_classes={}".format(len(word_vocab_proc.get_vocab()), len(tag_vocab_proc.get_vocab())))

# define a model
model = AdvSeqLabel(train_args)
model = AdvSeqLabel(model_param)

# call trainer to train
trainer = SeqLabelTrainer(train_args)
trainer.train(model, data_train, data_dev)

# save model
ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False)

# TODO:save pipeline
trainer = Trainer(**train_param.data)
trainer.train(model, dataset)

# save model & pipeline
pp = Pipeline([word_vocab_proc, word_indexer, seq_len_proc])
save_dict = {"pipeline": pp, "model": model}
torch.save(save_dict, "model_pp.pkl")


def test():
# Config Loader
test_args = ConfigSection()
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "class2id.pkl")
test_args["num_classes"] = len(index2label)

# load dev data
dev_data = load_pickle(pickle_path, "data_dev.pkl")

# Define the same model
model = AdvSeqLabel(test_args)
pass

# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print("model loaded!")

# Tester
tester = SeqLabelTester(**test_args.data)

# Start testing
tester.test(model, dev_data)

# print test results
print(tester.show_metrics())
print("model tested!")
def infer():
pass


if __name__ == "__main__":
train()
"""
import argparse

parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
@@ -139,3 +97,5 @@ if __name__ == "__main__":
else:
print('no mode specified for model!')
parser.print_help()

"""

Loading…
Cancel
Save