Browse Source

refine git commits

tags/v0.2.0
yunfan 6 years ago
parent
commit
2aaa381827
7 changed files with 32 additions and 39 deletions
  1. +12
    -21
      fastNLP/api/api.py
  2. +8
    -3
      fastNLP/core/dataset.py
  3. +4
    -4
      fastNLP/core/metrics.py
  4. +6
    -7
      fastNLP/core/trainer.py
  5. +0
    -2
      fastNLP/models/sequence_modeling.py
  6. +1
    -1
      reproduction/pos_tag_model/pos_tag.cfg
  7. +1
    -1
      setup.py

+ 12
- 21
fastNLP/api/api.py View File

@@ -6,6 +6,7 @@ warnings.filterwarnings('ignore')
import os import os


from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet

from fastNLP.api.model_zoo import load_url from fastNLP.api.model_zoo import load_url
from fastNLP.api.processor import ModelProcessor from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader from reproduction.chinese_word_segment.cws_io.cws_reader import ConlluCWSReader
@@ -120,7 +121,7 @@ class POS(API):
f1 = round(test_result['F'] * 100, 2) f1 = round(test_result['F'] * 100, 2)
pre = round(test_result['P'] * 100, 2) pre = round(test_result['P'] * 100, 2)
rec = round(test_result['R'] * 100, 2) rec = round(test_result['R'] * 100, 2)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))


return f1, pre, rec return f1, pre, rec


@@ -179,7 +180,7 @@ class CWS(API):
f1 = round(f1 * 100, 2) f1 = round(f1 * 100, 2)
pre = round(pre * 100, 2) pre = round(pre * 100, 2)
rec = round(rec * 100, 2) rec = round(rec * 100, 2)
print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1, pre, rec))


return f1, pre, rec return f1, pre, rec


@@ -251,30 +252,23 @@ class Parser(API):




class Analyzer: class Analyzer:
def __init__(self, seg=True, pos=True, parser=True, device='cpu'):

self.seg = seg
self.pos = pos
self.parser = parser
def __init__(self, device='cpu'):


if self.seg:
self.cws = CWS(device=device)
if self.pos:
self.pos = POS(device=device)
if parser:
self.parser = None
self.cws = CWS(device=device)
self.pos = POS(device=device)
self.parser = Parser(device=device)


def predict(self, content, seg=False, pos=False, parser=False): def predict(self, content, seg=False, pos=False, parser=False):
if seg is False and pos is False and parser is False: if seg is False and pos is False and parser is False:
seg = True seg = True
output_dict = {} output_dict = {}
if self.seg:
if seg:
seg_output = self.cws.predict(content) seg_output = self.cws.predict(content)
output_dict['seg'] = seg_output output_dict['seg'] = seg_output
if self.pos:
if pos:
pos_output = self.pos.predict(content) pos_output = self.pos.predict(content)
output_dict['pos'] = pos_output output_dict['pos'] = pos_output
if self.parser:
if parser:
parser_output = self.parser.predict(content) parser_output = self.parser.predict(content)
output_dict['parser'] = parser_output output_dict['parser'] = parser_output


@@ -301,7 +295,7 @@ if __name__ == "__main__":
# s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' , # s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。' ,
# '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', # '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
# '那么这款无人机到底有多厉害?'] # '那么这款无人机到底有多厉害?']
# print(pos.test('/Users/yh/Desktop/test_data/small_test.conll'))
# print(pos.test('/Users/yh/Desktop/test_data/pos_test.conll'))
# print(pos.predict(s)) # print(pos.predict(s))


# cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl' # cws_model_path = '../../reproduction/chinese_word_segment/models/cws_crf.pkl'
@@ -317,7 +311,4 @@ if __name__ == "__main__":
s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。', s = ['编者按:7月12日,英国航空航天系统公司公布了该公司研制的第一款高科技隐形无人机雷电之神。',
'这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。', '这款飞行从外型上来看酷似电影中的太空飞行器,据英国方面介绍,可以实现洲际远程打击。',
'那么这款无人机到底有多厉害?'] '那么这款无人机到底有多厉害?']
print(cws.test('/Users/yh/Desktop/test_data/small_test.conll'))
print(cws.predict(s))


print(parser.predict(s))

+ 8
- 3
fastNLP/core/dataset.py View File

@@ -313,9 +313,14 @@ class DataSet(object):
for col in headers: for col in headers:
_dict[col] = [] _dict[col] = []
for line_idx, line in enumerate(f, start_idx): for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep)
assert len(contents)==len(headers), "Line {} has {} parts, while header has {}."\
.format(line_idx, len(contents), len(headers))
contents = line.split(sep)
if len(contents)!=len(headers):
if dropna:
continue
else:
#TODO change error type
raise ValueError("Line {} has {} parts, while header has {} parts."\
.format(line_idx, len(contents), len(headers)))
for header, content in zip(headers, contents): for header, content in zip(headers, contents):
_dict[header].append(content) _dict[header].append(content)
return cls(_dict) return cls(_dict)

+ 4
- 4
fastNLP/core/metrics.py View File

@@ -38,15 +38,15 @@ class SeqLabelEvaluator(Evaluator):
def __call__(self, predict, truth, **_): def __call__(self, predict, truth, **_):
""" """


:param predict: list of dict, the network outputs from all batches.
:param predict: list of List, the network outputs from all batches.
:param truth: list of dict, the ground truths from all batch_y. :param truth: list of dict, the ground truths from all batch_y.
:return accuracy: :return accuracy:
""" """
total_correct, total_count = 0., 0.
total_correct, total_count = 0., 0.
for x, y in zip(predict, truth): for x, y in zip(predict, truth):
# x = torch.tensor(x)
x = torch.tensor(x)
y = y.to(x) # make sure they are in the same device y = y.to(x) # make sure they are in the same device
mask = (y > 0)
mask = (y > 0)
correct = torch.sum(((x == y) * mask).long()) correct = torch.sum(((x == y) * mask).long())
total_correct += float(correct) total_correct += float(correct)
total_count += float(torch.sum(mask.long())) total_count += float(torch.sum(mask.long()))


+ 6
- 7
fastNLP/core/trainer.py View File

@@ -4,6 +4,7 @@ from datetime import datetime
import warnings import warnings
from collections import defaultdict from collections import defaultdict
import os import os
import itertools
import shutil import shutil


from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
@@ -121,10 +122,7 @@ class Trainer(object):
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
prediction = self.data_forward(model, batch_x) prediction = self.data_forward(model, batch_x)


# TODO: refactor self.get_loss
loss = prediction["loss"] if "loss" in prediction else self.get_loss(prediction, batch_y)
# acc = self._evaluator([{"predict": prediction["predict"]}], [{"truth": batch_x["truth"]}])

loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss) self.grad_backward(loss)
self.update() self.update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
@@ -133,7 +131,7 @@ class Trainer(object):
self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step) self._summary_writer.add_scalar(name + "_mean", param.mean(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step) # self._summary_writer.add_scalar(name + "_std", param.std(), global_step=self.step)
# self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step) # self._summary_writer.add_scalar(name + "_grad_sum", param.sum(), global_step=self.step)
if n_print > 0 and self.step % n_print == 0:
if self.print_every > 0 and self.step % self.print_every == 0:
end = time.time() end = time.time()
diff = timedelta(seconds=round(end - start)) diff = timedelta(seconds=round(end - start))
print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format( print_output = "[epoch: {:>3} step: {:>4}] train loss: {:>4.6} time: {}".format(
@@ -241,7 +239,7 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No


batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_syn_model_data(model, batch_x, batch_y)
_syn_model_data(model, batch_x, batch_y)
# forward check # forward check
if batch_count==0: if batch_count==0:
_check_forward_error(model_func=model.forward, check_level=check_level, _check_forward_error(model_func=model.forward, check_level=check_level,
@@ -269,7 +267,8 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No
model_name, loss.size() model_name, loss.size()
)) ))
loss.backward() loss.backward()
if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE:
model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
break break


if dev_data is not None: if dev_data is not None:


+ 0
- 2
fastNLP/models/sequence_modeling.py View File

@@ -1,4 +1,3 @@
import numpy as np
import torch import torch
import numpy as np import numpy as np


@@ -141,7 +140,6 @@ class AdvSeqLabel(SeqLabeling):
idx_sort = idx_sort.cuda() idx_sort = idx_sort.cuda()
idx_unsort = idx_unsort.cuda() idx_unsort = idx_unsort.cuda()
self.mask = self.mask.cuda() self.mask = self.mask.cuda()
truth = truth.cuda() if truth is not None else None


x = self.Embedding(word_seq) x = self.Embedding(word_seq)
x = self.norm1(x) x = self.norm1(x)


+ 1
- 1
reproduction/pos_tag_model/pos_tag.cfg View File

@@ -36,4 +36,4 @@ pickle_path = "./save/"
use_crf = true use_crf = true
use_cuda = true use_cuda = true
rnn_hidden_units = 100 rnn_hidden_units = 100
word_emb_dim = 100
word_emb_dim = 100

+ 1
- 1
setup.py View File

@@ -13,7 +13,7 @@ with open('requirements.txt', encoding='utf-8') as f:


setup( setup(
name='fastNLP', name='fastNLP',
version='0.1.0',
version='0.1.1',
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team', description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
long_description=readme, long_description=readme,
license=license, license=license,


Loading…
Cancel
Save