diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 2228b240..37717f54 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -402,7 +402,7 @@ def fastnlp_torch_all_gather(obj:Any, device=None, group=None)->List: if _TORCH_GREATER_EQUAL_1_8: objs = [None for _ in range(dist.get_world_size(group))] dist.all_gather_object(objs, obj) - apply_to_collection(obj, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 + objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 return objs group = group if group is not None else torch.distributed.group.WORLD data = convert_to_tensors(obj, device=device) diff --git a/fastNLP/core/metrics/backend/torch_backend/backend.py b/fastNLP/core/metrics/backend/torch_backend/backend.py index f1db0151..8945ab01 100644 --- a/fastNLP/core/metrics/backend/torch_backend/backend.py +++ b/fastNLP/core/metrics/backend/torch_backend/backend.py @@ -33,8 +33,7 @@ class TorchBackend(Backend): if dist.is_initialized(): if method is None: raise AggregateMethodError(should_have_aggregate_method=True) - tensor = self._gather_all(tensor) - # tensor = self.all_gather_object(tensor) + tensor = fastnlp_torch_all_gather(tensor) if isinstance(tensor[0], torch.Tensor): tensor = torch.stack(tensor) # 第一步, aggregate结果 @@ -69,59 +68,6 @@ class TorchBackend(Backend): def get_scalar(self, tensor) -> float: return tensor.item() - @staticmethod - def _gather_all(result, group: Optional[Any] = None) -> List: - """Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. - Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case - tensors are padded, gathered and then trimmed to secure equal workload for all processes. - - Args: - result: the value to sync - group: the process group to gather results from. Defaults to all processes (world) - - Return: - gathered_result: list with size equal to the process group where - gathered_result[i] corresponds to result tensor from process i - """ - - if group is None: - group = dist.group.WORLD - - # convert tensors to contiguous format - result = result.contiguous() - - world_size = dist.get_world_size(group) - dist.barrier(group=group) - - # if the tensor is scalar, things are easy - if result.ndim == 0: - return _simple_gather_all_tensors(result, group, world_size) - - # 1. Gather sizes of all tensors - local_size = torch.tensor(result.shape, device=result.device) - local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] - dist.all_gather(local_sizes, local_size, group=group) - max_size = torch.stack(local_sizes).max(dim=0).values - all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) - - # 2. If shapes are all the same, then do a simple gather: - if all_sizes_equal: - return _simple_gather_all_tensors(result, group, world_size) - - # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate - pad_dims = [] - pad_by = (max_size - local_size).detach().cpu() - for val in reversed(pad_by): - pad_dims.append(0) - pad_dims.append(val.item()) - result_padded = torch.nn.functional.pad(result, pad_dims) - gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] - dist.all_gather(gathered_result, result_padded, group) - for idx, item_size in enumerate(local_sizes): - slice_param = [slice(dim_size) for dim_size in item_size] - gathered_result[idx] = gathered_result[idx][slice_param] - return gathered_result - def tensor2numpy(self, tensor) -> np.array: """ 将对应的tensor转为numpy对象 diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py index 483e9a49..22ba2635 100644 --- a/fastNLP/core/metrics/element.py +++ b/fastNLP/core/metrics/element.py @@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK class Element: - def __init__(self, value: float, aggregate_method, backend: Backend, name=None): + def __init__(self, name, value: float, aggregate_method, backend: Backend): + self.name = name self.init_value = value self.aggregate_method = aggregate_method - self.name = name if backend == 'auto': - raise RuntimeError("You have to specify the backend.") + raise RuntimeError(f"You have to specify the backend for Element:{self.name}.") elif isinstance(backend, AutoBackend): self.backend = backend else: @@ -41,14 +41,9 @@ class Element: msg = 'If you see this message, please report a bug.' if self.name and e.should_have_aggregate_method: msg = f"Element:{self.name} has no specified `aggregate_method`." - elif e.should_have_aggregate_method: - msg = "Element has no specified `aggregate_method`." elif self.name and not e.should_have_aggregate_method: msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ f'aggregate_method:{self.aggregate_method}.' - elif not e.should_have_aggregate_method: - msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ - f'aggregate_method:{self.aggregate_method}.' if e.only_warn: if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: logger.warning(msg) @@ -75,6 +70,7 @@ class Element: return self._value def get_scalar(self) -> float: + self._check_value_initialized() return self.backend.get_scalar(self._value) def fill_value(self, value): @@ -96,7 +92,7 @@ class Element: def _check_value_when_call(self): if self.value is None: - prefix = f'Element:`{self.name}`' if self.name else 'Element' + prefix = f'Element:`{self.name}`' raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " "element, or use it after it being used by the `Metric.compute()` method.") @@ -274,9 +270,10 @@ class Element: """ try: if self._value is None: - prefix = f'Element:`{self.name}`' if self.name else 'Element' + prefix = f'Element:`{self.name}`' raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " "element, or use it after it being used by the `Metric.compute()` method.") return getattr(self._value, item) except AttributeError as e: + logger.error(f"Element:{self.name} has no `{item}` attribute.") raise e diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 097671da..2fb575fc 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -35,7 +35,7 @@ class Metric: def elements(self) -> dict: return self._elements - def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: + def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: """ 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 tensor 直接进行加减乘除计算即可。 @@ -57,11 +57,9 @@ class Metric: else: backend = AutoBackend(backend) - # 当name为None,默认为变量取得变量名 - if name is None: - name = f'ele_var_{len(self._elements)}' + assert name is not None and name not in self.elements - element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) + element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend) self.elements[name] = element setattr(self, name, element) return element diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py index b92c24dc..716cea30 100644 --- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -219,6 +219,23 @@ class SpanFPreRecMetric(Metric): def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: + r""" + + :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), + 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. + :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 + :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 + :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 + :param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. + :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label + :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec + :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) + :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() + 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 + :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, + 当 backend 不支持分布式时,该参数无意义。 + """ super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) if f_type not in ('micro', 'macro'): raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) @@ -255,7 +272,7 @@ class SpanFPreRecMetric(Metric): for word, _ in tag_vocab: word = word.lower() if word != 'o': - word = word.split('-')[1] + word = word[2:] if word in self._true_positives: continue self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) @@ -266,8 +283,8 @@ class SpanFPreRecMetric(Metric): evaluate_result = {} if not self.only_gross or self.f_type == 'macro': tags = set(self._false_negatives.keys()) - tags.update(set(self._false_positives.keys())) - tags.update(set(self._true_positives.keys())) + tags.update(self._false_positives.keys()) + tags.update(self._true_positives.keys()) f_sum = 0 pre_sum = 0 rec_sum = 0 @@ -275,6 +292,9 @@ class SpanFPreRecMetric(Metric): tp = self._true_positives[tag].get_scalar() fn = self._false_negatives[tag].get_scalar() fp = self._false_positives[tag].get_scalar() + if tp == fn == fp == 0: + continue + f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) f_sum += f pre_sum += pre diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index 68a28335..2a2cd31c 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -150,7 +150,7 @@ def seed_jittor_global_seed(global_seed): pass -def dump_fastnlp_backend(default:bool = False): +def dump_fastnlp_backend(default:bool = False, backend=None): """ 将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下, 若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。 @@ -162,6 +162,7 @@ def dump_fastnlp_backend(default:bool = False): 会保存的环境变量为 FASTNLP_BACKEND 。 :param default: + :param backend: 保存使用的 backend 为哪个值,允许的值有 ['torch', 'paddle', 'jittor']。如果为 None ,则使用环境变量中的值。 :return: """ if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: @@ -176,10 +177,16 @@ def dump_fastnlp_backend(default:bool = False): os.makedirs(os.path.dirname(env_path), exist_ok=True) envs = {} - if FASTNLP_BACKEND in os.environ: - envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] + assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now." + if backend is None: + if FASTNLP_BACKEND in os.environ: + envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] + else: + envs[FASTNLP_BACKEND] = backend if len(envs): with open(env_path, 'w', encoding='utf8') as f: json.dump(fp=f, obj=envs) print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.") + else: + raise RuntimeError("No backend specified.") \ No newline at end of file diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index 773c1e22..38b79b44 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -48,7 +48,8 @@ def set_env_on_import_paddle(): # TODO jittor may need set this def set_env_on_import_jittor(): # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH - pass + if 'log_silent' not in os.environ: + os.environ['log_silent'] = '1' def set_env_on_import(): diff --git a/tests/core/metrics/test_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py similarity index 85% rename from tests/core/metrics/test_f1_rec_acc_torch.py rename to tests/core/metrics/test_span_f1_rec_acc_torch.py index 34067080..5908663a 100644 --- a/tests/core/metrics/test_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -99,7 +99,7 @@ def _test(local_rank: int, assert my_result == sklearn_metric -class SpanFPreRecMetricTest(unittest.TestCase): +class TestSpanFPreRecMetric: def test_case1(self): from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans @@ -136,33 +136,31 @@ class SpanFPreRecMetricTest(unittest.TestCase): fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) - bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, - -0.3782, 0.8240], - [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, + bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, + -0.3782, 0.8240], + [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, -0.3562, -1.4116], - [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, - 2.0023, 0.7075], - [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, + [ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, + 2.0023, 0.7075], + [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, 0.3832, -0.1540], - [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, + [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, -1.3508, -0.9513], - [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, - -0.0842, -0.4294]], - [ - [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, - -1.4138, -0.8853], - [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, - -1.0726, 0.0364], - [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, - -0.8836, -0.9320], - [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, - -1.6857, 1.1571], - [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, - -0.5837, 1.0184], - [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, - -0.9025, 0.0864]] - ] - ]) + [ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, + -0.0842, -0.4294]], + + [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, + -1.4138, -0.8853], + [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, + -1.0726, 0.0364], + [ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, + -0.8836, -0.9320], + [ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, + -1.6857, 1.1571], + [ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, + -0.5837, 1.0184], + [ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, + -0.9025, 0.0864]]]) bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, @@ -254,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): # print(expected_metric) metric_value = metric.get_metric() for key, value in expected_metric.items(): - self.assertAlmostEqual(value, metric_value[key], places=5) + np.allclose(value, metric_value[key]) def test_auto_encoding_type_infer(self): # 检查是否可以自动check encode的类型 @@ -271,9 +269,8 @@ class SpanFPreRecMetricTest(unittest.TestCase): vocab.add_word('o') vocabs[encoding_type] = vocab for e in ['bio', 'bioes', 'bmeso']: - with self.subTest(e=e): - metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) - assert metric.encoding_type == e + metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) + assert metric.encoding_type == e bmes_vocab = _generate_tags('bmes') vocab = Vocabulary() @@ -286,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): vocab = Vocabulary() for i in range(10): vocab.add_word(str(i)) - with self.assertRaises(Exception): + with pytest.raises(Exception): metric = SpanFPreRecMetric(vocab) def test_encoding_type(self): @@ -305,21 +302,20 @@ class SpanFPreRecMetricTest(unittest.TestCase): vocab.add_word('o') vocabs[encoding_type] = vocab for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): - with self.subTest(e1=e1, e2=e2): - if e1 == e2: + if e1 == e2: + metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) + else: + s2 = set(e2) + s2.update(set(e1)) + if s2 == set(e2): + continue + with pytest.raises(AssertionError): metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) - else: - s2 = set(e2) - s2.update(set(e1)) - if s2 == set(e2): - continue - with self.assertRaises(AssertionError): - metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) for encoding_type in ['bio', 'bioes', 'bmeso']: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') - with self.assertWarns(Warning): + with pytest.warns(Warning): vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') vocab = Vocabulary().add_word_lst(list('bmes'))