|
- import os, sys
- import socket
- from typing import Union
-
- import torch
- from torch import distributed
- import numpy as np
-
-
- def setup_ddp(rank: int, world_size: int, master_port: int) -> None:
- """Setup ddp environment."""
-
- os.environ["MASTER_ADDR"] = "localhost"
- os.environ["MASTER_PORT"] = str(master_port)
- if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"):
- torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)
-
-
- def find_free_network_port() -> int:
- """Finds a free port on localhost.
-
- It is useful in single-node training when we don't want to connect to a real master node but have to set the
- `MASTER_PORT` environment variable.
- """
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- s.bind(("", 0))
- s.listen(1)
- port = s.getsockname()[1]
- s.close()
- return port
-
-
- def _assert_allclose(my_result: Union[float, np.ndarray], sklearn_result: Union[float, np.ndarray],
- atol: float = 1e-8) -> None:
- """
- 测试对比结果,这里不用非得是必须数组且维度对应,一些其他情况例如 np.allclose(np.array([[1e10, ], ]), 1e10+1) 也是 True
- :param my_result: 可以不限设备等
- :param sklearn_result:
- :param atol:
- :return:
- """
- assert np.allclose(a=my_result, b=sklearn_result, atol=atol)
|