Browse Source

1.处理新测试例的import问题 2.将warnings.warn替换为logger.warn

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
edda3d3196
15 changed files with 30 additions and 35 deletions
  1. +2
    -2
      fastNLP/core/metrics/accuracy.py
  2. +2
    -3
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  3. +2
    -2
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  4. +1
    -2
      fastNLP/core/utils/utils.py
  5. +3
    -3
      fastNLP/embeddings/torch/static_embedding.py
  6. +3
    -3
      fastNLP/io/embed_loader.py
  7. +2
    -3
      fastNLP/io/loader/classification.py
  8. +4
    -4
      fastNLP/io/loader/matching.py
  9. +1
    -1
      fastNLP/io/pipe/classification.py
  10. +3
    -4
      fastNLP/io/pipe/matching.py
  11. +2
    -2
      fastNLP/io/pipe/utils.py
  12. +1
    -2
      fastNLP/modules/mix_modules/utils.py
  13. +2
    -3
      fastNLP/transformers/torch/models/auto/modeling_auto.py
  14. +1
    -0
      tests/modules/torch/encoder/test_seq2seq_encoder.py
  15. +1
    -1
      tests/modules/torch/generator/test_seq2seq_generator.py

+ 2
- 2
fastNLP/core/metrics/accuracy.py View File

@@ -4,13 +4,13 @@ __all__ = [
]

from typing import Union
import warnings

import numpy as np

from fastNLP.core.metrics.metric import Metric
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
from fastNLP.core.log import logger


class Accuracy(Metric):
@@ -69,7 +69,7 @@ class Accuracy(Metric):
elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")

else:
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or "


+ 2
- 3
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -4,7 +4,6 @@ __all__ = [

from typing import Union, List
from collections import Counter
import warnings
import numpy as np

from .metric import Metric
@@ -12,7 +11,7 @@ from .backend import Backend
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
from .utils import _compute_f_pre_rec
from fastNLP.core.log import logger

class ClassifyFPreRecMetric(Metric):
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None,
@@ -157,7 +156,7 @@ class ClassifyFPreRecMetric(Metric):
elif pred.ndim == target.ndim + 1:
pred = pred.argmax(axis=-1)
if seq_len is None and target.ndim > 1:
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
logger.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")
else:
raise RuntimeError(f"when pred have "
f"size:{pred.shape}, target should have size: {pred.shape} or "


+ 2
- 2
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -3,12 +3,12 @@ __all__ = [
]

from typing import Union, List, Optional
import warnings
from collections import Counter

from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.metric import Metric
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.log import logger
from .utils import _compute_f_pre_rec


@@ -39,7 +39,7 @@ def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encod
f"encoding_type."
tags = tags.replace(tag, '') # 删除该值
if tags: # 如果不为空,说明出现了未使用的tag
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
logger.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your "
"encoding_type.")




+ 1
- 2
fastNLP/core/utils/utils.py View File

@@ -2,7 +2,6 @@ import functools
import inspect
from inspect import Parameter
import dataclasses
import warnings
from dataclasses import is_dataclass
from copy import deepcopy
from collections import defaultdict, OrderedDict
@@ -555,7 +554,7 @@ def deprecated(help_message: Optional[str] = None):
def wrapper(*args, **kwargs):
func_hash = hash(deprecated_function)
if func_hash not in _emitted_deprecation_warnings:
warnings.warn(warning_msg, category=FutureWarning, stacklevel=2)
logger.warn(warning_msg, category=FutureWarning, stacklevel=2)
_emitted_deprecation_warnings.add(func_hash)
return deprecated_function(*args, **kwargs)



+ 3
- 3
fastNLP/embeddings/torch/static_embedding.py View File

@@ -7,7 +7,6 @@ __all__ = [
"StaticEmbedding"
]
import os
import warnings
from collections import defaultdict
from copy import deepcopy
import json
@@ -15,6 +14,7 @@ from typing import Union

import numpy as np

from fastNLP.core.log import logger
from .embedding import TokenEmbedding
from ...core import logger
from ...core.vocabulary import Vocabulary
@@ -286,7 +286,7 @@ class StaticEmbedding(TokenEmbedding):
if word in vocab:
index = vocab.to_index(word)
if index in matrix:
warnings.warn(f"Word has more than one vector in embedding file. Set logger level to "
logger.warn(f"Word has more than one vector in embedding file. Set logger level to "
f"DEBUG for detail.")
logger.debug(f"Word:{word} occurs again in line:{idx}(starts from 0)")
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
@@ -295,7 +295,7 @@ class StaticEmbedding(TokenEmbedding):
found_count += 1
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
logger.warn("Error occurred at the {} line.".format(idx))
else:
logger.error("Error occurred at the {} line.".format(idx))
raise e


+ 3
- 3
fastNLP/io/embed_loader.py View File

@@ -9,12 +9,12 @@ __all__ = [

import logging
import os
import warnings

import numpy as np

from fastNLP.core.utils.utils import Option
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.log import logger


class EmbeddingOption(Option):
@@ -91,7 +91,7 @@ class EmbedLoader:
hit_flags[index] = True
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
logger.warn("Error occurred at the {} line.".format(idx))
else:
logging.error("Error occurred at the {} line.".format(idx))
raise e
@@ -156,7 +156,7 @@ class EmbedLoader:
found_pad = True
except Exception as e:
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
logger.warn("Error occurred at the {} line.".format(idx))
pass
else:
logging.error("Error occurred at the {} line.".format(idx))


+ 2
- 3
fastNLP/io/loader/classification.py View File

@@ -25,11 +25,10 @@ import os
import random
import shutil
import time
import warnings

from .loader import Loader
from fastNLP.core.dataset import Instance, DataSet
from ...core import logger
from fastNLP.core.log import logger


# from ...core._logger import log
@@ -346,7 +345,7 @@ class SST2Loader(Loader):
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if 'test' in os.path.split(path)[1]:
warnings.warn("SST2's test file has no target.")
logger.warn("SST2's test file has no target.")
for line in f:
line = line.strip()
if line:


+ 4
- 4
fastNLP/io/loader/matching.py View File

@@ -12,7 +12,6 @@ __all__ = [
]

import os
import warnings
from typing import Union, Dict

from .csv import CSVLoader
@@ -22,6 +21,7 @@ from fastNLP.io.data_bundle import DataBundle
from ..utils import check_loader_paths
# from ...core.const import Const
from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.log import logger


class MNLILoader(Loader):
@@ -55,7 +55,7 @@ class MNLILoader(Loader):
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
warnings.warn("MNLI's test file has no target.")
logger.warn("MNLI's test file has no target.")
for line in f:
line = line.strip()
if line:
@@ -227,7 +227,7 @@ class QNLILoader(JsonLoader):
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
warnings.warn("QNLI's test file has no target.")
logger.warn("QNLI's test file has no target.")
for line in f:
line = line.strip()
if line:
@@ -289,7 +289,7 @@ class RTELoader(Loader):
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if path.endswith("test.tsv"):
warnings.warn("RTE's test file has no target.")
logger.warn("RTE's test file has no target.")
for line in f:
line = line.strip()
if line:


+ 1
- 1
fastNLP/io/pipe/classification.py View File

@@ -16,7 +16,6 @@ __all__ = [
]

import re
import warnings

try:
from nltk import Tree
@@ -33,6 +32,7 @@ from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2L
# from ...core._logger import log
# from ...core.const import Const
from fastNLP.core.dataset import DataSet, Instance
from fastNLP.core.log import logger


class CLSBasePipe(Pipe):


+ 3
- 4
fastNLP/io/pipe/matching.py View File

@@ -24,8 +24,7 @@ __all__ = [
"MachingTruncatePipe",
]

import warnings

from fastNLP.core.log import logger
from .pipe import Pipe
from .utils import get_tokenizer
from ..data_bundle import DataBundle
@@ -147,7 +146,7 @@ class MatchingBertPipe(Pipe):
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!."
warnings.warn(warn_msg)
logger.warn(warn_msg)
print(warn_msg)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
@@ -296,7 +295,7 @@ class MatchingPipe(Pipe):
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!."
warnings.warn(warn_msg)
logger.warn(warn_msg)
print(warn_msg)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if


+ 2
- 2
fastNLP/io/pipe/utils.py View File

@@ -7,11 +7,11 @@ __all__ = [
]

from typing import List
import warnings

# from ...core.const import Const
from ...core.vocabulary import Vocabulary
# from ...core._logger import log
from fastNLP.core.log import logger
from pkg_resources import parse_version


@@ -138,7 +138,7 @@ def _indexize(data_bundle, input_field_names='words', target_field_names='target
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \
f"data set but not in train data set!.\n" \
f"These label(s) are {tgt_vocab._no_create_word}"
warnings.warn(warn_msg)
logger.warn(warn_msg)
# log.warning(warn_msg)
tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name)


+ 1
- 2
fastNLP/modules/mix_modules/utils.py View File

@@ -1,4 +1,3 @@
import warnings
from typing import Any, Optional, Union

import numpy as np
@@ -113,7 +112,7 @@ def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] =
# 如果outputs有_grad键,可以实现求导
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient
if no_gradient == False:
warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.")
logger.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.")
jittor_numpy = jittor_var.numpy()
if not np.issubdtype(jittor_numpy.dtype, np.inexact):
no_gradient = True


+ 2
- 3
fastNLP/transformers/torch/models/auto/modeling_auto.py View File

@@ -14,7 +14,6 @@
# limitations under the License.
""" Auto Model class. """

import warnings
from collections import OrderedDict

from .auto_factory import _BaseAutoModelClass, _LazyAutoMapping, auto_class_update
@@ -307,7 +306,7 @@ AutoModelForSpeechSeq2Seq = auto_class_update(
class AutoModelWithLMHead(_AutoModelWithLMHead):
@classmethod
def from_config(cls, config):
warnings.warn(
logger.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",
@@ -317,7 +316,7 @@ class AutoModelWithLMHead(_AutoModelWithLMHead):

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
logger.warn(
"The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use "
"`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and "
"`AutoModelForSeq2SeqLM` for encoder-decoder models.",


+ 1
- 0
tests/modules/torch/encoder/test_seq2seq_encoder.py View File

@@ -10,6 +10,7 @@ if _NEED_IMPORT_TORCH:
from fastNLP.embeddings.torch import StaticEmbedding


@pytest.mark.torch
class TestTransformerSeq2SeqEncoder:
def test_case(self):
vocab = Vocabulary().add_word_lst("This is a test .".split())


+ 1
- 1
tests/modules/torch/generator/test_seq2seq_generator.py View File

@@ -43,7 +43,7 @@ class DummyState(State):
super().__init__()
self.decoder = decoder

def reorder_state(self, indices: torch.LongTensor):
def reorder_state(self, indices: "torch.LongTensor"):
self.decoder.decoder_output = self._reorder_state(self.decoder.decoder_output, indices, dim=0)




Loading…
Cancel
Save