Browse Source

seq_len_to_mask 添加oneflow部分;完善oneflow的测试标签

tags/v1.0.0beta
x54-729 2 years ago
parent
commit
5a41cca483
7 changed files with 53 additions and 17 deletions
  1. +2
    -2
      fastNLP/core/collators/padders/oneflow_padder.py
  2. +14
    -1
      fastNLP/core/utils/seq_len_to_mask.py
  3. +7
    -7
      tests/core/drivers/oneflow_driver/test_ddp.py
  4. +2
    -2
      tests/core/drivers/oneflow_driver/test_dist_utils.py
  5. +3
    -3
      tests/core/drivers/oneflow_driver/test_single_device.py
  6. +24
    -2
      tests/core/utils/test_seq_len_to_mask.py
  7. +1
    -0
      tests/pytest.ini

+ 2
- 2
fastNLP/core/collators/padders/oneflow_padder.py View File

@@ -169,7 +169,7 @@ class OneflowTensorPadder(Padder):
else: else:
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]


tensor = oneflow.full(max_shape, fill_value=pad_val, dtype=dtype, device=device)
tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device)
for i, field in enumerate(batch_field): for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
tensor[slices] = field tensor[slices] = field
@@ -221,6 +221,6 @@ def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0):
:return: :return:
""" """
shapes = get_shape(batch_field) shapes = get_shape(batch_field)
tensor = oneflow.full(shapes, dtype=dtype, fill_value=pad_val)
tensor = oneflow.full(shapes, dtype=dtype, value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype) tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor return tensor

+ 14
- 1
fastNLP/core/utils/seq_len_to_mask.py View File

@@ -1,7 +1,7 @@
from typing import Optional from typing import Optional


import numpy as np import numpy as np
from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE
from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_ONEFLOW
from .paddle_utils import paddle_to from .paddle_utils import paddle_to




@@ -14,6 +14,9 @@ if _NEED_IMPORT_PADDLE:
if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
import jittor import jittor


if _NEED_IMPORT_ONEFLOW:
import oneflow



def seq_len_to_mask(seq_len, max_len: Optional[int]=None): def seq_len_to_mask(seq_len, max_len: Optional[int]=None):
r""" r"""
@@ -80,5 +83,15 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None):
except NameError as e: except NameError as e:
pass pass


try:
if isinstance(seq_len, oneflow.Tensor):
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
batch_size = seq_len.shape[0]
broad_cast_seq_len = oneflow.arange(max_len).expand(batch_size, -1).to(seq_len)
mask = broad_cast_seq_len < seq_len.unsqueeze(1)
return mask
except NameError as e:
pass

raise TypeError("seq_len_to_mask function only supports numpy.ndarray, torch.Tensor, paddle.Tensor, " raise TypeError("seq_len_to_mask function only supports numpy.ndarray, torch.Tensor, paddle.Tensor, "
f"and jittor.Var, but got {type(seq_len)}") f"and jittor.Var, but got {type(seq_len)}")

+ 7
- 7
tests/core/drivers/oneflow_driver/test_ddp.py View File

@@ -80,7 +80,7 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=
# #
############################################################################ ############################################################################


@pytest.mark.oneflow
@pytest.mark.oneflowdist
class TestDDPDriverFunction: class TestDDPDriverFunction:
""" """
测试 OneflowDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 测试 OneflowDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题
@@ -159,7 +159,7 @@ class TestDDPDriverFunction:
# #
############################################################################ ############################################################################


@pytest.mark.oneflow
@pytest.mark.oneflowdist
class TestSetDistReproDataloader: class TestSetDistReproDataloader:


@classmethod @classmethod
@@ -510,7 +510,7 @@ class TestSetDistReproDataloader:
# 测试 save 和 load 相关的功能 # 测试 save 和 load 相关的功能
# #
############################################################################ ############################################################################
@pytest.mark.oneflow
@pytest.mark.oneflowdist
class TestSaveLoad: class TestSaveLoad:
""" """
测试多卡情况下 save 和 load 相关函数的表现 测试多卡情况下 save 和 load 相关函数的表现
@@ -740,7 +740,7 @@ class TestSaveLoad:
rank_zero_rm(path) rank_zero_rm(path)




@pytest.mark.oneflow
@pytest.mark.oneflowdist
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) @pytest.mark.parametrize("batch_size", ([1, 3, 16, 17]))
@pytest.mark.parametrize("drop_last", ([True, False])) @pytest.mark.parametrize("drop_last", ([True, False]))
@@ -790,7 +790,7 @@ def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible=True):
pass pass




@pytest.mark.oneflow
@pytest.mark.oneflowdist
@pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("shuffle", ([True, False]))
@pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) @pytest.mark.parametrize("batch_size", ([1, 3, 16, 17]))
@pytest.mark.parametrize("drop_last", ([True, False])) @pytest.mark.parametrize("drop_last", ([True, False]))
@@ -845,7 +845,7 @@ def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=T






@pytest.mark.oneflow
@pytest.mark.oneflowdist
@recover_logger @recover_logger
@pytest.mark.parametrize("inherit", ([True, False])) @pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_batch_sampler_dataloader(inherit): def test_customized_batch_sampler_dataloader(inherit):
@@ -897,7 +897,7 @@ def test_customized_batch_sampler_dataloader(inherit):
pass pass




@pytest.mark.oneflow
@pytest.mark.oneflowdist
@recover_logger @recover_logger
@pytest.mark.parametrize("inherit", ([True, False])) @pytest.mark.parametrize("inherit", ([True, False]))
def test_customized_sampler_dataloader(inherit): def test_customized_sampler_dataloader(inherit):


+ 2
- 2
tests/core/drivers/oneflow_driver/test_dist_utils.py View File

@@ -77,7 +77,7 @@ def test_tensor_object_transfer_tensor(device):
assert res["int"] == oneflow_dict["int"] assert res["int"] == oneflow_dict["int"]
assert res["string"] == oneflow_dict["string"] assert res["string"] == oneflow_dict["string"]


@pytest.mark.oneflow
@pytest.mark.oneflowdist
def test_fastnlp_oneflow_all_gather(): def test_fastnlp_oneflow_all_gather():
local_rank = int(os.environ["LOCAL_RANK"]) local_rank = int(os.environ["LOCAL_RANK"])
obj = { obj = {
@@ -113,7 +113,7 @@ def test_fastnlp_oneflow_all_gather():
assert len(data) == world_size assert len(data) == world_size
assert data[0] == data[1] assert data[0] == data[1]


@pytest.mark.oneflow
@pytest.mark.oneflowdist
def test_fastnlp_oneflow_broadcast_object(): def test_fastnlp_oneflow_broadcast_object():
local_rank = int(os.environ["LOCAL_RANK"]) local_rank = int(os.environ["LOCAL_RANK"])
if os.environ["LOCAL_RANK"] == "0": if os.environ["LOCAL_RANK"] == "0":


+ 3
- 3
tests/core/drivers/oneflow_driver/test_single_device.py View File

@@ -557,7 +557,7 @@ def test_save_and_load_model(only_state_dict):
res1 = driver1.model.evaluate_step(**batch) res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch)


assert oneflow.all(res1["preds"] == res2["preds"])
assert oneflow.all(res1["pred"] == res2["pred"])
finally: finally:
rank_zero_rm(path) rank_zero_rm(path)


@@ -623,7 +623,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16):
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) left_y_batches.update(batch["y"].reshape(-1, ).tolist())
res1 = driver1.model.evaluate_step(**batch) res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch)
assert oneflow.all(res1["preds"] == res2["preds"])
assert oneflow.all(res1["pred"] == res2["pred"])


assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset) assert len(left_x_batches | already_seen_x_set) == len(dataset)
@@ -698,7 +698,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16):
left_y_batches.update(batch["y"].reshape(-1, ).tolist()) left_y_batches.update(batch["y"].reshape(-1, ).tolist())
res1 = driver1.model.evaluate_step(**batch) res1 = driver1.model.evaluate_step(**batch)
res2 = driver2.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch)
assert oneflow.all(res1["preds"] == res2["preds"])
assert oneflow.all(res1["pred"] == res2["pred"])


assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) assert len(left_x_batches) + len(already_seen_x_set) == len(dataset)
assert len(left_x_batches | already_seen_x_set) == len(dataset) assert len(left_x_batches | already_seen_x_set) == len(dataset)


+ 24
- 2
tests/core/utils/test_seq_len_to_mask.py View File

@@ -1,7 +1,7 @@
import pytest import pytest
import numpy as np import numpy as np
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH, _NEED_IMPORT_ONEFLOW
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:
import torch import torch


@@ -11,6 +11,9 @@ if _NEED_IMPORT_PADDLE:
if _NEED_IMPORT_JITTOR: if _NEED_IMPORT_JITTOR:
import jittor import jittor


if _NEED_IMPORT_ONEFLOW:
import oneflow



class TestSeqLenToMask: class TestSeqLenToMask:


@@ -20,7 +23,7 @@ class TestSeqLenToMask:
length = seq_len[i] length = seq_len[i]
mask_i = mask[i] mask_i = mask[i]
for j in range(max_len): for j in range(max_len):
assert mask_i[j] == (j<length), (i, j, length)
assert mask_i[j].item() == (j<length), (i, j, length)


def test_numpy_seq_len(self): def test_numpy_seq_len(self):
# 测试能否转换numpy类型的seq_len # 测试能否转换numpy类型的seq_len
@@ -100,3 +103,22 @@ class TestSeqLenToMask:
seq_len = jittor.randint(1, 10, shape=(10,)) seq_len = jittor.randint(1, 10, shape=(10,))
mask = seq_len_to_mask(seq_len, 100) mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.shape[1] assert 100 == mask.shape[1]

@pytest.mark.oneflow
def test_pytorch_seq_len(self):
# 1. 随机测试
seq_len = oneflow.randint(1, 10, size=(10, ))
max_len = seq_len.max()
mask = seq_len_to_mask(seq_len)
assert max_len == mask.shape[1]
self.evaluate_mask_seq_len(seq_len.tolist(), mask)

# 2. 异常检测
seq_len = oneflow.randn(3, 4)
with pytest.raises(AssertionError):
mask = seq_len_to_mask(seq_len)

# 3. pad到指定长度
seq_len = oneflow.randint(1, 10, size=(10, ))
mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.size(1)

+ 1
- 0
tests/pytest.ini View File

@@ -9,3 +9,4 @@ markers =
torchjittor torchjittor
deepspeed deepspeed
torchoneflow torchoneflow
oneflowdist

Loading…
Cancel
Save