Browse Source

code optimization

* move used readers from reproduction to io/dataset_loader.py
(API shall not call anything from reproduction/)
tags/v0.3.1^2
FengZiYjun 6 years ago
parent
commit
c4ba75d160
10 changed files with 489 additions and 486 deletions
  1. +11
    -3
      fastNLP/api/api.py
  2. +100
    -0
      fastNLP/api/processor.py
  3. +373
    -0
      fastNLP/io/dataset_loader.py
  4. +1
    -1
      reproduction/Biaffine_parser/main.py
  5. +1
    -6
      reproduction/Biaffine_parser/run.py
  6. +0
    -51
      reproduction/Biaffine_parser/util.py
  7. +0
    -194
      reproduction/chinese_word_segment/cws_io/cws_reader.py
  8. +0
    -103
      reproduction/chinese_word_segment/process/cws_processor.py
  9. +1
    -125
      reproduction/pos_tag_model/pos_reader.py
  10. +2
    -3
      reproduction/pos_tag_model/train_pos_tag.py

+ 11
- 3
fastNLP/api/api.py View File

@@ -9,9 +9,7 @@ from fastNLP.core.dataset import DataSet


from fastNLP.api.utils import load_url from fastNLP.api.utils import load_url
from fastNLP.api.processor import ModelProcessor from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.io.dataset_loader import ConllCWSReader, ZhConllPOSReader, ConllxDataLoader, add_seg_tag
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline from fastNLP.api.pipeline import Pipeline
from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.metrics import SpanFPreRecMetric
@@ -31,6 +29,16 @@ class API:
self._dict = None self._dict = None


def predict(self, *args, **kwargs): def predict(self, *args, **kwargs):
"""Do prediction for the given input.
"""
raise NotImplementedError

def test(self, file_path):
"""Test performance over the given data set.

:param str file_path:
:return: a dictionary of metric values
"""
raise NotImplementedError raise NotImplementedError


def load(self, path, device): def load(self, path, device):


+ 100
- 0
fastNLP/api/processor.py View File

@@ -322,3 +322,103 @@ class SetInputProcessor(Processor):
def process(self, dataset): def process(self, dataset):
dataset.set_input(*self.fields, flag=self.flag) dataset.set_input(*self.fields, flag=self.flag)
return dataset return dataset


class VocabIndexerProcessor(Processor):
"""
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供
new_added_field_name, 则覆盖原有的field_name.

"""

def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
verbose=0, is_input=True):
"""

:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name.
:param min_freq: 创建的Vocabulary允许的单词最少出现次数.
:param max_size: 创建的Vocabulary允许的最大的单词数量
:param verbose: 0, 不输出任何信息;1,输出信息
:param bool is_input:
"""
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name)
self.min_freq = min_freq
self.max_size = max_size

self.verbose = verbose
self.is_input = is_input

def construct_vocab(self, *datasets):
"""
使用传入的DataSet创建vocabulary

:param datasets: DataSet类型的数据,用于构建vocabulary
:return:
"""
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size)
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
self.vocab.build_vocab()
if self.verbose:
print("Vocabulary Constructed, has {} items.".format(len(self.vocab)))

def process(self, *datasets, only_index_dataset=None):
"""
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary
后,则会index datasets与only_index_dataset。

:param datasets: DataSet类型的数据
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。
:return:
"""
if len(datasets) == 0 and not hasattr(self, 'vocab'):
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.")
if not hasattr(self, 'vocab'):
self.construct_vocab(*datasets)
else:
if self.verbose:
print("Using constructed vocabulary with {} items.".format(len(self.vocab)))
to_index_datasets = []
if len(datasets) != 0:
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
to_index_datasets.append(dataset)

if not (only_index_dataset is None):
if isinstance(only_index_dataset, list):
for dataset in only_index_dataset:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
to_index_datasets.append(dataset)
elif isinstance(only_index_dataset, DataSet):
to_index_datasets.append(only_index_dataset)
else:
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset)))

for dataset in to_index_datasets:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
new_field_name=self.new_added_field_name, is_input=self.is_input)
# 只返回一个,infer时为了跟其他processor保持一致
if len(to_index_datasets) == 1:
return to_index_datasets[0]

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab))
self.vocab = vocab

def delete_vocab(self):
del self.vocab

def get_vocab_size(self):
return len(self.vocab)

def set_verbose(self, verbose):
"""
设置processor verbose状态。

:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。
:return:
"""
self.verbose = verbose

+ 373
- 0
fastNLP/io/dataset_loader.py View File

@@ -90,6 +90,7 @@ class NativeDataSetLoader(DataSetLoader):
"""A simple example of DataSetLoader """A simple example of DataSetLoader


""" """

def __init__(self): def __init__(self):
super(NativeDataSetLoader, self).__init__() super(NativeDataSetLoader, self).__init__()


@@ -107,6 +108,7 @@ class RawDataSetLoader(DataSetLoader):
"""A simple example of raw data reader """A simple example of raw data reader


""" """

def __init__(self): def __init__(self):
super(RawDataSetLoader, self).__init__() super(RawDataSetLoader, self).__init__()


@@ -142,6 +144,7 @@ class POSDataSetLoader(DataSetLoader):


In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label.
""" """

def __init__(self): def __init__(self):
super(POSDataSetLoader, self).__init__() super(POSDataSetLoader, self).__init__()


@@ -540,3 +543,373 @@ class SNLIDataSetLoader(DataSetLoader):
data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len")
data_set.set_target("truth") data_set.set_target("truth")
return data_set return data_set


class ConllCWSReader(object):
def __init__(self):
pass

def load(self, path, cut_long_sent=False):
"""
返回的DataSet只包含raw_sentence这个field,内容为str。
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct

1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_char_lst(sample)
if res is None:
continue
line = ' '.join(res)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for raw_sentence in sents:
ds.append(Instance(raw_sentence=raw_sentence))

return ds

def get_char_lst(self, sample):
if len(sample) == 0:
return None
text = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
return text


class POSCWSReader(DataSetLoader):
"""
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限.
迈 N
向 N
充 N
...
泽 I-PER
民 I-PER

( N
一 N
九 N
...


:param filepath:
:return:
"""

def __init__(self, in_word_splitter=None):
super().__init__()
self.in_word_splitter = in_word_splitter

def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
if in_word_splitter is None:
in_word_splitter = self.in_word_splitter
dataset = DataSet()
with open(filepath, 'r') as f:
words = []
for line in f:
line = line.strip()
if len(line) == 0: # new line
if len(words) == 0: # 不能接受空行
continue
line = ' '.join(words)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for sent in sents:
instance = Instance(raw_sentence=sent)
dataset.append(instance)
words = []
else:
line = line.split()[0]
if in_word_splitter is None:
words.append(line)
else:
words.append(line.split(in_word_splitter)[0])
return dataset


class NaiveCWSReader(DataSetLoader):
"""
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了
这是 fastNLP , 一个 非常 good 的 包 .
或者,即每个part后面还有一个pos tag
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
"""

def __init__(self, in_word_splitter=None):
super().__init__()

self.in_word_splitter = in_word_splitter

def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
"""
允许使用的情况有(默认以\t或空格作为seg)
这是 fastNLP , 一个 非常 good 的 包 .
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0]
:param filepath:
:param in_word_splitter:
:return:
"""
if in_word_splitter == None:
in_word_splitter = self.in_word_splitter
dataset = DataSet()
with open(filepath, 'r') as f:
for line in f:
line = line.strip()
if len(line.replace(' ', '')) == 0: # 不能接受空行
continue

if not in_word_splitter is None:
words = []
for part in line.split():
word = part.split(in_word_splitter)[0]
words.append(word)
line = ' '.join(words)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for sent in sents:
instance = Instance(raw_sentence=sent)
dataset.append(instance)

return dataset


def cut_long_sentence(sent, max_sample_length=200):
"""
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length

:param sent: str.
:param max_sample_length: int.
:return: list of str.
"""
sent_no_space = sent.replace(' ', '')
cutted_sentence = []
if len(sent_no_space) > max_sample_length:
parts = sent.strip().split()
new_line = ''
length = 0
for part in parts:
length += len(part)
new_line += part + ' '
if length > max_sample_length:
new_line = new_line[:-1]
cutted_sentence.append(new_line)
length = 0
new_line = ''
if new_line != '':
cutted_sentence.append(new_line[:-1])
else:
cutted_sentence.append(sent)
return cutted_sentence


class ZhConllPOSReader(object):
# 中文colln格式reader
def __init__(self):
pass

def load(self, path):
"""
返回的DataSet, 包含以下的field
words:list of str,
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..]
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct

1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
char_seq = []
pos_seq = []
for word, tag in zip(res[0], res[1]):
char_seq.extend(list(word))
if len(word) == 1:
pos_seq.append('S-{}'.format(tag))
elif len(word) > 1:
pos_seq.append('B-{}'.format(tag))
for _ in range(len(word) - 2):
pos_seq.append('M-{}'.format(tag))
pos_seq.append('E-{}'.format(tag))
else:
raise ValueError("Zero length of word detected.")

ds.append(Instance(words=char_seq,
tag=pos_seq))

return ds

def get_one(self, sample):
if len(sample) == 0:
return None
text = []
pos_tags = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
pos_tags.append(t2)
return text, pos_tags


class ConllPOSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。
def __init__(self):
pass

def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
char_seq = []
pos_seq = []
for word, tag in zip(res[0], res[1]):
if len(word) == 1:
char_seq.append(word)
pos_seq.append('S-{}'.format(tag))
elif len(word) > 1:
pos_seq.append('B-{}'.format(tag))
for _ in range(len(word) - 2):
pos_seq.append('M-{}'.format(tag))
pos_seq.append('E-{}'.format(tag))
char_seq.extend(list(word))
else:
raise ValueError("Zero length of word detected.")

ds.append(Instance(words=char_seq,
tag=pos_seq))

return ds


class ConllxDataLoader(object):
def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

data = [self.get_one(sample) for sample in datalist]
return list(filter(lambda x: x is not None, data))

def get_one(self, sample):
sample = list(map(list, zip(*sample)))
if len(sample) == 0:
return None
for w in sample[7]:
if w == '_':
print('Error Sample {}'.format(sample))
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]


def add_seg_tag(data):
"""

:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""

_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []
for word, pos in zip(word_list, pos_list):
if len(word) == 1:
new_sample.append((word, 'S-' + pos))
else:
new_sample.append((word[0], 'B-' + pos))
for c in word[1:-1]:
new_sample.append((c, 'M-' + pos))
new_sample.append((word[-1], 'E-' + pos))
_processed.append(list(map(list, zip(*new_sample))))
return _processed

+ 1
- 1
reproduction/Biaffine_parser/main.py View File

@@ -5,7 +5,7 @@ sys.path.extend(['/home/yfshao/workdir/dev_fastnlp'])
import torch import torch
import argparse import argparse


from reproduction.Biaffine_parser.util import ConllxDataLoader, add_seg_tag
from fastNLP.io.dataset_loader import ConllxDataLoader, add_seg_tag
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance




+ 1
- 6
reproduction/Biaffine_parser/run.py View File

@@ -4,20 +4,15 @@ import sys
sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))


import fastNLP import fastNLP
import torch


from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.api.pipeline import Pipeline from fastNLP.api.pipeline import Pipeline
from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss from fastNLP.models.biaffine_parser import BiaffineParser, ParserMetric, ParserLoss
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.io.model_io import ModelLoader from fastNLP.io.model_io import ModelLoader
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.io.model_io import ModelSaver
from reproduction.Biaffine_parser.util import ConllxDataLoader, MyDataloader
from fastNLP.io.dataset_loader import ConllxDataLoader
from fastNLP.api.processor import * from fastNLP.api.processor import *


BOS = '<BOS>' BOS = '<BOS>'


+ 0
- 51
reproduction/Biaffine_parser/util.py View File

@@ -1,34 +1,3 @@
class ConllxDataLoader(object):
def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

data = [self.get_one(sample) for sample in datalist]
return list(filter(lambda x: x is not None, data))

def get_one(self, sample):
sample = list(map(list, zip(*sample)))
if len(sample) == 0:
return None
for w in sample[7]:
if w == '_':
print('Error Sample {}'.format(sample))
return None
# return word_seq, pos_seq, head_seq, head_tag_seq
return sample[1], sample[3], list(map(int, sample[6])), sample[7]


class MyDataloader: class MyDataloader:
def load(self, data_path): def load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f: with open(data_path, "r", encoding="utf-8") as f:
@@ -56,23 +25,3 @@ class MyDataloader:
return data return data




def add_seg_tag(data):
"""

:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""

_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []
for word, pos in zip(word_list, pos_list):
if len(word) == 1:
new_sample.append((word, 'S-' + pos))
else:
new_sample.append((word[0], 'B-' + pos))
for c in word[1:-1]:
new_sample.append((c, 'M-' + pos))
new_sample.append((word[-1], 'E-' + pos))
_processed.append(list(map(list, zip(*new_sample))))
return _processed

+ 0
- 194
reproduction/chinese_word_segment/cws_io/cws_reader.py View File

@@ -1,197 +1,3 @@




from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.io.dataset_loader import DataSetLoader


def cut_long_sentence(sent, max_sample_length=200):
"""
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length

:param sent: str.
:param max_sample_length: int.
:return: list of str.
"""
sent_no_space = sent.replace(' ', '')
cutted_sentence = []
if len(sent_no_space) > max_sample_length:
parts = sent.strip().split()
new_line = ''
length = 0
for part in parts:
length += len(part)
new_line += part + ' '
if length > max_sample_length:
new_line = new_line[:-1]
cutted_sentence.append(new_line)
length = 0
new_line = ''
if new_line != '':
cutted_sentence.append(new_line[:-1])
else:
cutted_sentence.append(sent)
return cutted_sentence

class NaiveCWSReader(DataSetLoader):
"""
这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了
这是 fastNLP , 一个 非常 good 的 包 .
或者,即每个part后面还有一个pos tag
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
"""
def __init__(self, in_word_splitter=None):
super().__init__()

self.in_word_splitter = in_word_splitter

def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
"""
允许使用的情况有(默认以\t或空格作为seg)
这是 fastNLP , 一个 非常 good 的 包 .
也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY
如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0]
:param filepath:
:param in_word_splitter:
:return:
"""
if in_word_splitter == None:
in_word_splitter = self.in_word_splitter
dataset = DataSet()
with open(filepath, 'r') as f:
for line in f:
line = line.strip()
if len(line.replace(' ', ''))==0: # 不能接受空行
continue

if not in_word_splitter is None:
words = []
for part in line.split():
word = part.split(in_word_splitter)[0]
words.append(word)
line = ' '.join(words)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for sent in sents:
instance = Instance(raw_sentence=sent)
dataset.append(instance)

return dataset


class POSCWSReader(DataSetLoader):
"""
支持读取以下的情况, 即每一行是一个词, 用空行作为两句话的界限.
迈 N
向 N
充 N
...
泽 I-PER
民 I-PER

( N
一 N
九 N
...


:param filepath:
:return:
"""
def __init__(self, in_word_splitter=None):
super().__init__()
self.in_word_splitter = in_word_splitter

def load(self, filepath, in_word_splitter=None, cut_long_sent=False):
if in_word_splitter is None:
in_word_splitter = self.in_word_splitter
dataset = DataSet()
with open(filepath, 'r') as f:
words = []
for line in f:
line = line.strip()
if len(line) == 0: # new line
if len(words)==0: # 不能接受空行
continue
line = ' '.join(words)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for sent in sents:
instance = Instance(raw_sentence=sent)
dataset.append(instance)
words = []
else:
line = line.split()[0]
if in_word_splitter is None:
words.append(line)
else:
words.append(line.split(in_word_splitter)[0])
return dataset


class ConllCWSReader(object):
def __init__(self):
pass

def load(self, path, cut_long_sent=False):
"""
返回的DataSet只包含raw_sentence这个field,内容为str。
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct

1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_char_lst(sample)
if res is None:
continue
line = ' '.join(res)
if cut_long_sent:
sents = cut_long_sentence(line)
else:
sents = [line]
for raw_sentence in sents:
ds.append(Instance(raw_sentence=raw_sentence))

return ds

def get_char_lst(self, sample):
if len(sample)==0:
return None
text = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
return text



+ 0
- 103
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -226,109 +226,6 @@ class Pre2Post2BigramProcessor(BigramProcessor):
return bigrams return bigrams




# 这里需要建立vocabulary了,但是遇到了以下的问题
# (1) 如果使用Processor的方式的话,但是在这种情况返回的不是dataset。所以建立vocabulary的工作用另外的方式实现,不借用
# Processor了
# TODO 如何将建立vocab和index这两步统一了?

class VocabIndexerProcessor(Processor):
"""
根据DataSet创建Vocabulary,并将其用数字index。新生成的index的field会被放在new_added_filed_name, 如果没有提供
new_added_field_name, 则覆盖原有的field_name.

"""
def __init__(self, field_name, new_added_filed_name=None, min_freq=1, max_size=None,
verbose=0, is_input=True):
"""

:param field_name: 从哪个field_name创建词表,以及对哪个field_name进行index操作
:param new_added_filed_name: index时,生成的index field的名称,如果不传入,则覆盖field_name.
:param min_freq: 创建的Vocabulary允许的单词最少出现次数.
:param max_size: 创建的Vocabulary允许的最大的单词数量
:param verbose: 0, 不输出任何信息;1,输出信息
:param bool is_input:
"""
super(VocabIndexerProcessor, self).__init__(field_name, new_added_filed_name)
self.min_freq = min_freq
self.max_size = max_size

self.verbose =verbose
self.is_input = is_input

def construct_vocab(self, *datasets):
"""
使用传入的DataSet创建vocabulary

:param datasets: DataSet类型的数据,用于构建vocabulary
:return:
"""
self.vocab = Vocabulary(min_freq=self.min_freq, max_size=self.max_size)
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset))
dataset.apply(lambda ins: self.vocab.update(ins[self.field_name]))
self.vocab.build_vocab()
if self.verbose:
print("Vocabulary Constructed, has {} items.".format(len(self.vocab)))

def process(self, *datasets, only_index_dataset=None):
"""
若还未建立Vocabulary,则使用dataset中的DataSet建立vocabulary;若已经有了vocabulary则使用已有的vocabulary。得到vocabulary
后,则会index datasets与only_index_dataset。

:param datasets: DataSet类型的数据
:param only_index_dataset: DataSet, or list of DataSet. 该参数中的内容只会被用于index,不会被用于生成vocabulary。
:return:
"""
if len(datasets)==0 and not hasattr(self,'vocab'):
raise RuntimeError("You have to construct vocabulary first. Or you have to pass datasets to construct it.")
if not hasattr(self, 'vocab'):
self.construct_vocab(*datasets)
else:
if self.verbose:
print("Using constructed vocabulary with {} items.".format(len(self.vocab)))
to_index_datasets = []
if len(datasets)!=0:
for dataset in datasets:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
to_index_datasets.append(dataset)

if not (only_index_dataset is None):
if isinstance(only_index_dataset, list):
for dataset in only_index_dataset:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
to_index_datasets.append(dataset)
elif isinstance(only_index_dataset, DataSet):
to_index_datasets.append(only_index_dataset)
else:
raise TypeError('Only DataSet or list of DataSet is allowed, not {}.'.format(type(only_index_dataset)))

for dataset in to_index_datasets:
assert isinstance(dataset, DataSet), "Only DataSet class is allowed, not {}.".format(type(dataset))
dataset.apply(lambda ins: [self.vocab.to_index(token) for token in ins[self.field_name]],
new_field_name=self.new_added_field_name, is_input=self.is_input)
# 只返回一个,infer时为了跟其他processor保持一致
if len(to_index_datasets) == 1:
return to_index_datasets[0]

def set_vocab(self, vocab):
assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary is allowed, not {}.".format(type(vocab))
self.vocab = vocab

def delete_vocab(self):
del self.vocab

def get_vocab_size(self):
return len(self.vocab)

def set_verbose(self, verbose):
"""
设置processor verbose状态。

:param verbose: int, 0,不输出任何信息;1,输出vocab 信息。
:return:
"""
self.verbose = verbose

class VocabProcessor(Processor): class VocabProcessor(Processor):
def __init__(self, field_name, min_freq=1, max_size=None): def __init__(self, field_name, min_freq=1, max_size=None):




+ 1
- 125
reproduction/pos_tag_model/pos_reader.py View File

@@ -1,6 +1,5 @@
from fastNLP.io.dataset_loader import ZhConllPOSReader


from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance


def cut_long_sentence(sent, max_sample_length=200): def cut_long_sentence(sent, max_sample_length=200):
sent_no_space = sent.replace(' ', '') sent_no_space = sent.replace(' ', '')
@@ -24,129 +23,6 @@ def cut_long_sentence(sent, max_sample_length=200):
return cutted_sentence return cutted_sentence




class ConllPOSReader(object):
# 返回的Dataset包含words(list of list, 里层的list是character), tag两个field(list of str, str是标有BIO的tag)。
def __init__(self):
pass

def load(self, path):
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
char_seq = []
pos_seq = []
for word, tag in zip(res[0], res[1]):
if len(word)==1:
char_seq.append(word)
pos_seq.append('S-{}'.format(tag))
elif len(word)>1:
pos_seq.append('B-{}'.format(tag))
for _ in range(len(word)-2):
pos_seq.append('M-{}'.format(tag))
pos_seq.append('E-{}'.format(tag))
char_seq.extend(list(word))
else:
raise ValueError("Zero length of word detected.")

ds.append(Instance(words=char_seq,
tag=pos_seq))

return ds



class ZhConllPOSReader(object):
# 中文colln格式reader
def __init__(self):
pass

def load(self, path):
"""
返回的DataSet, 包含以下的field
words:list of str,
tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..]
假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即
1 编者按 编者按 NN O 11 nmod:topic
2 : : PU O 11 punct
3 7月 7月 NT DATE 4 compound:nn
4 12日 12日 NT DATE 11 nmod:tmod
5 , , PU O 11 punct

1 这 这 DT O 3 det
2 款 款 M O 1 mark:clf
3 飞行 飞行 NN O 8 nsubj
4 从 从 P O 5 case
5 外型 外型 NN O 8 nmod:prep
"""
datalist = []
with open(path, 'r', encoding='utf-8') as f:
sample = []
for line in f:
if line.startswith('\n'):
datalist.append(sample)
sample = []
elif line.startswith('#'):
continue
else:
sample.append(line.split('\t'))
if len(sample) > 0:
datalist.append(sample)

ds = DataSet()
for sample in datalist:
# print(sample)
res = self.get_one(sample)
if res is None:
continue
char_seq = []
pos_seq = []
for word, tag in zip(res[0], res[1]):
char_seq.extend(list(word))
if len(word)==1:
pos_seq.append('S-{}'.format(tag))
elif len(word)>1:
pos_seq.append('B-{}'.format(tag))
for _ in range(len(word)-2):
pos_seq.append('M-{}'.format(tag))
pos_seq.append('E-{}'.format(tag))
else:
raise ValueError("Zero length of word detected.")

ds.append(Instance(words=char_seq,
tag=pos_seq))

return ds

def get_one(self, sample):
if len(sample)==0:
return None
text = []
pos_tags = []
for w in sample:
t1, t2, t3, t4 = w[1], w[3], w[6], w[7]
if t3 == '_':
return None
text.append(t1)
pos_tags.append(t2)
return text, pos_tags

if __name__ == '__main__': if __name__ == '__main__':
reader = ZhConllPOSReader() reader = ZhConllPOSReader()
d = reader.load('/home/hyan/train.conllx') d = reader.load('/home/hyan/train.conllx')

+ 2
- 3
reproduction/pos_tag_model/train_pos_tag.py View File

@@ -10,13 +10,12 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))




from fastNLP.api.pipeline import Pipeline from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import SeqLenProcessor
from fastNLP.api.processor import SeqLenProcessor, VocabIndexerProcessor
from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.metrics import SpanFPreRecMetric
from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
from fastNLP.io.config_io import ConfigLoader, ConfigSection from fastNLP.io.config_io import ConfigLoader, ConfigSection
from fastNLP.models.sequence_modeling import AdvSeqLabel from fastNLP.models.sequence_modeling import AdvSeqLabel
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor
from reproduction.pos_tag_model.pos_reader import ZhConllPOSReader
from fastNLP.io.dataset_loader import ZhConllPOSReader
from fastNLP.api.processor import ModelProcessor, Index2WordProcessor from fastNLP.api.processor import ModelProcessor, Index2WordProcessor


cfgfile = './pos_tag.cfg' cfgfile = './pos_tag.cfg'


Loading…
Cancel
Save