@@ -4,13 +4,13 @@ __all__ = [ | |||
] | |||
from typing import Union | |||
import warnings | |||
import numpy as np | |||
from fastNLP.core.metrics.metric import Metric | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask | |||
from fastNLP.core.log import logger | |||
class Accuracy(Metric): | |||
@@ -69,7 +69,7 @@ class Accuracy(Metric): | |||
elif pred.ndim == target.ndim + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and target.ndim > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " | |||
@@ -4,7 +4,6 @@ __all__ = [ | |||
from typing import Union, List | |||
from collections import Counter | |||
import warnings | |||
import numpy as np | |||
from .metric import Metric | |||
@@ -12,7 +11,7 @@ from .backend import Backend | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask | |||
from .utils import _compute_f_pre_rec | |||
from fastNLP.core.log import logger | |||
class ClassifyFPreRecMetric(Metric): | |||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, | |||
@@ -157,7 +156,7 @@ class ClassifyFPreRecMetric(Metric): | |||
elif pred.ndim == target.ndim + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and target.ndim > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred have " | |||
f"size:{pred.shape}, target should have size: {pred.shape} or " | |||
@@ -3,12 +3,12 @@ __all__ = [ | |||
] | |||
from typing import Union, List, Optional | |||
import warnings | |||
from collections import Counter | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.metrics.metric import Metric | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.log import logger | |||
from .utils import _compute_f_pre_rec | |||
@@ -39,7 +39,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod | |||
f"encoding_type." | |||
tags = tags.replace(tag, '') # 删除该值 | |||
if tags: # 如果不为空,说明出现了未使用的tag | |||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
logger.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
"encoding_type.") | |||
@@ -2,7 +2,6 @@ import functools | |||
import inspect | |||
from inspect import Parameter | |||
import dataclasses | |||
import warnings | |||
from dataclasses import is_dataclass | |||
from copy import deepcopy | |||
from collections import defaultdict, OrderedDict | |||
@@ -555,7 +554,7 @@ def deprecated(help_message: Optional[str] = None): | |||
def wrapper(*args, **kwargs): | |||
func_hash = hash(deprecated_function) | |||
if func_hash not in _emitted_deprecation_warnings: | |||
warnings.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||
logger.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||
_emitted_deprecation_warnings.add(func_hash) | |||
return deprecated_function(*args, **kwargs) | |||
@@ -7,7 +7,6 @@ __all__ = [ | |||
"StaticEmbedding" | |||
] | |||
import os | |||
import warnings | |||
from collections import defaultdict | |||
from copy import deepcopy | |||
import json | |||
@@ -15,6 +14,7 @@ from typing import Union | |||
import numpy as np | |||
from fastNLP.core.log import logger | |||
from .embedding import TokenEmbedding | |||
from ...core import logger | |||
from ...core.vocabulary import Vocabulary | |||
@@ -286,7 +286,7 @@ class StaticEmbedding(TokenEmbedding): | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
if index in matrix: | |||
warnings.warn(f"Word has more than one vector in embedding file. Set logger level to " | |||
logger.warn(f"Word has more than one vector in embedding file. Set logger level to " | |||
f"DEBUG for detail.") | |||
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") | |||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | |||
@@ -295,7 +295,7 @@ class StaticEmbedding(TokenEmbedding): | |||
found_count += 1 | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
logger.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
logger.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
@@ -9,12 +9,12 @@ __all__ = [ | |||
import logging | |||
import os | |||
import warnings | |||
import numpy as np | |||
from fastNLP.core.utils.utils import Option | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.log import logger | |||
class EmbeddingOption(Option): | |||
@@ -91,7 +91,7 @@ class EmbedLoader: | |||
hit_flags[index] = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
logger.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
logging.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
@@ -156,7 +156,7 @@ class EmbedLoader: | |||
found_pad = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
logger.warn("Error occurred at the {} line.".format(idx)) | |||
pass | |||
else: | |||
logging.error("Error occurred at the {} line.".format(idx)) | |||
@@ -25,11 +25,10 @@ import os | |||
import random | |||
import shutil | |||
import time | |||
import warnings | |||
from .loader import Loader | |||
from fastNLP.core.dataset import Instance, DataSet | |||
from ...core import logger | |||
from fastNLP.core.log import logger | |||
# from ...core._logger import log | |||
@@ -346,7 +345,7 @@ class SST2Loader(Loader): | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if 'test' in os.path.split(path)[1]: | |||
warnings.warn("SST2's test file has no target.") | |||
logger.warn("SST2's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
@@ -12,7 +12,6 @@ __all__ = [ | |||
] | |||
import os | |||
import warnings | |||
from typing import Union, Dict | |||
from .csv import CSVLoader | |||
@@ -22,6 +21,7 @@ from fastNLP.io.data_bundle import DataBundle | |||
from ..utils import check_loader_paths | |||
# from ...core.const import Const | |||
from fastNLP.core.dataset import DataSet, Instance | |||
from fastNLP.core.log import logger | |||
class MNLILoader(Loader): | |||
@@ -55,7 +55,7 @@ class MNLILoader(Loader): | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | |||
warnings.warn("MNLI's test file has no target.") | |||
logger.warn("MNLI's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
@@ -227,7 +227,7 @@ class QNLILoader(JsonLoader): | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
warnings.warn("QNLI's test file has no target.") | |||
logger.warn("QNLI's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
@@ -289,7 +289,7 @@ class RTELoader(Loader): | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
warnings.warn("RTE's test file has no target.") | |||
logger.warn("RTE's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
@@ -16,7 +16,6 @@ __all__ = [ | |||
] | |||
import re | |||
import warnings | |||
try: | |||
from nltk import Tree | |||
@@ -33,6 +32,7 @@ from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2L | |||
# from ...core._logger import log | |||
# from ...core.const import Const | |||
from fastNLP.core.dataset import DataSet, Instance | |||
from fastNLP.core.log import logger | |||
class CLSBasePipe(Pipe): | |||
@@ -24,8 +24,7 @@ __all__ = [ | |||
"MachingTruncatePipe", | |||
] | |||
import warnings | |||
from fastNLP.core.log import logger | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer | |||
from ..data_bundle import DataBundle | |||
@@ -147,7 +146,7 @@ class MatchingBertPipe(Pipe): | |||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | |||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | |||
f"data set but not in train data set!." | |||
warnings.warn(warn_msg) | |||
logger.warn(warn_msg) | |||
print(warn_msg) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
@@ -296,7 +295,7 @@ class MatchingPipe(Pipe): | |||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | |||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | |||
f"data set but not in train data set!." | |||
warnings.warn(warn_msg) | |||
logger.warn(warn_msg) | |||
print(warn_msg) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
@@ -7,11 +7,11 @@ __all__ = [ | |||
] | |||
from typing import List | |||
import warnings | |||
# from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
# from ...core._logger import log | |||
from fastNLP.core.log import logger | |||
from pkg_resources import parse_version | |||
@@ -138,7 +138,7 @@ def _indexize(data_bundle, input_field_names='words', target_field_names='target | |||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | |||
f"data set but not in train data set!.\n" \ | |||
f"These label(s) are {tgt_vocab._no_create_word}" | |||
warnings.warn(warn_msg) | |||
logger.warn(warn_msg) | |||
# log.warning(warn_msg) | |||
tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) | |||
data_bundle.set_vocab(tgt_vocab, target_field_name) | |||
@@ -1,4 +1,3 @@ | |||
import warnings | |||
from typing import Any, Optional, Union | |||
import numpy as np | |||
@@ -113,7 +112,7 @@ def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] = | |||
# 如果outputs有_grad键,可以实现求导 | |||
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient | |||
if no_gradient == False: | |||
warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.") | |||
logger.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.") | |||
jittor_numpy = jittor_var.numpy() | |||
if not np.issubdtype(jittor_numpy.dtype, np.inexact): | |||
no_gradient = True | |||
@@ -14,7 +14,6 @@ | |||
# limitations under the License. | |||
""" Auto Model class. """ | |||
import warnings | |||
from collections import OrderedDict | |||
from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update | |||
@@ -307,7 +306,7 @@ AutoModelForSpeechSeq2Seq = auto_class_update( | |||
class AutoModelWithLMHead(_AutoModelWithLMHead): | |||
@classmethod | |||
def from_config(cls, config): | |||
warnings.warn( | |||
logger.warn( | |||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |||
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |||
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |||
@@ -317,7 +316,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead): | |||
@classmethod | |||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): | |||
warnings.warn( | |||
logger.warn( | |||
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " | |||
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " | |||
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", | |||
@@ -10,6 +10,7 @@ if _NEED_IMPORT_TORCH: | |||
from fastNLP.embeddings.torch import StaticEmbedding | |||
@pytest.mark.torch | |||
class TestTransformerSeq2SeqEncoder: | |||
def test_case(self): | |||
vocab = Vocabulary().add_word_lst("This is a test .".split()) | |||
@@ -43,7 +43,7 @@ class DummyState(State): | |||
super().__init__() | |||
self.decoder = decoder | |||
def reorder_state(self, indices: torch.LongTensor): | |||
def reorder_state(self, indices: "torch.LongTensor"): | |||
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) | |||