Browse Source

- add Const

- fix bugs
tags/v0.4.10
yunfan 6 years ago
parent
commit
e864aecb03
5 changed files with 70 additions and 18 deletions
  1. +46
    -0
      fastNLP/core/const.py
  2. +5
    -5
      fastNLP/io/dataset_loader.py
  3. +12
    -11
      fastNLP/models/star_transformer.py
  4. +1
    -1
      test/core/test_callbacks.py
  5. +6
    -1
      test/io/test_dataset_loader.py

+ 46
- 0
fastNLP/core/const.py View File

@@ -0,0 +1,46 @@
class Const():
"""fastNLP中field命名常量。
具体列表::

INPUT 模型的序列输入 words(复数words1, words2)
CHAR_INPUT 模型character输入 chars(复数chars1, chars2)
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2)
OUTPUT 模型输出 pred(复数pred1, pred2)
TARGET 真实目标 target(复数target1,target2)

"""
INPUT = 'words'
CHAR_INPUT = 'chars'
INPUT_LEN = 'seq_len'
OUTPUT = 'pred'
TARGET = 'target'

@staticmethod
def INPUTS(i):
"""得到第 i 个 ``INPUT`` 的命名"""
i = int(i) + 1
return Const.INPUT + str(i)

@staticmethod
def CHAR_INPUTS(i):
"""得到第 i 个 ``CHAR_INPUT`` 的命名"""
i = int(i) + 1
return Const.CHAR_INPUT + str(i)

@staticmethod
def INPUT_LENS(i):
"""得到第 i 个 ``INPUT_LEN`` 的命名"""
i = int(i) + 1
return Const.INPUT_LEN + str(i)

@staticmethod
def OUTPUTS(i):
"""得到第 i 个 ``OUTPUT`` 的命名"""
i = int(i) + 1
return Const.OUTPUT + str(i)

@staticmethod
def TARGETS(i):
"""得到第 i 个 ``TARGET`` 的命名"""
i = int(i) + 1
return Const.TARGET + str(i)

+ 5
- 5
fastNLP/io/dataset_loader.py View File

@@ -193,9 +193,9 @@ class ConllLoader(DataSetLoader):

:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
"""
def __init__(self, headers, indexs=None, dropna=True):
def __init__(self, headers, indexs=None, dropna=False):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
raise TypeError('invalid headers: {}, should be list of strings'.format(headers))
@@ -314,7 +314,7 @@ class JsonLoader(DataSetLoader):
`value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``True``
Default: ``False``
"""
def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__()
@@ -375,9 +375,9 @@ class CSVLoader(DataSetLoader):
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None``
:param str sep: CSV文件中列与列之间的分隔符. Default: ","
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``True``
Default: ``False``
"""
def __init__(self, headers=None, sep=",", dropna=True):
def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers
self.sep = sep
self.dropna = dropna


+ 12
- 11
fastNLP/models/star_transformer.py View File

@@ -1,8 +1,9 @@
"""Star-Transformer 的 一个 Pytorch 实现.
"""
from fastNLP.modules.encoder.star_transformer import StarTransformer
from fastNLP.core.utils import seq_lens_to_masks
from ..modules.encoder.star_transformer import StarTransformer
from ..core.utils import seq_lens_to_masks
from ..modules.utils import get_embeddings
from ..core.const import Const

import torch
from torch import nn
@@ -139,7 +140,7 @@ class STSeqLabel(nn.Module):
nodes, _ = self.enc(words, mask)
output = self.cls(nodes)
output = output.transpose(1,2) # make hidden to be dim 1
return {'output': output} # [bsz, n_cls, seq_len]
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len]

def predict(self, words, seq_len):
"""
@@ -149,8 +150,8 @@ class STSeqLabel(nn.Module):
:return output: [batch, seq_len] 输出序列中每个元素的分类
"""
y = self.forward(words, seq_len)
_, pred = y['output'].max(1)
return {'output': pred}
_, pred = y[Const.OUTPUT].max(1)
return {Const.OUTPUT: pred}


class STSeqCls(nn.Module):
@@ -201,7 +202,7 @@ class STSeqCls(nn.Module):
nodes, relay = self.enc(words, mask)
y = 0.5 * (relay + nodes.max(1)[0])
output = self.cls(y) # [bsz, n_cls]
return {'output': output}
return {Const.OUTPUT: output}

def predict(self, words, seq_len):
"""
@@ -211,8 +212,8 @@ class STSeqCls(nn.Module):
:return output: [batch, num_cls] 输出序列的分类
"""
y = self.forward(words, seq_len)
_, pred = y['output'].max(1)
return {'output': pred}
_, pred = y[Const.OUTPUT].max(1)
return {Const.OUTPUT: pred}


class STNLICls(nn.Module):
@@ -269,7 +270,7 @@ class STNLICls(nn.Module):
y1 = enc(words1, mask1)
y2 = enc(words2, mask2)
output = self.cls(y1, y2) # [bsz, n_cls]
return {'output': output}
return {Const.OUTPUT: output}

def predict(self, words1, words2, seq_len1, seq_len2):
"""
@@ -281,5 +282,5 @@ class STNLICls(nn.Module):
:return output: [batch, num_cls] 输出分类的概率
"""
y = self.forward(words1, words2, seq_len1, seq_len2)
_, pred = y['output'].max(1)
return {'output': pred}
_, pred = y[Const.OUTPUT].max(1)
return {Const.OUTPUT: pred}

+ 1
- 1
test/core/test_callbacks.py View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np
import torch

from fastNLP.core.callback import EchoCallback, EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
LRFinder, \
TensorboardCallback
from fastNLP.core.dataset import DataSet


+ 6
- 1
test/io/test_dataset_loader.py View File

@@ -1,7 +1,7 @@
import unittest

from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \
CSVLoader, SNLILoader
CSVLoader, SNLILoader, JsonLoader

class TestDatasetLoader(unittest.TestCase):

@@ -24,3 +24,8 @@ class TestDatasetLoader(unittest.TestCase):
def test_SNLILoader(self):
ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl')
assert len(ds) == 3

def test_JsonLoader(self):
ds = JsonLoader().load('test/data_for_tests/sample_snli.jsonl')
assert len(ds) == 3


Loading…
Cancel
Save