|
|
@@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method |
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
|
|
set_start_method("spawn", force=True) |
|
|
|
|
|
|
|
|
|
|
@@ -45,7 +46,6 @@ def setup_ddp(rank: int, world_size: int, master_port: int) -> None: |
|
|
|
|
|
|
|
os.environ["MASTER_ADDR"] = "localhost" |
|
|
|
os.environ["MASTER_PORT"] = str(master_port) |
|
|
|
print(torch.cuda.device_count()) |
|
|
|
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): |
|
|
|
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) |
|
|
|
|
|
|
@@ -64,15 +64,15 @@ def find_free_network_port() -> int: |
|
|
|
return port |
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope='class', autouse=True) |
|
|
|
def pre_process(): |
|
|
|
global pool |
|
|
|
pool = Pool(processes=NUM_PROCESSES) |
|
|
|
master_port = find_free_network_port() |
|
|
|
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) |
|
|
|
yield |
|
|
|
pool.close() |
|
|
|
pool.join() |
|
|
|
# @pytest.fixture(scope='class', autouse=True) |
|
|
|
# def pre_process(): |
|
|
|
# global pool |
|
|
|
# pool = Pool(processes=NUM_PROCESSES) |
|
|
|
# master_port = find_free_network_port() |
|
|
|
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) |
|
|
|
# yield |
|
|
|
# pool.close() |
|
|
|
# pool.join() |
|
|
|
|
|
|
|
|
|
|
|
def _test(local_rank: int, |
|
|
@@ -87,18 +87,19 @@ def _test(local_rank: int, |
|
|
|
# dataset 也类似(每个进程有自己的一个) |
|
|
|
dataset = copy.deepcopy(dataset) |
|
|
|
metric.to(device) |
|
|
|
print(os.environ.get("MASTER_PORT", "xx")) |
|
|
|
# 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上) |
|
|
|
for i in range(local_rank, len(dataset), world_size): |
|
|
|
pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len'] |
|
|
|
print(tg, seq_len) |
|
|
|
metric.update(pred, tg, seq_len) |
|
|
|
|
|
|
|
my_result = metric.get_metric() |
|
|
|
print(my_result) |
|
|
|
print(sklearn_metric) |
|
|
|
assert my_result == sklearn_metric |
|
|
|
|
|
|
|
|
|
|
|
class SpanFPreRecMetricTest(unittest.TestCase): |
|
|
|
global pool |
|
|
|
|
|
|
|
def test_case1(self): |
|
|
|
from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans |
|
|
@@ -147,26 +148,26 @@ class SpanFPreRecMetricTest(unittest.TestCase): |
|
|
|
-1.3508, -0.9513], |
|
|
|
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, |
|
|
|
-0.0842, -0.4294]], |
|
|
|
|
|
|
|
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, |
|
|
|
-1.4138, -0.8853], |
|
|
|
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, |
|
|
|
-1.0726, 0.0364], |
|
|
|
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, |
|
|
|
-0.8836, -0.9320], |
|
|
|
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, |
|
|
|
-1.6857, 1.1571], |
|
|
|
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, |
|
|
|
-0.5837, 1.0184], |
|
|
|
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, |
|
|
|
-0.9025, 0.0864]]]) |
|
|
|
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], |
|
|
|
[4, 1, 7, 0, 4, 7]]) |
|
|
|
[ |
|
|
|
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, |
|
|
|
-1.4138, -0.8853], |
|
|
|
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, |
|
|
|
-1.0726, 0.0364], |
|
|
|
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, |
|
|
|
-0.8836, -0.9320], |
|
|
|
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, |
|
|
|
-1.6857, 1.1571], |
|
|
|
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, |
|
|
|
-0.5837, 1.0184], |
|
|
|
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, |
|
|
|
-0.9025, 0.0864]] |
|
|
|
] |
|
|
|
]) |
|
|
|
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) |
|
|
|
fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) |
|
|
|
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, |
|
|
|
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, |
|
|
|
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} |
|
|
|
|
|
|
|
assert expect_bio_res == fastnlp_bio_metric.get_metric() |
|
|
|
# print(fastnlp_bio_metric.get_metric()) |
|
|
|
|
|
|
@@ -325,44 +326,52 @@ class SpanFPreRecMetricTest(unittest.TestCase): |
|
|
|
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') |
|
|
|
|
|
|
|
def test_case5(self): |
|
|
|
global pool |
|
|
|
# pool = Pool(NUM_PROCESSES) |
|
|
|
# master_port = find_free_network_port() |
|
|
|
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) |
|
|
|
# global pool |
|
|
|
pool = Pool(NUM_PROCESSES) |
|
|
|
master_port = find_free_network_port() |
|
|
|
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) |
|
|
|
number_labels = 4 |
|
|
|
# bio tag |
|
|
|
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) |
|
|
|
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) |
|
|
|
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) |
|
|
|
dataset = DataSet({'pred': [torch.FloatTensor( |
|
|
|
[[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, |
|
|
|
-0.3782, 0.8240], |
|
|
|
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, |
|
|
|
-0.3562, -1.4116], |
|
|
|
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, |
|
|
|
2.0023, 0.7075], |
|
|
|
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, |
|
|
|
0.3832, -0.1540], |
|
|
|
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, |
|
|
|
-1.3508, -0.9513], |
|
|
|
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, |
|
|
|
-0.0842, -0.4294]], |
|
|
|
|
|
|
|
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, |
|
|
|
-1.4138, -0.8853], |
|
|
|
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, |
|
|
|
-1.0726, 0.0364], |
|
|
|
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, |
|
|
|
-0.8836, -0.9320], |
|
|
|
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, |
|
|
|
-1.6857, 1.1571], |
|
|
|
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, |
|
|
|
-0.5837, 1.0184], |
|
|
|
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, |
|
|
|
-0.9025, 0.0864]]])] * 100, |
|
|
|
'tg': [torch.LongTensor([[3, 6, 0, 8, 2, 4], |
|
|
|
[4, 1, 7, 0, 4, 7]])] * 100, |
|
|
|
'seq_len': [[6, 6]] * 100}) |
|
|
|
dataset = DataSet({'pred': [ |
|
|
|
torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, |
|
|
|
-0.3782, 0.8240], |
|
|
|
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, |
|
|
|
-0.3562, -1.4116], |
|
|
|
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, |
|
|
|
2.0023, 0.7075], |
|
|
|
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, |
|
|
|
0.3832, -0.1540], |
|
|
|
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, |
|
|
|
-1.3508, -0.9513], |
|
|
|
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, |
|
|
|
-0.0842, -0.4294]] |
|
|
|
|
|
|
|
]), |
|
|
|
torch.FloatTensor([ |
|
|
|
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, |
|
|
|
-1.4138, -0.8853], |
|
|
|
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, |
|
|
|
-1.0726, 0.0364], |
|
|
|
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, |
|
|
|
-0.8836, -0.9320], |
|
|
|
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, |
|
|
|
-1.6857, 1.1571], |
|
|
|
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, |
|
|
|
-0.5837, 1.0184], |
|
|
|
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, |
|
|
|
-0.9025, 0.0864]] |
|
|
|
]) |
|
|
|
], |
|
|
|
'tg': [ |
|
|
|
torch.LongTensor([[3, 6, 0, 8, 2, 4]]), |
|
|
|
torch.LongTensor([[4, 1, 7, 0, 4, 7]]) |
|
|
|
], |
|
|
|
'seq_len': [ |
|
|
|
[6], [6] |
|
|
|
]}) |
|
|
|
metric_kwargs = { |
|
|
|
'tag_vocab': fastnlp_bio_vocab, |
|
|
|
'only_gross': False, |
|
|
@@ -372,7 +381,6 @@ class SpanFPreRecMetricTest(unittest.TestCase): |
|
|
|
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, |
|
|
|
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} |
|
|
|
processes = NUM_PROCESSES |
|
|
|
print(torch.cuda.device_count()) |
|
|
|
|
|
|
|
pool.starmap( |
|
|
|
partial( |
|
|
@@ -384,3 +392,5 @@ class SpanFPreRecMetricTest(unittest.TestCase): |
|
|
|
), |
|
|
|
[(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] |
|
|
|
) |
|
|
|
pool.close() |
|
|
|
pool.join() |