Browse Source

1.处理新测试例的import问题 2.将warnings.warn替换为logger.warn

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
edda3d3196
15 changed files with 30 additions and 35 deletions
  1. +2
    -2
      fastNLP/core/metrics/accuracy.py
  2. +2
    -3
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  3. +2
    -2
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  4. +1
    -2
      fastNLP/core/utils/utils.py
  5. +3
    -3
      fastNLP/embeddings/torch/static_embedding.py
  6. +3
    -3
      fastNLP/io/embed_loader.py
  7. +2
    -3
      fastNLP/io/loader/classification.py
  8. +4
    -4
      fastNLP/io/loader/matching.py
  9. +1
    -1
      fastNLP/io/pipe/classification.py
  10. +3
    -4
      fastNLP/io/pipe/matching.py
  11. +2
    -2
      fastNLP/io/pipe/utils.py
  12. +1
    -2
      fastNLP/modules/mix_modules/utils.py
  13. +2
    -3
      fastNLP/transformers/torch/models/auto/modeling_auto.py
  14. +1
    -0
      tests/modules/torch/encoder/test_seq2seq_encoder.py
  15. +1
    -1
      tests/modules/torch/generator/test_seq2seq_generator.py

+ 2
- 2
fastNLP/core/metrics/accuracy.py View File

@@ -4,13 +4,13 @@ __all__ = [
] ]


from typing import Union from typing import Union
import warnings


import numpy as np import numpy as np


from fastNLP.core.metrics.metric import Metric from fastNLP.core.metrics.metric import Metric
from fastNLP.core.metrics.backend import Backend from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
from fastNLP.core.log import logger




class Accuracy(Metric): class Accuracy(Metric):
@@ -69,7 +69,7 @@ class Accuracy(Metric):
elif pred.ndim == target.ndim + 1: elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1) pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1: if seq_len is None and target.ndim > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")


else: else:
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or " raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or "


+ 2
- 3
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -4,7 +4,6 @@ __all__ = [


from typing import Union, List from typing import Union, List
from collections import Counter from collections import Counter
import warnings
import numpy as np import numpy as np


from .metric import Metric from .metric import Metric
@@ -12,7 +11,7 @@ from .backend import Backend
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
from .utils import _compute_f_pre_rec from .utils import _compute_f_pre_rec
from fastNLP.core.log import logger


class ClassifyFPreRecMetric(Metric): class ClassifyFPreRecMetric(Metric):
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None,
@@ -157,7 +156,7 @@ class ClassifyFPreRecMetric(Metric):
elif pred.ndim == target.ndim + 1: elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1) pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1: if seq_len is None and target.ndim > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else: else:
raise RuntimeError(f"when pred have " raise RuntimeError(f"when pred have "
f"size:{pred.shape}, target should have size: {pred.shape} or " f"size:{pred.shape}, target should have size: {pred.shape} or "


+ 2
- 2
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -3,12 +3,12 @@ __all__ = [
] ]


from typing import Union, List, Optional from typing import Union, List, Optional
import warnings
from collections import Counter from collections import Counter


from fastNLP.core.metrics.backend import Backend from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.metric import Metric from fastNLP.core.metrics.metric import Metric
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.log import logger
from .utils import _compute_f_pre_rec from .utils import _compute_f_pre_rec




@@ -39,7 +39,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod
f"encoding_type." f"encoding_type."
tags = tags.replace(tag, '') # 删除该值 tags = tags.replace(tag, '') # 删除该值
if tags: # 如果不为空,说明出现了未使用的tag if tags: # 如果不为空,说明出现了未使用的tag
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
logger.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
"encoding_type.") "encoding_type.")






+ 1
- 2
fastNLP/core/utils/utils.py View File

@@ -2,7 +2,6 @@ import functools
import inspect import inspect
from inspect import Parameter from inspect import Parameter
import dataclasses import dataclasses
import warnings
from dataclasses import is_dataclass from dataclasses import is_dataclass
from copy import deepcopy from copy import deepcopy
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
@@ -555,7 +554,7 @@ def deprecated(help_message: Optional[str] = None):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
func_hash = hash(deprecated_function) func_hash = hash(deprecated_function)
if func_hash not in _emitted_deprecation_warnings: if func_hash not in _emitted_deprecation_warnings:
warnings.warn(warning_msg, category=FutureWarning, stacklevel=2)
logger.warn(warning_msg, category=FutureWarning, stacklevel=2)
_emitted_deprecation_warnings.add(func_hash) _emitted_deprecation_warnings.add(func_hash)
return deprecated_function(*args, **kwargs) return deprecated_function(*args, **kwargs)




+ 3
- 3
fastNLP/embeddings/torch/static_embedding.py View File

@@ -7,7 +7,6 @@ __all__ = [
"StaticEmbedding" "StaticEmbedding"
] ]
import os import os
import warnings
from collections import defaultdict from collections import defaultdict
from copy import deepcopy from copy import deepcopy
import json import json
@@ -15,6 +14,7 @@ from typing import Union


import numpy as np import numpy as np


from fastNLP.core.log import logger
from .embedding import TokenEmbedding from .embedding import TokenEmbedding
from ...core import logger from ...core import logger
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
@@ -286,7 +286,7 @@ class StaticEmbedding(TokenEmbedding):
if word in vocab: if word in vocab:
index = vocab.to_index(word) index = vocab.to_index(word)
if index in matrix: if index in matrix:
warnings.warn(f"Word has more than one vector in embedding file. Set logger level to "
logger.warn(f"Word has more than one vector in embedding file. Set logger level to "
f"DEBUG for detail.") f"DEBUG for detail.")
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)") logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)")
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
@@ -295,7 +295,7 @@ class StaticEmbedding(TokenEmbedding):
found_count += 1 found_count += 1
except Exception as e: except Exception as e:
if error == 'ignore': if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
logger.warn("Error occurred at the {} line.".format(idx))
else: else:
logger.error("Error occurred at the {} line.".format(idx)) logger.error("Error occurred at the {} line.".format(idx))
raise e raise e


+ 3
- 3
fastNLP/io/embed_loader.py View File

@@ -9,12 +9,12 @@ __all__ = [


import logging import logging
import os import os
import warnings


import numpy as np import numpy as np


from fastNLP.core.utils.utils import Option from fastNLP.core.utils.utils import Option
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.log import logger




class EmbeddingOption(Option): class EmbeddingOption(Option):
@@ -91,7 +91,7 @@ class EmbedLoader:
hit_flags[index] = True hit_flags[index] = True
except Exception as e: except Exception as e:
if error == 'ignore': if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
logger.warn("Error occurred at the {} line.".format(idx))
else: else:
logging.error("Error occurred at the {} line.".format(idx)) logging.error("Error occurred at the {} line.".format(idx))
raise e raise e
@@ -156,7 +156,7 @@ class EmbedLoader:
found_pad = True found_pad = True
except Exception as e: except Exception as e:
if error == 'ignore': if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
logger.warn("Error occurred at the {} line.".format(idx))
pass pass
else: else:
logging.error("Error occurred at the {} line.".format(idx)) logging.error("Error occurred at the {} line.".format(idx))


+ 2
- 3
fastNLP/io/loader/classification.py View File

@@ -25,11 +25,10 @@ import os
import random import random
import shutil import shutil
import time import time
import warnings


from .loader import Loader from .loader import Loader
from fastNLP.core.dataset import Instance, DataSet from fastNLP.core.dataset import Instance, DataSet
from ...core import logger
from fastNLP.core.log import logger




# from ...core._logger import log # from ...core._logger import log
@@ -346,7 +345,7 @@ class SST2Loader(Loader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if 'test' in os.path.split(path)[1]: if 'test' in os.path.split(path)[1]:
warnings.warn("SST2's test file has no target.")
logger.warn("SST2's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:


+ 4
- 4
fastNLP/io/loader/matching.py View File

@@ -12,7 +12,6 @@ __all__ = [
] ]


import os import os
import warnings
from typing import Union, Dict from typing import Union, Dict


from .csv import CSVLoader from .csv import CSVLoader
@@ -22,6 +21,7 @@ from fastNLP.io.data_bundle import DataBundle
from ..utils import check_loader_paths from ..utils import check_loader_paths
# from ...core.const import Const # from ...core.const import Const
from fastNLP.core.dataset import DataSet, Instance from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.log import logger




class MNLILoader(Loader): class MNLILoader(Loader):
@@ -55,7 +55,7 @@ class MNLILoader(Loader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
warnings.warn("MNLI's test file has no target.")
logger.warn("MNLI's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:
@@ -227,7 +227,7 @@ class QNLILoader(JsonLoader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
warnings.warn("QNLI's test file has no target.")
logger.warn("QNLI's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:
@@ -289,7 +289,7 @@ class RTELoader(Loader):
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
warnings.warn("RTE's test file has no target.")
logger.warn("RTE's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:


+ 1
- 1
fastNLP/io/pipe/classification.py View File

@@ -16,7 +16,6 @@ __all__ = [
] ]


import re import re
import warnings


try: try:
from nltk import Tree from nltk import Tree
@@ -33,6 +32,7 @@ from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2L
# from ...core._logger import log # from ...core._logger import log
# from ...core.const import Const # from ...core.const import Const
from fastNLP.core.dataset import DataSet, Instance from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.log import logger




class CLSBasePipe(Pipe): class CLSBasePipe(Pipe):


+ 3
- 4
fastNLP/io/pipe/matching.py View File

@@ -24,8 +24,7 @@ __all__ = [
"MachingTruncatePipe", "MachingTruncatePipe",
] ]


import warnings

from fastNLP.core.log import logger
from .pipe import Pipe from .pipe import Pipe
from .utils import get_tokenizer from .utils import get_tokenizer
from ..data_bundle import DataBundle from ..data_bundle import DataBundle
@@ -147,7 +146,7 @@ class MatchingBertPipe(Pipe):
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!." f"data set but not in train data set!."
warnings.warn(warn_msg)
logger.warn(warn_msg)
print(warn_msg) print(warn_msg)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
@@ -296,7 +295,7 @@ class MatchingPipe(Pipe):
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!." f"data set but not in train data set!."
warnings.warn(warn_msg)
logger.warn(warn_msg)
print(warn_msg) print(warn_msg)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if


+ 2
- 2
fastNLP/io/pipe/utils.py View File

@@ -7,11 +7,11 @@ __all__ = [
] ]


from typing import List from typing import List
import warnings


# from ...core.const import Const # from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
# from ...core._logger import log # from ...core._logger import log
from fastNLP.core.log import logger
from pkg_resources import parse_version from pkg_resources import parse_version




@@ -138,7 +138,7 @@ def _indexize(data_bundle, input_field_names='words', target_field_names='target
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!.\n" \ f"data set but not in train data set!.\n" \
f"These label(s) are {tgt_vocab._no_create_word}" f"These label(s) are {tgt_vocab._no_create_word}"
warnings.warn(warn_msg)
logger.warn(warn_msg)
# log.warning(warn_msg) # log.warning(warn_msg)
tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name) data_bundle.set_vocab(tgt_vocab, target_field_name)


+ 1
- 2
fastNLP/modules/mix_modules/utils.py View File

@@ -1,4 +1,3 @@
import warnings
from typing import Any, Optional, Union from typing import Any, Optional, Union


import numpy as np import numpy as np
@@ -113,7 +112,7 @@ def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] =
# 如果outputs有_grad键,可以实现求导 # 如果outputs有_grad键,可以实现求导
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient
if no_gradient == False: if no_gradient == False:
warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.")
logger.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.")
jittor_numpy = jittor_var.numpy() jittor_numpy = jittor_var.numpy()
if not np.issubdtype(jittor_numpy.dtype, np.inexact): if not np.issubdtype(jittor_numpy.dtype, np.inexact):
no_gradient = True no_gradient = True


+ 2
- 3
fastNLP/transformers/torch/models/auto/modeling_auto.py View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Auto Model class. """ """ Auto Model class. """


import warnings
from collections import OrderedDict from collections import OrderedDict


from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
@@ -307,7 +306,7 @@ AutoModelForSpeechSeq2Seq = auto_class_update(
class AutoModelWithLMHead(_AutoModelWithLMHead): class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
warnings.warn(
logger.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
@@ -317,7 +316,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead):


@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
logger.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", "`AutoModelForSeq2SeqLM` for encoder-decoder models.",


+ 1
- 0
tests/modules/torch/encoder/test_seq2seq_encoder.py View File

@@ -10,6 +10,7 @@ if _NEED_IMPORT_TORCH:
from fastNLP.embeddings.torch import StaticEmbedding from fastNLP.embeddings.torch import StaticEmbedding




@pytest.mark.torch
class TestTransformerSeq2SeqEncoder: class TestTransformerSeq2SeqEncoder:
def test_case(self): def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split()) vocab = Vocabulary().add_word_lst("This is a test .".split())


+ 1
- 1
tests/modules/torch/generator/test_seq2seq_generator.py View File

@@ -43,7 +43,7 @@ class DummyState(State):
super().__init__() super().__init__()
self.decoder = decoder self.decoder = decoder


def reorder_state(self, indices: torch.LongTensor):
def reorder_state(self, indices: "torch.LongTensor"):
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0) self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)






Loading…
Cancel
Save