Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
b56903bd11
4 changed files with 13 additions and 7 deletions
  1. +1
    -0
      fastNLP/core/collators/padders/get_padder.py
  2. +3
    -0
      fastNLP/core/collators/padders/paddle_padder.py
  3. +2
    -0
      fastNLP/core/collators/padders/torch_padder.py
  4. +7
    -7
      fastNLP/core/controllers/evaluator.py

+ 1
- 0
fastNLP/core/collators/padders/get_padder.py View File

@@ -118,6 +118,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
elif backend == 'numpy': elif backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'torch': elif backend == 'torch':
# 这里 ele_dtype 传入为 None 的原因是防止出现 paddle tensor 转换为 torch tensor
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'paddle': elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)


+ 3
- 0
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -132,6 +132,9 @@ class PaddleTensorPadder(Padder):
try: try:
if not isinstance(batch_field[0], paddle.Tensor): if not isinstance(batch_field[0], paddle.Tensor):
batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field]
else:
if dtype is None:
dtype = batch_field[0].dtype
except AttributeError: except AttributeError:
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.") f"it must have tolist() method.")


+ 2
- 0
fastNLP/core/collators/padders/torch_padder.py View File

@@ -118,6 +118,8 @@ class TorchTensorPadder(Padder):
batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field]
else: else:
device = batch_field[0].device device = batch_field[0].device
if dtype is None:
dtype = batch_field[0].dtype
except AttributeError: except AttributeError:
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.") f"it must have tolist() method.")


+ 7
- 7
fastNLP/core/controllers/evaluator.py View File

@@ -169,14 +169,14 @@ class Evaluator:
raise e raise e
finally: finally:
self.finally_progress_bar() self.finally_progress_bar()

metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False)
if len(metric_results) > 0: # 如果 metric 不为 None 需要 print 。
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False)
if self.verbose:
if self.progress_bar == 'rich':
f_rich_progress.print(metric_results)
else:
logger.info(metric_results)
self.driver.set_model_mode(mode='train') self.driver.set_model_mode(mode='train')
if self.verbose:
if self.progress_bar == 'rich':
f_rich_progress.print(metric_results)
else:
logger.info(metric_results)


return metric_results return metric_results




Loading…
Cancel
Save