@@ -34,6 +34,7 @@ class TorchBackend(Backend): | |||||
if method is None: | if method is None: | ||||
raise AggregateMethodError(should_have_aggregate_method=True) | raise AggregateMethodError(should_have_aggregate_method=True) | ||||
tensor = self._gather_all(tensor) | tensor = self._gather_all(tensor) | ||||
# tensor = self.all_gather_object(tensor) | |||||
if isinstance(tensor[0], torch.Tensor): | if isinstance(tensor[0], torch.Tensor): | ||||
tensor = torch.stack(tensor) | tensor = torch.stack(tensor) | ||||
# 第一步, aggregate结果 | # 第一步, aggregate结果 | ||||
@@ -34,6 +34,7 @@ class Element: | |||||
自动aggregate对应的元素 | 自动aggregate对应的元素 | ||||
""" | """ | ||||
self._check_value_initialized() | |||||
try: | try: | ||||
self._value = self.backend.aggregate(self._value, self.aggregate_method) | self._value = self.backend.aggregate(self._value, self.aggregate_method) | ||||
except AggregateMethodError as e: | except AggregateMethodError as e: | ||||
@@ -216,9 +216,9 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
class SpanFPreRecMetric(Metric): | class SpanFPreRecMetric(Metric): | ||||
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, | |||||
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', | |||||
beta=1, aggregate_when_get_metric: bool = True,) -> None: | |||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||||
only_gross: bool = True, f_type='micro', | |||||
beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: | |||||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | ||||
@@ -249,9 +249,18 @@ class SpanFPreRecMetric(Metric): | |||||
self.only_gross = only_gross | self.only_gross = only_gross | ||||
self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||||
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||||
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||||
self._true_positives = {} | |||||
self._false_positives = {} | |||||
self._false_negatives = {} | |||||
for word, _ in tag_vocab: | |||||
word = word.lower() | |||||
if word != 'o': | |||||
word = word.split('-')[1] | |||||
if word in self._true_positives: | |||||
continue | |||||
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||||
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) | |||||
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) | |||||
def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
evaluate_result = {} | evaluate_result = {} | ||||
@@ -284,10 +293,17 @@ class SpanFPreRecMetric(Metric): | |||||
evaluate_result['rec'] = rec_sum / len(tags) | evaluate_result['rec'] = rec_sum / len(tags) | ||||
if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
tp, fn, fp = [], [], [] | |||||
for val in self._true_positives.values(): | |||||
tp.append(val.get_scalar()) | |||||
for val in self._false_negatives.values(): | |||||
fn.append(val.get_scalar()) | |||||
for val in self._false_positives.values(): | |||||
fp.append(val.get_scalar()) | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | f, pre, rec = _compute_f_pre_rec(self.beta_square, | ||||
sum(val.get_scalar() for val in self._true_positives.values()), | |||||
sum(val.get_scalar() for val in self._false_negatives.values()), | |||||
sum(val.get_scalar() for val in self._false_positives.values())) | |||||
sum(tp), | |||||
sum(fn), | |||||
sum(fp)) | |||||
evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
@@ -1,6 +1,6 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader | |||||
from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
@@ -9,17 +9,17 @@ class TestFdl(unittest.TestCase): | |||||
def test_init_v1(self): | def test_init_v1(self): | ||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||||
# for batch in fdl: | # for batch in fdl: | ||||
# print(batch) | # print(batch) | ||||
fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) | |||||
fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) | |||||
# for batch in fdl1: | # for batch in fdl1: | ||||
# print(batch) | # print(batch) | ||||
def test_set_padding(self): | def test_set_padding(self): | ||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
ds.set_pad_val("x", val=-1) | ds.set_pad_val("x", val=-1) | ||||
fdl = FDataLoader(ds, batch_size=3) | |||||
fdl = TorchDataLoader(ds, batch_size=3) | |||||
fdl.set_input("x", "y") | fdl.set_input("x", "y") | ||||
for batch in fdl: | for batch in fdl: | ||||
print(batch) | print(batch) | ||||
@@ -36,7 +36,7 @@ class TestFdl(unittest.TestCase): | |||||
_dict["Y"].append(ins['y']) | _dict["Y"].append(ins['y']) | ||||
return _dict | return _dict | ||||
fdl = FDataLoader(ds, batch_size=3, as_numpy=True) | |||||
fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) | |||||
fdl.set_input("x", "y") | fdl.set_input("x", "y") | ||||
fdl.add_collator(collate_fn) | fdl.add_collator(collate_fn) | ||||
for batch in fdl: | for batch in fdl: | ||||
@@ -44,7 +44,7 @@ class TestFdl(unittest.TestCase): | |||||
def test_get_batch_indices(self): | def test_get_batch_indices(self): | ||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
fdl = FDataLoader(ds, batch_size=3, shuffle=True) | |||||
fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) | |||||
fdl.set_input("y", "x") | fdl.set_input("y", "x") | ||||
for batch in fdl: | for batch in fdl: | ||||
print(fdl.get_batch_indices()) | print(fdl.get_batch_indices()) | ||||
@@ -67,30 +67,30 @@ class TestFdl(unittest.TestCase): | |||||
return object.__getattribute__(self, item) | return object.__getattribute__(self, item) | ||||
dataset = _DataSet() | dataset = _DataSet() | ||||
dl = FDataLoader(dataset, batch_size=2, shuffle=True) | |||||
dl = TorchDataLoader(dataset, batch_size=2, shuffle=True) | |||||
# dl.set_inputs('data', 'labels') | # dl.set_inputs('data', 'labels') | ||||
# dl.set_pad_val('labels', val=None) | # dl.set_pad_val('labels', val=None) | ||||
for batch in dl: | for batch in dl: | ||||
print(batch) | print(batch) | ||||
print(dl.get_batch_indices()) | print(dl.get_batch_indices()) | ||||
def test_prepare_dataloader(self): | |||||
def test_prepare_torch_dataloader(self): | |||||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl, FDataLoader) | |||||
dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||||
assert isinstance(dl, TorchDataLoader) | |||||
ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | ||||
dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | ||||
dl_bundle = prepare_dataloader(dbl) | |||||
assert isinstance(dl_bundle['train'], FDataLoader) | |||||
assert isinstance(dl_bundle['val'], FDataLoader) | |||||
dl_bundle = prepare_torch_dataloader(dbl) | |||||
assert isinstance(dl_bundle['train'], TorchDataLoader) | |||||
assert isinstance(dl_bundle['val'], TorchDataLoader) | |||||
ds_dict = {'train_1': ds, 'val': ds1} | ds_dict = {'train_1': ds, 'val': ds1} | ||||
dl_dict = prepare_dataloader(ds_dict) | |||||
assert isinstance(dl_dict['train_1'], FDataLoader) | |||||
assert isinstance(dl_dict['val'], FDataLoader) | |||||
dl_dict = prepare_torch_dataloader(ds_dict) | |||||
assert isinstance(dl_dict['train_1'], TorchDataLoader) | |||||
assert isinstance(dl_dict['val'], TorchDataLoader) | |||||
sequence = [ds, ds1] | sequence = [ds, ds1] | ||||
seq_ds = prepare_dataloader(sequence) | |||||
assert isinstance(seq_ds[0], FDataLoader) | |||||
assert isinstance(seq_ds[1], FDataLoader) | |||||
seq_ds = prepare_torch_dataloader(sequence) | |||||
assert isinstance(seq_ds[0], TorchDataLoader) | |||||
assert isinstance(seq_ds[1], TorchDataLoader) |
@@ -118,7 +118,6 @@ class TestAccuracy: | |||||
def test_v1(self, is_ddp: bool, dataset: DataSet, metric_class: Type['Metric'], | def test_v1(self, is_ddp: bool, dataset: DataSet, metric_class: Type['Metric'], | ||||
metric_kwargs: Dict[str, Any]) -> None: | metric_kwargs: Dict[str, Any]) -> None: | ||||
global pool | global pool | ||||
print(pool) | |||||
if is_ddp: | if is_ddp: | ||||
if sys.platform == "win32": | if sys.platform == "win32": | ||||
pytest.skip("DDP not supported on windows") | pytest.skip("DDP not supported on windows") | ||||
@@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
set_start_method("spawn", force=True) | set_start_method("spawn", force=True) | ||||
@@ -45,7 +46,6 @@ def setup_ddp(rank: int, world_size: int, master_port: int) -> None: | |||||
os.environ["MASTER_ADDR"] = "localhost" | os.environ["MASTER_ADDR"] = "localhost" | ||||
os.environ["MASTER_PORT"] = str(master_port) | os.environ["MASTER_PORT"] = str(master_port) | ||||
print(torch.cuda.device_count()) | |||||
if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): | if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): | ||||
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) | torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) | ||||
@@ -64,15 +64,15 @@ def find_free_network_port() -> int: | |||||
return port | return port | ||||
@pytest.fixture(scope='class', autouse=True) | |||||
def pre_process(): | |||||
global pool | |||||
pool = Pool(processes=NUM_PROCESSES) | |||||
master_port = find_free_network_port() | |||||
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||||
yield | |||||
pool.close() | |||||
pool.join() | |||||
# @pytest.fixture(scope='class', autouse=True) | |||||
# def pre_process(): | |||||
# global pool | |||||
# pool = Pool(processes=NUM_PROCESSES) | |||||
# master_port = find_free_network_port() | |||||
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||||
# yield | |||||
# pool.close() | |||||
# pool.join() | |||||
def _test(local_rank: int, | def _test(local_rank: int, | ||||
@@ -87,18 +87,19 @@ def _test(local_rank: int, | |||||
# dataset 也类似(每个进程有自己的一个) | # dataset 也类似(每个进程有自己的一个) | ||||
dataset = copy.deepcopy(dataset) | dataset = copy.deepcopy(dataset) | ||||
metric.to(device) | metric.to(device) | ||||
print(os.environ.get("MASTER_PORT", "xx")) | |||||
# 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上) | # 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上) | ||||
for i in range(local_rank, len(dataset), world_size): | for i in range(local_rank, len(dataset), world_size): | ||||
pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len'] | pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len'] | ||||
print(tg, seq_len) | |||||
metric.update(pred, tg, seq_len) | metric.update(pred, tg, seq_len) | ||||
my_result = metric.get_metric() | my_result = metric.get_metric() | ||||
print(my_result) | |||||
print(sklearn_metric) | |||||
assert my_result == sklearn_metric | assert my_result == sklearn_metric | ||||
class SpanFPreRecMetricTest(unittest.TestCase): | class SpanFPreRecMetricTest(unittest.TestCase): | ||||
global pool | |||||
def test_case1(self): | def test_case1(self): | ||||
from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans | from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans | ||||
@@ -147,26 +148,26 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
-1.3508, -0.9513], | -1.3508, -0.9513], | ||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | ||||
-0.0842, -0.4294]], | -0.0842, -0.4294]], | ||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||||
-1.4138, -0.8853], | |||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||||
-1.0726, 0.0364], | |||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||||
-0.8836, -0.9320], | |||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||||
-1.6857, 1.1571], | |||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||||
-0.5837, 1.0184], | |||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||||
-0.9025, 0.0864]]]) | |||||
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], | |||||
[4, 1, 7, 0, 4, 7]]) | |||||
[ | |||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||||
-1.4138, -0.8853], | |||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||||
-1.0726, 0.0364], | |||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||||
-0.8836, -0.9320], | |||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||||
-1.6857, 1.1571], | |||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||||
-0.5837, 1.0184], | |||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||||
-0.9025, 0.0864]] | |||||
] | |||||
]) | |||||
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) | |||||
fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) | fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) | ||||
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, | expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, | ||||
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, | 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, | ||||
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} | 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} | ||||
assert expect_bio_res == fastnlp_bio_metric.get_metric() | assert expect_bio_res == fastnlp_bio_metric.get_metric() | ||||
# print(fastnlp_bio_metric.get_metric()) | # print(fastnlp_bio_metric.get_metric()) | ||||
@@ -325,44 +326,52 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') | ||||
def test_case5(self): | def test_case5(self): | ||||
global pool | |||||
# pool = Pool(NUM_PROCESSES) | |||||
# master_port = find_free_network_port() | |||||
# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||||
# global pool | |||||
pool = Pool(NUM_PROCESSES) | |||||
master_port = find_free_network_port() | |||||
pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) | |||||
number_labels = 4 | number_labels = 4 | ||||
# bio tag | # bio tag | ||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | ||||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | ||||
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | ||||
dataset = DataSet({'pred': [torch.FloatTensor( | |||||
[[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||||
-0.3782, 0.8240], | |||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||||
-0.3562, -1.4116], | |||||
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||||
2.0023, 0.7075], | |||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||||
0.3832, -0.1540], | |||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||||
-1.3508, -0.9513], | |||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||||
-0.0842, -0.4294]], | |||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||||
-1.4138, -0.8853], | |||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||||
-1.0726, 0.0364], | |||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||||
-0.8836, -0.9320], | |||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||||
-1.6857, 1.1571], | |||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||||
-0.5837, 1.0184], | |||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||||
-0.9025, 0.0864]]])] * 100, | |||||
'tg': [torch.LongTensor([[3, 6, 0, 8, 2, 4], | |||||
[4, 1, 7, 0, 4, 7]])] * 100, | |||||
'seq_len': [[6, 6]] * 100}) | |||||
dataset = DataSet({'pred': [ | |||||
torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||||
-0.3782, 0.8240], | |||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||||
-0.3562, -1.4116], | |||||
[1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, | |||||
2.0023, 0.7075], | |||||
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, | |||||
0.3832, -0.1540], | |||||
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, | |||||
-1.3508, -0.9513], | |||||
[1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, | |||||
-0.0842, -0.4294]] | |||||
]), | |||||
torch.FloatTensor([ | |||||
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, | |||||
-1.4138, -0.8853], | |||||
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, | |||||
-1.0726, 0.0364], | |||||
[0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, | |||||
-0.8836, -0.9320], | |||||
[0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, | |||||
-1.6857, 1.1571], | |||||
[1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, | |||||
-0.5837, 1.0184], | |||||
[1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, | |||||
-0.9025, 0.0864]] | |||||
]) | |||||
], | |||||
'tg': [ | |||||
torch.LongTensor([[3, 6, 0, 8, 2, 4]]), | |||||
torch.LongTensor([[4, 1, 7, 0, 4, 7]]) | |||||
], | |||||
'seq_len': [ | |||||
[6], [6] | |||||
]}) | |||||
metric_kwargs = { | metric_kwargs = { | ||||
'tag_vocab': fastnlp_bio_vocab, | 'tag_vocab': fastnlp_bio_vocab, | ||||
'only_gross': False, | 'only_gross': False, | ||||
@@ -372,7 +381,6 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, | 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, | ||||
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} | 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} | ||||
processes = NUM_PROCESSES | processes = NUM_PROCESSES | ||||
print(torch.cuda.device_count()) | |||||
pool.starmap( | pool.starmap( | ||||
partial( | partial( | ||||
@@ -384,3 +392,5 @@ class SpanFPreRecMetricTest(unittest.TestCase): | |||||
), | ), | ||||
[(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] | [(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] | ||||
) | ) | ||||
pool.close() | |||||
pool.join() |