@@ -50,20 +50,20 @@ def prepare_callbacks(callbacks, progress_bar): | |||||
raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") | ||||
_callbacks += callbacks | _callbacks += callbacks | ||||
has_no_progress = False | |||||
has_no_progress = True | |||||
for _callback in _callbacks: | for _callback in _callbacks: | ||||
if isinstance(_callback, ProgressCallback): | if isinstance(_callback, ProgressCallback): | ||||
has_no_progress = True | |||||
if not has_no_progress: | |||||
has_no_progress = False | |||||
if has_no_progress and progress_bar is not None: | |||||
callback = choose_progress_callback(progress_bar) | callback = choose_progress_callback(progress_bar) | ||||
if callback is not None: | if callback is not None: | ||||
_callbacks.append(callback) | _callbacks.append(callback) | ||||
elif progress_bar is not None and progress_bar != 'auto': | |||||
logger.warning(f"Since you have passed in ProgressBar callback, progress_bar will be ignored.") | |||||
has_no_progress = False | |||||
elif has_no_progress is False and progress_bar not in ('auto', None): | |||||
logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") | |||||
if has_no_progress and progress_bar is None: | |||||
rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " | |||||
"during training.") | |||||
if has_no_progress: | |||||
logger.rank_zero_warning("No progress bar is provided, there will have no progress output during training.") | |||||
return _callbacks | return _callbacks | ||||
@@ -87,17 +87,20 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
logger.info(f"Loading best model with {self.monitor_name}: {self.monitor_value}...") | |||||
if self.real_save_folder: | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_load_fn=self.model_load_fn) | |||||
else: | |||||
self.buffer.seek(0) | |||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
if self.delete_after_after: | |||||
trainer.driver.barrier() | |||||
self._delete_folder() | |||||
trainer.driver.barrier() | |||||
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | |||||
if self.real_save_folder: | |||||
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||||
model_load_fn=self.model_load_fn) | |||||
else: | |||||
logger.info( | |||||
f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||||
self.buffer.seek(0) | |||||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||||
if self.delete_after_after: | |||||
trainer.driver.barrier() | |||||
self._delete_folder() | |||||
trainer.driver.barrier() | |||||
def _delete_folder(self): | def _delete_folder(self): | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
@@ -138,8 +138,6 @@ class PaddleTensorPadder(Padder): | |||||
shapes = [field.shape for field in batch_field] | shapes = [field.shape for field in batch_field] | ||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
if isinstance(dtype, np.dtype): | |||||
print(dtype) | |||||
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
@@ -4,11 +4,11 @@ __all__ = [ | |||||
] | ] | ||||
from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | from typing import Optional, Callable, Sequence, Union, Tuple, Dict, Mapping, List | ||||
import inspect | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.dataloaders.utils import indice_collate_wrapper | from fastNLP.core.dataloaders.utils import indice_collate_wrapper | ||||
from fastNLP.io.data_bundle import DataBundle | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, UnrepeatedSampler, RandomSampler | ||||
@@ -79,35 +79,30 @@ class TorchDataLoader(DataLoader): | |||||
if sampler is None and batch_sampler is None: | if sampler is None and batch_sampler is None: | ||||
sampler = RandomSampler(dataset, shuffle=shuffle) | sampler = RandomSampler(dataset, shuffle=shuffle) | ||||
shuffle=False | shuffle=False | ||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | |||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers) | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | ||||
self._collate_fn = dataset.dataset.collator | |||||
self._collate_fn.set_backend(backend="torch") | |||||
collate_fn = dataset.dataset.collator | |||||
collate_fn.set_backend(backend="torch") | |||||
else: | else: | ||||
self._collate_fn = Collator(backend="torch") | |||||
collate_fn = Collator(backend="torch") | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
elif isinstance(collate_fn, Callable): | |||||
if collate_fn is not default_collate: | |||||
self._collate_fn = collate_fn | |||||
else: | |||||
self._collate_fn = default_collate | |||||
self.cur_indices_batch = None | |||||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||||
multiprocessing_context=multiprocessing_context, generator=generator, | |||||
prefetch_factor=prefetch_factor, | |||||
persistent_workers=persistent_workers) | |||||
self.cur_batch_indices = None | |||||
def __iter__(self): | def __iter__(self): | ||||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | # 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | ||||
# if len(self._collate_fn.get_collators()) == 0: | # if len(self._collate_fn.get_collators()) == 0: | ||||
# self._collate_fn.add_collator(self.collate_fn) | # self._collate_fn.add_collator(self.collate_fn) | ||||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||||
self.collate_fn = indice_collate_wrapper(self.collate_fn) | |||||
for indices, data in super().__iter__(): | for indices, data in super().__iter__(): | ||||
self.cur_batch_indices = indices | self.cur_batch_indices = indices | ||||
yield data | yield data | ||||
@@ -132,12 +127,26 @@ class TorchDataLoader(DataLoader): | |||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return collator | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | ||||
def _get_collator(self): | |||||
""" | |||||
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None | |||||
:return: | |||||
""" | |||||
collator = None | |||||
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): | |||||
collator = self.collate_fn.__wrapped__ | |||||
elif isinstance(self.collate_fn, Collator): | |||||
collator = self.collate_fn | |||||
return collator | |||||
def set_ignore(self, *field_names) -> Collator: | def set_ignore(self, *field_names) -> Collator: | ||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
@@ -149,9 +158,10 @@ class TorchDataLoader(DataLoader): | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_ignore(*field_names) | |||||
return self._collate_fn | |||||
collator = self._get_collator() | |||||
if isinstance(collator, Collator): | |||||
collator.set_ignore(*field_names) | |||||
return collator | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | ||||
@@ -164,7 +174,7 @@ class TorchDataLoader(DataLoader): | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | |||||
batch_size: int = 16, | batch_size: int = 16, | ||||
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
@@ -197,7 +207,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | :param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | ||||
:param non_train_batch_size: | :param non_train_batch_size: | ||||
""" | """ | ||||
from fastNLP.io import DataBundle | |||||
if isinstance(ds_or_db, DataSet): | if isinstance(ds_or_db, DataSet): | ||||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | ||||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | ||||
@@ -208,7 +218,7 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS | |||||
) | ) | ||||
return dl | return dl | ||||
elif isinstance(ds_or_db, DataBundle): | |||||
elif type(ds_or_db, DataBundle): | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
@@ -5,12 +5,25 @@ def indice_collate_wrapper(func): | |||||
:param func: 需要修饰的函数 | :param func: 需要修饰的函数 | ||||
:return: | :return: | ||||
""" | """ | ||||
if func.__name__ == '_indice_collate_wrapper': # 如果已经被包裹过了 | |||||
return func | |||||
def wrapper(tuple_data): | |||||
def _indice_collate_wrapper(tuple_data): # 这里不能使用 functools.wraps ,否则会检测不到 | |||||
indice, ins_list = [], [] | indice, ins_list = [], [] | ||||
for idx, ins in tuple_data: | for idx, ins in tuple_data: | ||||
indice.append(idx) | indice.append(idx) | ||||
ins_list.append(ins) | ins_list.append(ins) | ||||
return indice, func(ins_list) | return indice, func(ins_list) | ||||
_indice_collate_wrapper.__wrapped__ = func # 记录对应的 | |||||
return wrapper | |||||
return _indice_collate_wrapper | |||||
if __name__ == '__main__': | |||||
def demo(*args, **kwargs): | |||||
pass | |||||
d = indice_collate_wrapper(demo) | |||||
print(d.__name__) | |||||
print(d.__wrapped__) |
@@ -8,6 +8,7 @@ __all__ = [ | |||||
from collections import Counter | from collections import Counter | ||||
from typing import Any, Union, List, Callable | from typing import Any, Union, List, Callable | ||||
from ..log import logger | |||||
import numpy as np | import numpy as np | ||||
@@ -21,7 +22,7 @@ class FieldArray: | |||||
try: | try: | ||||
_content = list(_content) | _content = list(_content) | ||||
except BaseException as e: | except BaseException as e: | ||||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||||
logger.error(f"Cannot convert content(of type:{type(content)}) into list.") | |||||
raise e | raise e | ||||
self.name = name | self.name = name | ||||
self.content = _content | self.content = _content | ||||
@@ -87,7 +88,7 @@ class FieldArray: | |||||
try: | try: | ||||
new_contents.append(cell.split(sep)) | new_contents.append(cell.split(sep)) | ||||
except Exception as e: | except Exception as e: | ||||
print(f"Exception happens when process value in index {index}.") | |||||
logger.error(f"Exception happens when process value in index {index}.") | |||||
raise e | raise e | ||||
return self._after_process(new_contents, inplace=inplace) | return self._after_process(new_contents, inplace=inplace) | ||||
@@ -111,7 +111,7 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
def warning_once(self, msg, *args, **kwargs): | def warning_once(self, msg, *args, **kwargs): | ||||
""" | """ | ||||
通过 warning 内容只会 warning 一次 | |||||
相同的 warning 内容只会 warning 一次 | |||||
:param msg: | :param msg: | ||||
:param args: | :param args: | ||||
@@ -124,6 +124,22 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
self._log(WARNING, msg, args, **kwargs) | self._log(WARNING, msg, args, **kwargs) | ||||
self._warning_msgs.add(msg) | self._warning_msgs.add(msg) | ||||
def rank_zero_warning(self, msg, *args, **kwargs): | |||||
""" | |||||
只在 rank 0 上 warning 。 | |||||
:param msg: | |||||
:param args: | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': | |||||
if msg not in self._warning_msgs: | |||||
if self.isEnabledFor(WARNING): | |||||
# kwargs = self._add_rank_info(kwargs) | |||||
self._log(WARNING, msg, args, **kwargs) | |||||
self._warning_msgs.add(msg) | |||||
def warn(self, msg, *args, **kwargs): | def warn(self, msg, *args, **kwargs): | ||||
if self.isEnabledFor(WARNING): | if self.isEnabledFor(WARNING): | ||||
kwargs = self._add_rank_info(kwargs) | kwargs = self._add_rank_info(kwargs) | ||||
@@ -156,8 +156,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
super().stop_task(task_id) | super().stop_task(task_id) | ||||
super().remove_task(task_id) | super().remove_task(task_id) | ||||
self.refresh() # 使得bar不残留 | self.refresh() # 使得bar不残留 | ||||
if len(self._tasks) == 0: | |||||
super().stop() | |||||
# 这里需要注释掉的原因是由于,在dataset多次apply的过程中会出现自动换行的问题。以前保留这个的原因应该是由于evaluate结束bar不消失。 | |||||
# if len(self._tasks) == 0: | |||||
# self.live.stop() | |||||
def start(self) -> None: | def start(self) -> None: | ||||
super().start() | super().start() | ||||
@@ -15,6 +15,7 @@ from functools import wraps | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.utils.utils import Option | from fastNLP.core.utils.utils import Option | ||||
from fastNLP.core.utils.utils import _is_iterable | from fastNLP.core.utils.utils import _is_iterable | ||||
from .log import logger | |||||
import io | import io | ||||
@@ -56,7 +57,7 @@ def _check_build_status(func): | |||||
if self.rebuild is False: | if self.rebuild is False: | ||||
self.rebuild = True | self.rebuild = True | ||||
if self.max_size is not None and len(self.word_count) >= self.max_size: | if self.max_size is not None and len(self.word_count) >= self.max_size: | ||||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||||
logger.warning("Vocabulary has reached the max size {} when calling {} method. " | |||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | "Adding more words may cause unexpected behaviour of Vocabulary. ".format( | ||||
self.max_size, func.__name__)) | self.max_size, func.__name__)) | ||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
@@ -322,7 +323,7 @@ class Vocabulary(object): | |||||
for f_n, n_f_n in zip(field_name, new_field_name): | for f_n, n_f_n in zip(field_name, new_field_name): | ||||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | ||||
except Exception as e: | except Exception as e: | ||||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||||
logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||||
raise e | raise e | ||||
else: | else: | ||||
raise RuntimeError("Only DataSet type is allowed.") | raise RuntimeError("Only DataSet type is allowed.") | ||||
@@ -378,7 +379,7 @@ class Vocabulary(object): | |||||
try: | try: | ||||
dataset.apply(construct_vocab) | dataset.apply(construct_vocab) | ||||
except BaseException as e: | except BaseException as e: | ||||
print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||||
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||||
raise e | raise e | ||||
else: | else: | ||||
raise TypeError("Only DataSet type is allowed.") | raise TypeError("Only DataSet type is allowed.") | ||||
@@ -10,7 +10,7 @@ from typing import Union, List, Callable | |||||
from ..core.dataset import DataSet | from ..core.dataset import DataSet | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
# from ..core._logger import _logger | |||||
from fastNLP.core import logger | |||||
class DataBundle: | class DataBundle: | ||||
@@ -72,7 +72,7 @@ class DataBundle: | |||||
else: | else: | ||||
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ | error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ | ||||
f'It should be one of {self.datasets.keys()}.' | f'It should be one of {self.datasets.keys()}.' | ||||
print(error_msg) | |||||
logger.error(error_msg) | |||||
raise KeyError(error_msg) | raise KeyError(error_msg) | ||||
def delete_dataset(self, name: str): | def delete_dataset(self, name: str): | ||||
@@ -97,7 +97,7 @@ class DataBundle: | |||||
else: | else: | ||||
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ | error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ | ||||
f'It should be one of {self.vocabs.keys()}.' | f'It should be one of {self.vocabs.keys()}.' | ||||
print(error_msg) | |||||
logger.error(error_msg) | |||||
raise KeyError(error_msg) | raise KeyError(error_msg) | ||||
def delete_vocab(self, field_name: str): | def delete_vocab(self, field_name: str): | ||||
@@ -117,7 +117,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||||
yield line_idx, res | yield line_idx, res | ||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||||
logger.error('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||||
sample = [] | sample = [] | ||||
continue | continue | ||||
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | ||||
@@ -132,5 +132,5 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||||
except Exception as e: | except Exception as e: | ||||
if dropna: | if dropna: | ||||
return | return | ||||
print('invalid instance ends at line: {}'.format(line_idx)) | |||||
logger.error('invalid instance ends at line: {}'.format(line_idx)) | |||||
raise e | raise e |
@@ -29,6 +29,7 @@ import warnings | |||||
from .loader import Loader | from .loader import Loader | ||||
from fastNLP.core.dataset import Instance, DataSet | from fastNLP.core.dataset import Instance, DataSet | ||||
from ...core import logger | |||||
# from ...core._logger import log | # from ...core._logger import log | ||||
@@ -86,7 +87,8 @@ class CLSBaseLoader(Loader): | |||||
if raw_words: | if raw_words: | ||||
ds.append(Instance(raw_words=raw_words, target=target)) | ds.append(Instance(raw_words=raw_words, target=target)) | ||||
except Exception as e: | except Exception as e: | ||||
print(f'Load file `{path}` failed for `{e}`') | |||||
logger.error(f'Fail to load `{path}`.') | |||||
raise e | |||||
return ds | return ds | ||||