From 5a41cca4838bf4b7a2725e1a42b4d33b951fd5ed Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 17 Jul 2022 18:39:45 +0800 Subject: [PATCH] =?UTF-8?q?seq=5Flen=5Fto=5Fmask=20=E6=B7=BB=E5=8A=A0onefl?= =?UTF-8?q?ow=E9=83=A8=E5=88=86=EF=BC=9B=E5=AE=8C=E5=96=84oneflow=E7=9A=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=A0=87=E7=AD=BE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/collators/padders/oneflow_padder.py | 4 +-- fastNLP/core/utils/seq_len_to_mask.py | 15 ++++++++++- tests/core/drivers/oneflow_driver/test_ddp.py | 14 +++++----- .../drivers/oneflow_driver/test_dist_utils.py | 4 +-- .../oneflow_driver/test_single_device.py | 6 ++--- tests/core/utils/test_seq_len_to_mask.py | 26 +++++++++++++++++-- tests/pytest.ini | 1 + 7 files changed, 53 insertions(+), 17 deletions(-) diff --git a/fastNLP/core/collators/padders/oneflow_padder.py b/fastNLP/core/collators/padders/oneflow_padder.py index 3f2b8bce..30d73e26 100644 --- a/fastNLP/core/collators/padders/oneflow_padder.py +++ b/fastNLP/core/collators/padders/oneflow_padder.py @@ -169,7 +169,7 @@ class OneflowTensorPadder(Padder): else: max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] - tensor = oneflow.full(max_shape, fill_value=pad_val, dtype=dtype, device=device) + tensor = oneflow.full(max_shape, value=pad_val, dtype=dtype, device=device) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) tensor[slices] = field @@ -221,6 +221,6 @@ def get_padded_oneflow_tensor(batch_field, dtype=None, pad_val=0): :return: """ shapes = get_shape(batch_field) - tensor = oneflow.full(shapes, dtype=dtype, fill_value=pad_val) + tensor = oneflow.full(shapes, dtype=dtype, value=pad_val) tensor = fill_tensor(batch_field, tensor, dtype=dtype) return tensor diff --git a/fastNLP/core/utils/seq_len_to_mask.py b/fastNLP/core/utils/seq_len_to_mask.py index 710c0a2b..5458602d 100644 --- a/fastNLP/core/utils/seq_len_to_mask.py +++ b/fastNLP/core/utils/seq_len_to_mask.py @@ -1,7 +1,7 @@ from typing import Optional import numpy as np -from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE +from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_ONEFLOW from .paddle_utils import paddle_to @@ -14,6 +14,9 @@ if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_JITTOR: import jittor +if _NEED_IMPORT_ONEFLOW: + import oneflow + def seq_len_to_mask(seq_len, max_len: Optional[int]=None): r""" @@ -80,5 +83,15 @@ def seq_len_to_mask(seq_len, max_len: Optional[int]=None): except NameError as e: pass + try: + if isinstance(seq_len, oneflow.Tensor): + assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." + batch_size = seq_len.shape[0] + broad_cast_seq_len = oneflow.arange(max_len).expand(batch_size, -1).to(seq_len) + mask = broad_cast_seq_len < seq_len.unsqueeze(1) + return mask + except NameError as e: + pass + raise TypeError("seq_len_to_mask function only supports numpy.ndarray, torch.Tensor, paddle.Tensor, " f"and jittor.Var, but got {type(seq_len)}") \ No newline at end of file diff --git a/tests/core/drivers/oneflow_driver/test_ddp.py b/tests/core/drivers/oneflow_driver/test_ddp.py index 8fa92924..c1d49230 100644 --- a/tests/core/drivers/oneflow_driver/test_ddp.py +++ b/tests/core/drivers/oneflow_driver/test_ddp.py @@ -80,7 +80,7 @@ def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed= # ############################################################################ -@pytest.mark.oneflow +@pytest.mark.oneflowdist class TestDDPDriverFunction: """ 测试 OneflowDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 @@ -159,7 +159,7 @@ class TestDDPDriverFunction: # ############################################################################ -@pytest.mark.oneflow +@pytest.mark.oneflowdist class TestSetDistReproDataloader: @classmethod @@ -510,7 +510,7 @@ class TestSetDistReproDataloader: # 测试 save 和 load 相关的功能 # ############################################################################ -@pytest.mark.oneflow +@pytest.mark.oneflowdist class TestSaveLoad: """ 测试多卡情况下 save 和 load 相关函数的表现 @@ -740,7 +740,7 @@ class TestSaveLoad: rank_zero_rm(path) -@pytest.mark.oneflow +@pytest.mark.oneflowdist @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) @pytest.mark.parametrize("drop_last", ([True, False])) @@ -790,7 +790,7 @@ def test_shuffle_dataloader(shuffle, batch_size, drop_last, reproducible=True): pass -@pytest.mark.oneflow +@pytest.mark.oneflowdist @pytest.mark.parametrize("shuffle", ([True, False])) @pytest.mark.parametrize("batch_size", ([1, 3, 16, 17])) @pytest.mark.parametrize("drop_last", ([True, False])) @@ -845,7 +845,7 @@ def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=T -@pytest.mark.oneflow +@pytest.mark.oneflowdist @recover_logger @pytest.mark.parametrize("inherit", ([True, False])) def test_customized_batch_sampler_dataloader(inherit): @@ -897,7 +897,7 @@ def test_customized_batch_sampler_dataloader(inherit): pass -@pytest.mark.oneflow +@pytest.mark.oneflowdist @recover_logger @pytest.mark.parametrize("inherit", ([True, False])) def test_customized_sampler_dataloader(inherit): diff --git a/tests/core/drivers/oneflow_driver/test_dist_utils.py b/tests/core/drivers/oneflow_driver/test_dist_utils.py index 45951519..10d06197 100644 --- a/tests/core/drivers/oneflow_driver/test_dist_utils.py +++ b/tests/core/drivers/oneflow_driver/test_dist_utils.py @@ -77,7 +77,7 @@ def test_tensor_object_transfer_tensor(device): assert res["int"] == oneflow_dict["int"] assert res["string"] == oneflow_dict["string"] -@pytest.mark.oneflow +@pytest.mark.oneflowdist def test_fastnlp_oneflow_all_gather(): local_rank = int(os.environ["LOCAL_RANK"]) obj = { @@ -113,7 +113,7 @@ def test_fastnlp_oneflow_all_gather(): assert len(data) == world_size assert data[0] == data[1] -@pytest.mark.oneflow +@pytest.mark.oneflowdist def test_fastnlp_oneflow_broadcast_object(): local_rank = int(os.environ["LOCAL_RANK"]) if os.environ["LOCAL_RANK"] == "0": diff --git a/tests/core/drivers/oneflow_driver/test_single_device.py b/tests/core/drivers/oneflow_driver/test_single_device.py index 011997ca..674e460f 100644 --- a/tests/core/drivers/oneflow_driver/test_single_device.py +++ b/tests/core/drivers/oneflow_driver/test_single_device.py @@ -557,7 +557,7 @@ def test_save_and_load_model(only_state_dict): res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) - assert oneflow.all(res1["preds"] == res2["preds"]) + assert oneflow.all(res1["pred"] == res2["pred"]) finally: rank_zero_rm(path) @@ -623,7 +623,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): left_y_batches.update(batch["y"].reshape(-1, ).tolist()) res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) - assert oneflow.all(res1["preds"] == res2["preds"]) + assert oneflow.all(res1["pred"] == res2["pred"]) assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) assert len(left_x_batches | already_seen_x_set) == len(dataset) @@ -698,7 +698,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): left_y_batches.update(batch["y"].reshape(-1, ).tolist()) res1 = driver1.model.evaluate_step(**batch) res2 = driver2.model.evaluate_step(**batch) - assert oneflow.all(res1["preds"] == res2["preds"]) + assert oneflow.all(res1["pred"] == res2["pred"]) assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) assert len(left_x_batches | already_seen_x_set) == len(dataset) diff --git a/tests/core/utils/test_seq_len_to_mask.py b/tests/core/utils/test_seq_len_to_mask.py index 64c84837..8a09571c 100644 --- a/tests/core/utils/test_seq_len_to_mask.py +++ b/tests/core/utils/test_seq_len_to_mask.py @@ -1,7 +1,7 @@ import pytest import numpy as np from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask -from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH, _NEED_IMPORT_ONEFLOW if _NEED_IMPORT_TORCH: import torch @@ -11,6 +11,9 @@ if _NEED_IMPORT_PADDLE: if _NEED_IMPORT_JITTOR: import jittor +if _NEED_IMPORT_ONEFLOW: + import oneflow + class TestSeqLenToMask: @@ -20,7 +23,7 @@ class TestSeqLenToMask: length = seq_len[i] mask_i = mask[i] for j in range(max_len): - assert mask_i[j] == (j