Browse Source

- 修改pos tag训练脚本,可以跑

- 在api中创建converter.py
- Pipeline添加初始化方法,方便一次性添加processors
- 删除pos_tagger.py
- 优化整体code style
tags/v0.2.0
FengZiYjun 5 years ago
parent
commit
26e3abdf58
12 changed files with 330 additions and 209 deletions
  1. +182
    -0
      fastNLP/api/converter.py
  2. +12
    -4
      fastNLP/api/pipeline.py
  3. +0
    -44
      fastNLP/api/pos_tagger.py
  4. +22
    -5
      fastNLP/api/processor.py
  5. +0
    -3
      fastNLP/core/batch.py
  6. +27
    -37
      fastNLP/core/dataset.py
  7. +2
    -2
      fastNLP/core/instance.py
  8. +3
    -2
      fastNLP/loader/dataset_loader.py
  9. +7
    -1
      fastNLP/models/sequence_modeling.py
  10. +13
    -11
      fastNLP/modules/decoder/CRF.py
  11. +5
    -3
      reproduction/pos_tag_model/pos_tag.cfg
  12. +57
    -97
      reproduction/pos_tag_model/train_pos_tag.py

+ 182
- 0
fastNLP/api/converter.py View File

@@ -0,0 +1,182 @@
import re


class SpanConverter:
def __init__(self, replace_tag, pattern):
super(SpanConverter, self).__init__()

self.replace_tag = replace_tag
self.pattern = pattern

def find_certain_span_and_replace(self, sentence):
replaced_sentence = ''
prev_end = 0
for match in re.finditer(self.pattern, sentence):
start, end = match.span()
span = sentence[start:end]
replaced_sentence += sentence[prev_end:start] + \
self.span_to_special_tag(span)
prev_end = end
replaced_sentence += sentence[prev_end:]

return replaced_sentence

def span_to_special_tag(self, span):

return self.replace_tag

def find_certain_span(self, sentence):
spans = []
for match in re.finditer(self.pattern, sentence):
spans.append(match.span())
return spans


class AlphaSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<ALPHA>'
# 理想状态下仅处理纯为字母的情况, 但不处理<[a-zA-Z]+>(因为这应该是特殊的tag).
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%.!<\\-"])'

super(AlphaSpanConverter, self).__init__(replace_tag, pattern)


class DigitSpanConverter(SpanConverter):
def __init__(self):
replace_tag = '<NUM>'
pattern = '\d[\d\\.]*(?=[\u4e00-\u9fff ,%.!<-])'

super(DigitSpanConverter, self).__init__(replace_tag, pattern)

def span_to_special_tag(self, span):
# return self.special_tag
if span[0] == '0' and len(span) > 2:
return '<NUM>'
decimal_point_count = 0 # one might have more than one decimal pointers
for idx, char in enumerate(span):
if char == '.' or char == '﹒' or char == '·':
decimal_point_count += 1
if span[-1] == '.' or span[-1] == '﹒' or span[
-1] == '·': # last digit being decimal point means this is not a number
if decimal_point_count == 1:
return span
else:
return '<UNKDGT>'
if decimal_point_count == 1:
return '<DEC>'
elif decimal_point_count > 1:
return '<UNKDGT>'
else:
return '<NUM>'


class TimeConverter(SpanConverter):
def __init__(self):
replace_tag = '<TOC>'
pattern = '\d+[::∶][\d::∶]+(?=[\u4e00-\u9fff ,%.!<-])'

super().__init__(replace_tag, pattern)


class MixNumAlphaConverter(SpanConverter):
def __init__(self):
replace_tag = '<MIX>'
pattern = None

super().__init__(replace_tag, pattern)

def find_certain_span_and_replace(self, sentence):
replaced_sentence = ''
start = 0
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
for idx in range(len(sentence)):
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]):
if not matching_flag:
replaced_sentence += sentence[start:idx]
start = idx
if re.match('[0-9]', sentence[idx]):
number_flag = True
elif re.match('[\'′&\\-]', sentence[idx]):
link_flag = True
elif re.match('/', sentence[idx]):
slash_flag = True
elif re.match('[\\(\\)]', sentence[idx]):
bracket_flag = True
else:
alpha_flag = True
matching_flag = True
elif re.match('[\\.]', sentence[idx]):
pass
else:
if matching_flag:
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \
or (slash_flag and alpha_flag) or (link_flag and number_flag) \
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag):
span = sentence[start:idx]
start = idx
replaced_sentence += self.span_to_special_tag(span)
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False

replaced_sentence += sentence[start:]
return replaced_sentence

def find_certain_span(self, sentence):
spans = []
start = 0
matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False
for idx in range(len(sentence)):
if re.match('[0-9a-zA-Z/\\(\\)\'′&\\-]', sentence[idx]):
if not matching_flag:
start = idx
if re.match('[0-9]', sentence[idx]):
number_flag = True
elif re.match('[\'′&\\-]', sentence[idx]):
link_flag = True
elif re.match('/', sentence[idx]):
slash_flag = True
elif re.match('[\\(\\)]', sentence[idx]):
bracket_flag = True
else:
alpha_flag = True
matching_flag = True
elif re.match('[\\.]', sentence[idx]):
pass
else:
if matching_flag:
if (number_flag and alpha_flag) or (link_flag and alpha_flag) \
or (slash_flag and alpha_flag) or (link_flag and number_flag) \
or (number_flag and bracket_flag) or (bracket_flag and alpha_flag):
spans.append((start, idx))
start = idx

matching_flag = False
number_flag = False
alpha_flag = False
link_flag = False
slash_flag = False
bracket_flag = False

return spans


class EmailConverter(SpanConverter):
def __init__(self):
replaced_tag = "<EML>"
pattern = '[0-9a-zA-Z]+[@][.﹒0-9a-zA-Z@]+(?=[\u4e00-\u9fff ,%.!<\\-"$])'

super(EmailConverter, self).__init__(replaced_tag, pattern)

+ 12
- 4
fastNLP/api/pipeline.py View File

@@ -1,17 +1,25 @@
from fastNLP.api.processor import Processor from fastNLP.api.processor import Processor





class Pipeline: class Pipeline:
def __init__(self):
"""
Pipeline takes a DataSet object as input, runs multiple processors sequentially, and
outputs a DataSet object.
"""

def __init__(self, processors=None):
self.pipeline = [] self.pipeline = []
if isinstance(processors, list):
for proc in processors:
assert isinstance(proc, Processor), "Must be a Processor, not {}.".format(type(processor))
self.pipeline = processors


def add_processor(self, processor): def add_processor(self, processor):
assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor)) assert isinstance(processor, Processor), "Must be a Processor, not {}.".format(type(processor))
self.pipeline.append(processor) self.pipeline.append(processor)


def process(self, dataset): def process(self, dataset):
assert len(self.pipeline)!=0, "You need to add some processor first."
assert len(self.pipeline) != 0, "You need to add some processor first."


for proc_name, proc in self.pipeline: for proc_name, proc in self.pipeline:
dataset = proc(dataset) dataset = proc(dataset)
@@ -19,4 +27,4 @@ class Pipeline:
return dataset return dataset


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs)
return self.process(*args, **kwargs)

+ 0
- 44
fastNLP/api/pos_tagger.py View File

@@ -1,44 +0,0 @@
import pickle

import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.predictor import Predictor


class POS_tagger:
def __init__(self):
pass

def predict(self, query):
"""
:param query: List[str]
:return answer: List[str]

"""
# TODO: 根据query 构建DataSet
pos_dataset = DataSet()
pos_dataset["text_field"] = np.array(query)

# 加载pipeline和model
pipeline = self.load_pipeline("./xxxx")

# 将DataSet作为参数运行 pipeline
pos_dataset = pipeline(pos_dataset)

# 加载模型
model = ModelLoader().load_pytorch("./xxx")

# 调 predictor
predictor = Predictor()
output = predictor.predict(model, pos_dataset)

# TODO: 转成最终输出
return None

@staticmethod
def load_pipeline(path):
with open(path, "r") as fp:
pipeline = pickle.load(fp)
return pipeline

+ 22
- 5
fastNLP/api/processor.py View File

@@ -1,7 +1,7 @@

from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary



class Processor: class Processor:
def __init__(self, field_name, new_added_field_name): def __init__(self, field_name, new_added_field_name):
self.field_name = field_name self.field_name = field_name
@@ -10,15 +10,18 @@ class Processor:
else: else:
self.new_added_field_name = new_added_field_name self.new_added_field_name = new_added_field_name


def process(self):
def process(self, *args, **kwargs):
pass pass


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.process(*args, **kwargs) return self.process(*args, **kwargs)





class FullSpaceToHalfSpaceProcessor(Processor): class FullSpaceToHalfSpaceProcessor(Processor):
"""全角转半角,以字符为处理单元

"""

def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True, def __init__(self, field_name, change_alpha=True, change_digit=True, change_punctuation=True,
change_space=True): change_space=True):
super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None) super(FullSpaceToHalfSpaceProcessor, self).__init__(field_name, None)
@@ -64,11 +67,12 @@ class FullSpaceToHalfSpaceProcessor(Processor):
if self.change_space: if self.change_space:
FHs += FH_SPACE FHs += FH_SPACE
self.convert_map = {k: v for k, v in FHs} self.convert_map = {k: v for k, v in FHs}

def process(self, dataset): def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset: for ins in dataset:
sentence = ins[self.field_name] sentence = ins[self.field_name]
new_sentence = [None]*len(sentence)
new_sentence = [None] * len(sentence)
for idx, char in enumerate(sentence): for idx, char in enumerate(sentence):
if char in self.convert_map: if char in self.convert_map:
char = self.convert_map[char] char = self.convert_map[char]
@@ -98,7 +102,7 @@ class IndexerProcessor(Processor):
index = [self.vocab.to_index(token) for token in tokens] index = [self.vocab.to_index(token) for token in tokens]
ins[self.new_added_field_name] = index ins[self.new_added_field_name] = index


dataset.set_need_tensor(**{self.new_added_field_name:True})
dataset.set_need_tensor(**{self.new_added_field_name: True})


if self.delete_old_field: if self.delete_old_field:
dataset.delete_field(self.field_name) dataset.delete_field(self.field_name)
@@ -122,3 +126,16 @@ class VocabProcessor(Processor):
def get_vocab(self): def get_vocab(self):
self.vocab.build_vocab() self.vocab.build_vocab()
return self.vocab return self.vocab


class SeqLenProcessor(Processor):
def __init__(self, field_name, new_added_field_name='seq_lens'):
super(SeqLenProcessor, self).__init__(field_name, new_added_field_name)

def process(self, dataset):
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
for ins in dataset:
length = len(ins[self.field_name])
ins[self.new_added_field_name] = length
dataset.set_need_tensor(**{self.new_added_field_name: True})
return dataset

+ 0
- 3
fastNLP/core/batch.py View File

@@ -1,5 +1,3 @@
from collections import defaultdict

import torch import torch




@@ -68,4 +66,3 @@ class Batch(object):
self.curidx = endidx self.curidx = endidx


return batch_x, batch_y return batch_x, batch_y


+ 27
- 37
fastNLP/core/dataset.py View File

@@ -1,23 +1,27 @@
import random
import sys, os
sys.path.append('../..')
sys.path = [os.path.join(os.path.dirname(__file__), '../..')] + sys.path

from collections import defaultdict
from copy import deepcopy
import numpy as np

from fastNLP.core.field import TextField, LabelField
from fastNLP.core.instance import Instance
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.fieldarray import FieldArray from fastNLP.core.fieldarray import FieldArray


_READERS = {} _READERS = {}



def construct_dataset(sentences):
"""Construct a data set from a list of sentences.

:param sentences: list of str
:return dataset: a DataSet object
"""
dataset = DataSet()
for sentence in sentences:
instance = Instance()
instance['raw_sentence'] = sentence
dataset.append(instance)
return dataset


class DataSet(object): class DataSet(object):
"""A DataSet object is a list of Instance objects. """A DataSet object is a list of Instance objects.


""" """

class DataSetIter(object): class DataSetIter(object):
def __init__(self, dataset): def __init__(self, dataset):
self.dataset = dataset self.dataset = dataset
@@ -34,13 +38,12 @@ class DataSet(object):


def __setitem__(self, name, val): def __setitem__(self, name, val):
if name not in self.dataset: if name not in self.dataset:
new_fields = [None]*len(self.dataset)
new_fields = [None] * len(self.dataset)
self.dataset.add_field(name, new_fields) self.dataset.add_field(name, new_fields)
self.dataset[name][self.idx] = val self.dataset[name][self.idx] = val


def __repr__(self): def __repr__(self):
# TODO
pass
return " ".join([repr(self.dataset[name][self.idx]) for name in self.dataset])


def __init__(self, instance=None): def __init__(self, instance=None):
self.field_arrays = {} self.field_arrays = {}
@@ -72,7 +75,7 @@ class DataSet(object):
self.field_arrays[name].append(field) self.field_arrays[name].append(field)


def add_field(self, name, fields): def add_field(self, name, fields):
if len(self.field_arrays)!=0:
if len(self.field_arrays) != 0:
assert len(self) == len(fields) assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields) self.field_arrays[name] = FieldArray(name, fields)


@@ -90,27 +93,10 @@ class DataSet(object):
return len(field) return len(field)


def get_length(self): def get_length(self):
"""Fetch lengths of all fields in all instances in a dataset.

:return lengths: dict of (str: list). The str is the field name.
The list contains lengths of this field in all instances.

"""
pass

def shuffle(self):
pass

def split(self, ratio, shuffle=True):
"""Train/dev splitting

:param ratio: float, between 0 and 1. The ratio of development set in origin data set.
:param shuffle: bool, whether shuffle the data set before splitting. Default: True.
:return train_set: a DataSet object, representing the training set
dev_set: a DataSet object, representing the validation set
"""The same as __len__


""" """
pass
return len(self)


def rename_field(self, old_name, new_name): def rename_field(self, old_name, new_name):
"""rename a field """rename a field
@@ -118,7 +104,7 @@ class DataSet(object):
if old_name in self.field_arrays: if old_name in self.field_arrays:
self.field_arrays[new_name] = self.field_arrays.pop(old_name) self.field_arrays[new_name] = self.field_arrays.pop(old_name)
else: else:
raise KeyError
raise KeyError("{} is not a valid name. ".format(old_name))
return self return self


def set_is_target(self, **fields): def set_is_target(self, **fields):
@@ -150,6 +136,7 @@ class DataSet(object):
data = _READERS[name]().load(*args, **kwargs) data = _READERS[name]().load(*args, **kwargs)
self.extend(data) self.extend(data)
return self return self

return _read return _read
else: else:
return object.__getattribute__(self, name) return object.__getattribute__(self, name)
@@ -159,18 +146,21 @@ class DataSet(object):
"""decorator to add dataloader support """decorator to add dataloader support
""" """
assert isinstance(method_name, str) assert isinstance(method_name, str)

def wrapper(read_cls): def wrapper(read_cls):
_READERS[method_name] = read_cls _READERS[method_name] = read_cls
return read_cls return read_cls

return wrapper return wrapper




if __name__ == '__main__': if __name__ == '__main__':
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance

ins = Instance(test='test0') ins = Instance(test='test0')
dataset = DataSet([ins]) dataset = DataSet([ins])
for _iter in dataset: for _iter in dataset:
print(_iter['test']) print(_iter['test'])
_iter['test'] = 'abc' _iter['test'] = 'abc'
print(_iter['test']) print(_iter['test'])
print(dataset.field_arrays)
print(dataset.field_arrays)

+ 2
- 2
fastNLP/core/instance.py View File

@@ -1,4 +1,4 @@
import torch


class Instance(object): class Instance(object):
"""An instance which consists of Fields is an example in the DataSet. """An instance which consists of Fields is an example in the DataSet.
@@ -35,4 +35,4 @@ class Instance(object):
return self.add_field(name, field) return self.add_field(name, field)


def __repr__(self): def __repr__(self):
return self.fields.__repr__()
return self.fields.__repr__()

+ 3
- 2
fastNLP/loader/dataset_loader.py View File

@@ -1,9 +1,9 @@
import os import os


from fastNLP.loader.base_loader import BaseLoader
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.field import * from fastNLP.core.field import *
from fastNLP.core.instance import Instance
from fastNLP.loader.base_loader import BaseLoader




def convert_seq_dataset(data): def convert_seq_dataset(data):
@@ -393,6 +393,7 @@ class PeopleDailyCorpusLoader(DataSetLoader):
sent_words.append(token) sent_words.append(token)
pos_tag_examples.append([sent_words, sent_pos_tag]) pos_tag_examples.append([sent_words, sent_pos_tag])
ner_examples.append([sent_words, sent_ner]) ner_examples.append([sent_words, sent_ner])
# List[List[List[str], List[str]]]
return pos_tag_examples, ner_examples return pos_tag_examples, ner_examples


def convert(self, data): def convert(self, data):


+ 7
- 1
fastNLP/models/sequence_modeling.py View File

@@ -44,6 +44,9 @@ class SeqLabeling(BaseModel):
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training. If truth is not None, return loss, a scalar. Used in training.
""" """
assert word_seq.shape[0] == word_seq_origin_len.shape[0]
if truth is not None:
assert truth.shape == word_seq.shape
self.mask = self.make_mask(word_seq, word_seq_origin_len) self.mask = self.make_mask(word_seq, word_seq_origin_len)


x = self.Embedding(word_seq) x = self.Embedding(word_seq)
@@ -80,7 +83,7 @@ class SeqLabeling(BaseModel):
batch_size, max_len = x.size(0), x.size(1) batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len) mask = seq_mask(seq_len, max_len)
mask = mask.byte().view(batch_size, max_len) mask = mask.byte().view(batch_size, max_len)
mask = mask.to(x)
mask = mask.to(x).float()
return mask return mask


def decode(self, x, pad=True): def decode(self, x, pad=True):
@@ -130,6 +133,9 @@ class AdvSeqLabel(SeqLabeling):
:return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting. :return y: If truth is None, return list of [decode path(list)]. Used in testing and predicting.
If truth is not None, return loss, a scalar. Used in training. If truth is not None, return loss, a scalar. Used in training.
""" """
word_seq = word_seq.long()
word_seq_origin_len = word_seq_origin_len.long()
truth = truth.long()
self.mask = self.make_mask(word_seq, word_seq_origin_len) self.mask = self.make_mask(word_seq, word_seq_origin_len)


batch_size = word_seq.size(0) batch_size = word_seq.size(0)


+ 13
- 11
fastNLP/modules/decoder/CRF.py View File

@@ -3,6 +3,7 @@ from torch import nn


from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import initial_parameter



def log_sum_exp(x, dim=-1): def log_sum_exp(x, dim=-1):
max_value, _ = x.max(dim=dim, keepdim=True) max_value, _ = x.max(dim=dim, keepdim=True)
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value
@@ -20,7 +21,7 @@ def seq_len_to_byte_mask(seq_lens):




class ConditionalRandomField(nn.Module): class ConditionalRandomField(nn.Module):
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None):
def __init__(self, tag_size, include_start_end_trans=False, initial_method=None):
""" """
:param tag_size: int, num of tags :param tag_size: int, num of tags
:param include_start_end_trans: bool, whether to include start/end tag :param include_start_end_trans: bool, whether to include start/end tag
@@ -38,6 +39,7 @@ class ConditionalRandomField(nn.Module):


# self.reset_parameter() # self.reset_parameter()
initial_parameter(self, initial_method) initial_parameter(self, initial_method)

def reset_parameter(self): def reset_parameter(self):
nn.init.xavier_normal_(self.trans_m) nn.init.xavier_normal_(self.trans_m)
if self.include_start_end_trans: if self.include_start_end_trans:
@@ -81,15 +83,15 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)


# trans_socre [L-1, B] # trans_socre [L-1, B]
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :]
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]] * mask[1:, :]
# emit_score [L, B] # emit_score [L, B]
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags] * mask
# score [L-1, B] # score [L-1, B]
score = trans_score + emit_score[:seq_len-1, :]
score = trans_score + emit_score[:seq_len - 1, :]
score = score.sum(0) + emit_score[-1] score = score.sum(0) + emit_score[-1]
if self.include_start_end_trans: if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = masks.long().sum(0)
last_idx = mask.long().sum(0)
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]]
score += st_scores + ed_scores score += st_scores + ed_scores
# return [B,] # return [B,]
@@ -120,14 +122,14 @@ class ConditionalRandomField(nn.Module):
:return: scores, paths :return: scores, paths
""" """
batch_size, seq_len, n_tags = data.size() batch_size, seq_len, n_tags = data.size()
data = data.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.float() # L, B
data = data.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.float() # L, B


# dp # dp
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = data[0] vscore = data[0]
if self.include_start_end_trans: if self.include_start_end_trans:
vscore += self.start_scores.view(1. -1)
vscore += self.start_scores.view(1. - 1)
for i in range(1, seq_len): for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1) prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = data[i].view(batch_size, 1, n_tags) cur_score = data[i].view(batch_size, 1, n_tags)
@@ -145,15 +147,15 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device)
lens = (mask.long().sum(0) - 1) lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0 # idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len


ans = data.new_empty((seq_len, batch_size), dtype=torch.long) ans = data.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1) ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1): for i in range(seq_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags] last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i+1], batch_idx] = last_tags
ans[idxes[i + 1], batch_idx] = last_tags


if get_score: if get_score:
return ans_score, ans.transpose(0, 1) return ans_score, ans.transpose(0, 1)
return ans.transpose(0, 1)
return ans.transpose(0, 1)

+ 5
- 3
reproduction/pos_tag_model/pos_tag.cfg View File

@@ -1,10 +1,12 @@
[train] [train]
epochs = 30
batch_size = 64
epochs = 5
batch_size = 2
pickle_path = "./save/" pickle_path = "./save/"
validate = true
validate = false
save_best_dev = true save_best_dev = true
model_saved_path = "./save/" model_saved_path = "./save/"

[model]
rnn_hidden_units = 100 rnn_hidden_units = 100
word_emb_dim = 100 word_emb_dim = 100
use_crf = true use_crf = true


+ 57
- 97
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -1,130 +1,88 @@
import os import os
import sys


sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
import torch


from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import VocabProcessor, IndexerProcessor, SeqLenProcessor
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.trainer import Trainer
from fastNLP.loader.config_loader import ConfigLoader, ConfigSection from fastNLP.loader.config_loader import ConfigLoader, ConfigSection
from fastNLP.core.trainer import SeqLabelTrainer
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader, BaseLoader
from fastNLP.core.preprocess import SeqLabelPreprocess, load_pickle
from fastNLP.saver.model_saver import ModelSaver
from fastNLP.loader.model_loader import ModelLoader
from fastNLP.core.tester import SeqLabelTester
from fastNLP.loader.dataset_loader import PeopleDailyCorpusLoader
from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.models.sequence_modeling import AdvSeqLabel
from fastNLP.core.predictor import SeqLabelInfer


# not in the file's dir
if len(os.path.dirname(__file__)) != 0:
os.chdir(os.path.dirname(__file__))
datadir = "/home/zyfeng/data/"
cfgfile = './pos_tag.cfg' cfgfile = './pos_tag.cfg'
data_name = "CWS_POS_TAG_NER_people_daily.txt"
datadir = "/home/zyfeng/fastnlp_0.2.0/test/data_for_tests/"
data_name = "people_daily_raw.txt"


pos_tag_data_path = os.path.join(datadir, data_name) pos_tag_data_path = os.path.join(datadir, data_name)
pickle_path = "save" pickle_path = "save"
data_infer_path = os.path.join(datadir, "infer.utf8") data_infer_path = os.path.join(datadir, "infer.utf8")




def infer():
# Config Loader
test_args = ConfigSection()
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "class2id.pkl")
test_args["num_classes"] = len(index2label)

# Define the same model
model = AdvSeqLabel(test_args)

try:
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print('model loaded!')
except Exception as e:
print('cannot load model!')
raise

# Data Loader
raw_data_loader = BaseLoader(data_infer_path)
infer_data = raw_data_loader.load_lines()
print('data loaded')

# Inference interface
infer = SeqLabelInfer(pickle_path)
results = infer.predict(model, infer_data)

print(results)
print("Inference finished!")


def train():
def train():
# load config # load config
trainer_args = ConfigSection()
model_args = ConfigSection()
ConfigLoader().load_config(cfgfile, {"train": train_args, "test": test_args})
train_param = ConfigSection()
model_param = ConfigSection()
ConfigLoader().load_config(cfgfile, {"train": train_param, "model": model_param})
print("config loaded")


# Data Loader # Data Loader
loader = PeopleDailyCorpusLoader() loader = PeopleDailyCorpusLoader()
train_data, _ = loader.load()

# TODO: define processors
# define pipeline
pp = Pipeline()
# TODO: pp.add_processor()

# run the pipeline, get data_set
train_data = pp(train_data)
train_data, _ = loader.load(os.path.join(datadir, data_name))
print("data loaded")

dataset = DataSet()
for data in train_data:
instance = Instance()
instance["words"] = data[0]
instance["tag"] = data[1]
dataset.append(instance)
print("dataset transformed")

# processor_1 = FullSpaceToHalfSpaceProcessor('words')
# processor_1(dataset)
word_vocab_proc = VocabProcessor('words')
tag_vocab_proc = VocabProcessor("tag")
word_vocab_proc(dataset)
tag_vocab_proc(dataset)
word_indexer = IndexerProcessor(word_vocab_proc.get_vocab(), 'words', 'word_seq', delete_old_field=True)
word_indexer(dataset)
tag_indexer = IndexerProcessor(tag_vocab_proc.get_vocab(), 'tag', 'truth', delete_old_field=True)
tag_indexer(dataset)
seq_len_proc = SeqLenProcessor("word_seq", "word_seq_origin_len")
seq_len_proc(dataset)

print("processors defined")
# dataset.set_is_target(tag_ids=True)
model_param["vocab_size"] = len(word_vocab_proc.get_vocab())
model_param["num_classes"] = len(tag_vocab_proc.get_vocab())
print("vocab_size={} num_classes={}".format(len(word_vocab_proc.get_vocab()), len(tag_vocab_proc.get_vocab())))


# define a model # define a model
model = AdvSeqLabel(train_args)
model = AdvSeqLabel(model_param)


# call trainer to train # call trainer to train
trainer = SeqLabelTrainer(train_args)
trainer.train(model, data_train, data_dev)

# save model
ModelSaver("./saved_model.pkl").save_pytorch(model, param_only=False)

# TODO:save pipeline
trainer = Trainer(**train_param.data)
trainer.train(model, dataset)


# save model & pipeline
pp = Pipeline([word_vocab_proc, word_indexer, seq_len_proc])
save_dict = {"pipeline": pp, "model": model}
torch.save(save_dict, "model_pp.pkl")




def test(): def test():
# Config Loader
test_args = ConfigSection()
ConfigLoader("config").load_config(cfgfile, {"POS_test": test_args})

# fetch dictionary size and number of labels from pickle files
word2index = load_pickle(pickle_path, "word2id.pkl")
test_args["vocab_size"] = len(word2index)
index2label = load_pickle(pickle_path, "class2id.pkl")
test_args["num_classes"] = len(index2label)

# load dev data
dev_data = load_pickle(pickle_path, "data_dev.pkl")

# Define the same model
model = AdvSeqLabel(test_args)
pass


# Dump trained parameters into the model
ModelLoader.load_pytorch(model, "./save/saved_model.pkl")
print("model loaded!")


# Tester
tester = SeqLabelTester(**test_args.data)

# Start testing
tester.test(model, dev_data)

# print test results
print(tester.show_metrics())
print("model tested!")
def infer():
pass




if __name__ == "__main__": if __name__ == "__main__":
train()
"""
import argparse import argparse


parser = argparse.ArgumentParser(description='Run a chinese word segmentation model') parser = argparse.ArgumentParser(description='Run a chinese word segmentation model')
@@ -139,3 +97,5 @@ if __name__ == "__main__":
else: else:
print('no mode specified for model!') print('no mode specified for model!')
parser.print_help() parser.print_help()

"""

Loading…
Cancel
Save