Browse Source

1. 从安装文件中删除api/automl的安装

2. metric中存在seq_len的bug
3. sampler中存在命名错误,已修改
tags/v0.4.10
yh_cc 5 years ago
parent
commit
b9558e21e6
4 changed files with 11 additions and 10 deletions
  1. +2
    -0
      MANIFEST.in
  2. +4
    -5
      fastNLP/core/metrics.py
  3. +4
    -4
      fastNLP/core/sampler.py
  4. +1
    -1
      test/core/test_sampler.py

+ 2
- 0
MANIFEST.in View File

@@ -3,3 +3,5 @@ include LICENSE
include README.md
prune test/
prune reproduction/
prune fastNLP/api
prune fastNLP/automl

+ 4
- 5
fastNLP/core/metrics.py View File

@@ -269,7 +269,7 @@ class AccuracyMetric(MetricBase):
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 参数映射表中 `seq_lens` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
"""
def __init__(self, pred=None, target=None, seq_len=None):
@@ -458,7 +458,7 @@ class SpanFPreRecMetric(MetricBase):
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_lens'取数据。
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。
:param str encoding_type: 目前支持bio, bmes
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
个label
@@ -729,9 +729,8 @@ class BMESF1PreRecMetric(MetricBase):
f"{pred.size()[:-1]}, got {target.size()}.")

for idx in range(len(pred)):
seq_len = seq_len[idx]
target_tags = target[idx][:seq_len].tolist()
pred_tags = pred[idx][:seq_len]
target_tags = target[idx][:seq_len[idx]].tolist()
pred_tags = pred[idx][:seq_len[idx]]
pred_tags = self._validate_tags(pred_tags)
start_idx = 0
for t_idx, (t_tag, p_tag) in enumerate(zip(target_tags, pred_tags)):


+ 4
- 4
fastNLP/core/sampler.py View File

@@ -59,16 +59,16 @@ class BucketSampler(Sampler):

:param int num_buckets: bucket的数量
:param int batch_size: batch的大小
:param str seq_lens_field_name: 对应序列长度的 `field` 的名字
:param str seq_len_field_name: 对应序列长度的 `field` 的名字
"""
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_len'):
def __init__(self, num_buckets=10, batch_size=32, seq_len_field_name='seq_len'):
self.num_buckets = num_buckets
self.batch_size = batch_size
self.seq_lens_field_name = seq_lens_field_name
self.seq_len_field_name = seq_len_field_name
def __call__(self, data_set):
seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content
seq_lens = data_set.get_all_fields()[self.seq_len_field_name].content
total_sample_num = len(seq_lens)
bucket_indexes = []


+ 1
- 1
test/core/test_sampler.py View File

@@ -38,7 +38,7 @@ class TestSampler(unittest.TestCase):
assert len(_) == 10

def test_BucketSampler(self):
sampler = BucketSampler(num_buckets=3, batch_size=16, seq_lens_field_name="seq_len")
sampler = BucketSampler(num_buckets=3, batch_size=16, seq_len_field_name="seq_len")
data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
data_set.apply(lambda ins: len(ins["x"]), new_field_name="seq_len")
indices = sampler(data_set)


Loading…
Cancel
Save