Browse Source

Merge branch 'dev0.8.0' of https://github.com/fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
lxr-tech 3 years ago
parent
commit
040dfa654a
23 changed files with 195 additions and 189 deletions
  1. +1
    -1
      fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py
  2. +2
    -2
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  3. +2
    -3
      fastNLP/core/metrics/utils.py
  4. +2
    -2
      tests/core/callbacks/test_load_best_model_callback_torch.py
  5. +0
    -1
      tests/core/controllers/_test_trainer_fleet.py
  6. +0
    -1
      tests/core/controllers/_test_trainer_fleet_outside.py
  7. +1
    -3
      tests/core/controllers/test_trainer_paddle.py
  8. +0
    -0
      tests/core/dataloaders/jittor_dataloader/__init__.py
  9. +0
    -0
      tests/core/dataloaders/paddle_dataloader/__init__.py
  10. +0
    -0
      tests/core/dataloaders/torch_dataloader/__init__.py
  11. +0
    -1
      tests/core/drivers/paddle_driver/test_dist_utils.py
  12. +0
    -2
      tests/core/drivers/paddle_driver/test_fleet.py
  13. +0
    -3
      tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py
  14. +0
    -3
      tests/core/drivers/paddle_driver/test_single_device.py
  15. +0
    -2
      tests/core/drivers/paddle_driver/test_utils.py
  16. +31
    -0
      tests/core/drivers/torch_driver/test.py
  17. +0
    -2
      tests/core/drivers/torch_driver/test_ddp.py
  18. +0
    -3
      tests/core/drivers/torch_driver/test_initialize_torch_driver.py
  19. +0
    -2
      tests/core/drivers/torch_driver/test_single_device.py
  20. +0
    -2
      tests/core/drivers/torch_driver/test_utils.py
  21. +0
    -0
      tests/core/log/test_logger_torch.py
  22. +147
    -147
      tests/core/samplers/test_reproducible_batch_sampler.py
  23. +9
    -9
      tests/core/samplers/test_unrepeated_sampler.py

+ 1
- 1
fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py View File

@@ -14,7 +14,7 @@ if _NEED_IMPORT_PADDLE:
import paddle

def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]],
model: paddle.nn.Layer, **kwargs) -> PaddleDriver:
model: "paddle.nn.Layer", **kwargs) -> PaddleDriver:
r"""
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去;
1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将


+ 2
- 2
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -11,8 +11,8 @@ from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_BACKEND_LAUNCH


def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.device, int, List[int]]],
model: torch.nn.Module, **kwargs) -> TorchDriver:
def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.device", int, List[int]]],
model: "torch.nn.Module", **kwargs) -> TorchDriver:
r"""
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去;
注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错;


+ 2
- 3
fastNLP/core/metrics/utils.py View File

@@ -11,9 +11,8 @@ _IS_ALLENNLP_AVAILABLE = _module_available('allennlp')
if _IS_ALLENNLP_AVAILABLE:
from allennlp.training.metrics import Metric as allennlp_Metric

if _NEED_IMPORT_TORCH and _IS_TORCHMETRICS_AVAILABLE:
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric
if _IS_TORCHMETRICS_AVAILABLE:
from torchmetrics import Metric as torchmetrics_Metric

if _NEED_IMPORT_PADDLE:
from paddle.metric import Metric as paddle_Metric


+ 2
- 2
tests/core/callbacks/test_load_best_model_callback_torch.py View File

@@ -16,7 +16,7 @@ from fastNLP.core.controllers.trainer import Trainer
from fastNLP.core.metrics.accuracy import Accuracy
from fastNLP.core.callbacks.load_best_model_callback import LoadBestModelCallback
from fastNLP.core import Evaluator
from fastNLP.core.utils.utils import safe_rm
from fastNLP.core import rank_zero_rm
from fastNLP.core.drivers.torch_driver import TorchSingleDriver
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
@@ -112,7 +112,7 @@ def test_load_best_model_callback(
results = evaluator.run()
assert np.allclose(callbacks[0].monitor_value, results['acc#acc#dl1'])
if save_folder:
safe_rm(save_folder)
rank_zero_rm(save_folder)
if dist.is_initialized():
dist.destroy_process_group()



tests/core/controllers/test_trainer_fleet.py → tests/core/controllers/_test_trainer_fleet.py View File

@@ -4,7 +4,6 @@
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet.py
"""
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
import sys
sys.path.append("../../../")


tests/core/controllers/test_trainer_fleet_outside.py → tests/core/controllers/_test_trainer_fleet_outside.py View File

@@ -4,7 +4,6 @@
python -m paddle.distributed.launch --gpus=0,2,3 test_trainer_fleet_outside.py
"""
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
import sys
sys.path.append("../../../")


+ 1
- 3
tests/core/controllers/test_trainer_paddle.py View File

@@ -1,6 +1,4 @@
import pytest
import os
os.environ["FASTNLP_BACKEND"] = "paddle"
from dataclasses import dataclass

from fastNLP.core.controllers.trainer import Trainer
@@ -25,7 +23,7 @@ class TrainPaddleConfig:
shuffle: bool = True
evaluate_every = 2

@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1)])
@pytest.mark.parametrize("driver,device", [("paddle", "cpu"), ("paddle", 1), ("fleet", [0, 1])])
# @pytest.mark.parametrize("driver,device", [("fleet", [0, 1])])
@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc#acc", metric_threshold=0.0, larger_better=True),
RichCallback(5)]])


+ 0
- 0
tests/core/dataloaders/jittor_dataloader/__init__.py View File


+ 0
- 0
tests/core/dataloaders/paddle_dataloader/__init__.py View File


+ 0
- 0
tests/core/dataloaders/torch_dataloader/__init__.py View File


+ 0
- 1
tests/core/drivers/paddle_driver/test_dist_utils.py View File

@@ -3,7 +3,6 @@ import sys
import signal
import pytest
import traceback
os.environ["FASTNLP_BACKEND"] = "paddle"

import numpy as np



+ 0
- 2
tests/core/drivers/paddle_driver/test_fleet.py View File

@@ -1,8 +1,6 @@
import pytest
import os
from pathlib import Path

os.environ["FASTNLP_BACKEND"] = "paddle"
from fastNLP.core.drivers.paddle_driver.fleet import PaddleFleetDriver
from fastNLP.core.samplers import (
RandomSampler,


+ 0
- 3
tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py View File

@@ -1,8 +1,5 @@
import os
import pytest

os.environ["FASTNLP_BACKEND"] = "paddle"

from fastNLP.core.drivers import PaddleSingleDriver, PaddleFleetDriver
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver
from fastNLP.envs import get_gpu_count


+ 0
- 3
tests/core/drivers/paddle_driver/test_single_device.py View File

@@ -1,6 +1,3 @@
import os
from re import S
os.environ["FASTNLP_BACKEND"] = "paddle"
import pytest
from pathlib import Path



+ 0
- 2
tests/core/drivers/paddle_driver/test_utils.py View File

@@ -1,6 +1,4 @@
import os
import pytest
os.environ["FASTNLP_BACKEND"] = "paddle"

from fastNLP.core.drivers.paddle_driver.utils import (
get_device_from_visible,


+ 31
- 0
tests/core/drivers/torch_driver/test.py View File

@@ -0,0 +1,31 @@
import sys
sys.path.append("../../../../")
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1

import torch

device = [0, 1]
torch_model = TorchNormalModel_Classification_1(10, 10)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
driver = TorchDDPDriver(
model=torch_model,
parallel_device=device,
fp16=False
)
driver.set_optimizers(torch_opt)
driver.setup()
print("-----------first--------------")

device = [0, 2]
torch_model = TorchNormalModel_Classification_1(10, 10)
torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01)
device = [torch.device(i) for i in device]
driver = TorchDDPDriver(
model=torch_model,
parallel_device=device,
fp16=False
)
driver.set_optimizers(torch_opt)
driver.setup()

+ 0
- 2
tests/core/drivers/torch_driver/test_ddp.py View File

@@ -1,8 +1,6 @@
import pytest
import os
from pathlib import Path

os.environ["FASTNLP_BACKEND"] = "torch"
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver
from fastNLP.core.samplers import (
RandomSampler,


+ 0
- 3
tests/core/drivers/torch_driver/test_initialize_torch_driver.py View File

@@ -1,8 +1,5 @@
import os
import pytest

os.environ["FASTNLP_BACKEND"] = "torch"

from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver
from fastNLP.envs import get_gpu_count


+ 0
- 2
tests/core/drivers/torch_driver/test_single_device.py View File

@@ -1,5 +1,3 @@
import os
os.environ["FASTNLP_BACKEND"] = "torch"
import pytest
from pathlib import Path



+ 0
- 2
tests/core/drivers/torch_driver/test_utils.py View File

@@ -1,6 +1,4 @@
import os
import pytest
os.environ["FASTNLP_BACKEND"] = "torch"

from fastNLP.core.drivers.torch_driver.utils import (
replace_batch_sampler,


tests/core/log/test_logger.py → tests/core/log/test_logger_torch.py View File


+ 147
- 147
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -9,153 +9,153 @@ from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset

class TestReproducibleBatchSampler:
# TODO 拆分测试,在这里只测试一个东西
def test_torch_dataloader_1(self):
import torch
from torch.utils.data import DataLoader
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "RandomBatchSampler"}
# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])
# 改变 batch_size;
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))
for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])
# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break
# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break
def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())
# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, RandomBatchSampler)
state = _get_re_batchsampler.state_dict()
# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)
# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)
# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.append(next(iter_dataloader))
except StopIteration:
break
def test_3(self):
import torch
from torch.utils.data import DataLoader
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size)
for idx, data in enumerate(dataloader):
if idx > 3:
break
iterator = iter(dataloader)
for each in iterator:
pass
#
# class TestReproducibleBatchSampler:
# # TODO 拆分测试,在这里只测试一个东西
# def test_torch_dataloader_1(self):
# import torch
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# next(iter_dataloader)
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "RandomBatchSampler"}
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 改变 batch_size;
# after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 断点重训的第二轮是否是一个完整的 dataloader;
# # 先把断点重训所在的那一个 epoch 跑完;
# begin_idx = 27
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# # 开始新的一轮;
# begin_idx = 0
# iter_dataloader = iter(dataloader)
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# def test_torch_dataloader_2(self):
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
# all_supposed_data = []
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# all_supposed_data.extend(next(iter_dataloader).tolist())
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, RandomBatchSampler)
# state = _get_re_batchsampler.state_dict()
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 先把这一轮的数据过完;
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
# while True:
# try:
# all_supposed_data.extend(next(iter_dataloader).tolist())
# except StopIteration:
# break
# assert all_supposed_data == list(pre_index_list)
#
# # 重新开启新的一轮;
# for _ in range(3):
# iter_dataloader = iter(dataloader)
# res = []
# while True:
# try:
# res.append(next(iter_dataloader))
# except StopIteration:
# break
#
# def test_3(self):
# import torch
# from torch.utils.data import DataLoader
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
#
# for idx, data in enumerate(dataloader):
# if idx > 3:
# break
#
# iterator = iter(dataloader)
# for each in iterator:
# pass


class DatasetWithVaryLength:


+ 9
- 9
tests/core/samplers/test_unrepeated_sampler.py View File

@@ -28,12 +28,12 @@ class TestUnrepeatedSampler:
@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replica, num_of_data, shuffle):
def test_multi(self, num_replicas, num_of_data, shuffle):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
sampler.set_distributed(num_replica, rank=i)
sampler.set_distributed(num_replicas, rank=i)
samplers.append(sampler)

indexes = list(chain(*samplers))
@@ -52,12 +52,12 @@ class TestUnrepeatedSortedSampler:

@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
def test_multi(self, num_replicas, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i)
sampler.set_distributed(num_replicas, rank=i)
samplers.append(sampler)

# 保证顺序是没乱的
@@ -83,12 +83,12 @@ class TestUnrepeatedSequentialSampler:

@pytest.mark.parametrize('num_replicas', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
def test_multi(self, num_replica, num_of_data):
def test_multi(self, num_replicas, num_of_data):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
for i in range(num_replicas):
sampler = UnrepeatedSequentialSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i)
sampler.set_distributed(num_replicas, rank=i)
samplers.append(sampler)

# 保证顺序是没乱的


Loading…
Cancel
Save