Browse Source

修改了 core 部分 import 的顺序,__all__ 暴露的内容

tags/v0.4.10
ChenXin 5 years ago
parent
commit
32fdb48754
14 changed files with 336 additions and 263 deletions
  1. +10
    -7
      fastNLP/core/batch.py
  2. +13
    -10
      fastNLP/core/callback.py
  3. +5
    -3
      fastNLP/core/dataset.py
  4. +54
    -46
      fastNLP/core/field.py
  5. +3
    -1
      fastNLP/core/instance.py
  6. +14
    -3
      fastNLP/core/losses.py
  7. +108
    -100
      fastNLP/core/metrics.py
  8. +14
    -6
      fastNLP/core/optimizer.py
  9. +8
    -3
      fastNLP/core/predictor.py
  10. +8
    -4
      fastNLP/core/sampler.py
  11. +16
    -13
      fastNLP/core/tester.py
  12. +6
    -7
      fastNLP/core/trainer.py
  13. +72
    -60
      fastNLP/core/utils.py
  14. +5
    -0
      fastNLP/core/vocabulary.py

+ 10
- 7
fastNLP/core/batch.py View File

@@ -2,15 +2,19 @@
batch 模块实现了 fastNLP 所需的 Batch 类。 batch 模块实现了 fastNLP 所需的 Batch 类。


""" """
__all__ = ["Batch"]
import atexit
import numpy as np import numpy as np
import torch import torch
import atexit

from .sampler import RandomSampler, Sampler
import torch.multiprocessing as mp import torch.multiprocessing as mp

from queue import Empty, Full from queue import Empty, Full


from .sampler import RandomSampler

__all__ = [
"Batch"
]

_python_is_exit = False _python_is_exit = False




@@ -120,7 +124,7 @@ class Batch(object):
:return list(int) indexes: 下标序列 :return list(int) indexes: 下标序列
""" """
return self.cur_batch_indices return self.cur_batch_indices
@staticmethod @staticmethod
def _run_fetch(batch, q): def _run_fetch(batch, q):
try: try:
@@ -145,7 +149,7 @@ class Batch(object):
q.put(e) q.put(e)
finally: finally:
q.join() q.join()
@staticmethod @staticmethod
def _run_batch_iter(batch): def _run_batch_iter(batch):
q = mp.JoinableQueue(maxsize=10) q = mp.JoinableQueue(maxsize=10)
@@ -182,4 +186,3 @@ def _to_tensor(batch, dtype):
except: except:
pass pass
return batch return batch


+ 13
- 10
fastNLP/core/callback.py View File

@@ -49,6 +49,18 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:
trainer.train() trainer.train()


""" """
import os
import torch

try:
from tensorboardX import SummaryWriter
tensorboardX_flag = True
except:
tensorboardX_flag = False

from ..io.model_io import ModelSaver, ModelLoader

__all__ = [ __all__ = [
"Callback", "Callback",
"GradientClipCallback", "GradientClipCallback",
@@ -60,15 +72,6 @@ __all__ = [
"CallbackException", "CallbackException",
"EarlyStopError" "EarlyStopError"
] ]
import os
import torch
from ..io.model_io import ModelSaver, ModelLoader

try:
from tensorboardX import SummaryWriter
tensorboardX_flag = True
except:
tensorboardX_flag = False




class Callback(object): class Callback(object):
@@ -587,7 +590,7 @@ class TensorboardCallback(Callback):
self._summary_writer = SummaryWriter(path) self._summary_writer = SummaryWriter(path)
else: else:
self._summary_writer = None self._summary_writer = None
def on_batch_begin(self, batch_x, batch_y, indices): def on_batch_begin(self, batch_x, batch_y, indices):
if "model" in self.options and self.graph_added is False: if "model" in self.options and self.graph_added is False:
# tesorboardX 这里有大bug,暂时没法画模型图 # tesorboardX 这里有大bug,暂时没法画模型图


+ 5
- 3
fastNLP/core/dataset.py View File

@@ -272,9 +272,7 @@




""" """
__all__ = ["DataSet"]
import _pickle as pickle import _pickle as pickle

import numpy as np import numpy as np
import warnings import warnings


@@ -283,6 +281,10 @@ from .field import FieldArray
from .instance import Instance from .instance import Instance
from .utils import _get_func_signature from .utils import _get_func_signature


__all__ = [
"DataSet"
]



class DataSet(object): class DataSet(object):
""" """
@@ -854,4 +856,4 @@ class DataSet(object):
with open(path, 'rb') as f: with open(path, 'rb') as f:
d = pickle.load(f) d = pickle.load(f)
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d))
return d
return d

+ 54
- 46
fastNLP/core/field.py View File

@@ -3,11 +3,17 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas
原理部分请参考 :doc:`fastNLP.core.dataset` 原理部分请参考 :doc:`fastNLP.core.dataset`


""" """


import numpy as np import numpy as np

from copy import deepcopy from copy import deepcopy


__all__ = [
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder"
]



class FieldArray(object): class FieldArray(object):
""" """
@@ -24,6 +30,7 @@ class FieldArray(object):
:param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor, :param bool ignore_type: 是否忽略该field的type,一般如果这个field不需要转为torch.FloatTensor或torch.LongTensor,
就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。 就可以设置为True。具体意义请参考 :class:`~fastNLP.DataSet` 。
""" """
def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False): def __init__(self, name, content, is_target=None, is_input=None, padder=None, ignore_type=False):
self.name = name self.name = name
if isinstance(content, list): if isinstance(content, list):
@@ -41,7 +48,7 @@ class FieldArray(object):
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content)))
if len(content) == 0: if len(content) == 0:
raise RuntimeError("Cannot initialize FieldArray with empty list.") raise RuntimeError("Cannot initialize FieldArray with empty list.")
self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐 self.content = content # 1维 或 2维 或 3维 list, 形状可能不对齐
self.content_dim = None # 表示content是多少维的list self.content_dim = None # 表示content是多少维的list
if padder is None: if padder is None:
@@ -51,27 +58,27 @@ class FieldArray(object):
padder = deepcopy(padder) padder = deepcopy(padder)
self.set_padder(padder) self.set_padder(padder)
self.ignore_type = ignore_type self.ignore_type = ignore_type
self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array self.BASIC_TYPES = (int, float, str) # content中可接受的Python基本类型,这里没有np.array
self.pytype = None self.pytype = None
self.dtype = None self.dtype = None
self._is_input = None self._is_input = None
self._is_target = None self._is_target = None
if is_input is not None or is_target is not None: if is_input is not None or is_target is not None:
self.is_input = is_input self.is_input = is_input
self.is_target = is_target self.is_target = is_target
def _set_dtype(self): def _set_dtype(self):
if self.ignore_type is False: if self.ignore_type is False:
self.pytype = self._type_detection(self.content) self.pytype = self._type_detection(self.content)
self.dtype = self._map_to_np_type(self.pytype) self.dtype = self._map_to_np_type(self.pytype)
@property @property
def is_input(self): def is_input(self):
return self._is_input return self._is_input
@is_input.setter @is_input.setter
def is_input(self, value): def is_input(self, value):
""" """
@@ -80,11 +87,11 @@ class FieldArray(object):
if value is True: if value is True:
self._set_dtype() self._set_dtype()
self._is_input = value self._is_input = value
@property @property
def is_target(self): def is_target(self):
return self._is_target return self._is_target
@is_target.setter @is_target.setter
def is_target(self, value): def is_target(self, value):
""" """
@@ -93,7 +100,7 @@ class FieldArray(object):
if value is True: if value is True:
self._set_dtype() self._set_dtype()
self._is_target = value self._is_target = value
def _type_detection(self, content): def _type_detection(self, content):
""" """
当该field被设置为is_input或者is_target时被调用 当该field被设置为is_input或者is_target时被调用
@@ -101,9 +108,9 @@ class FieldArray(object):
""" """
if len(content) == 0: if len(content) == 0:
raise RuntimeError("Empty list in Field {}.".format(self.name)) raise RuntimeError("Empty list in Field {}.".format(self.name))
type_set = set([type(item) for item in content]) type_set = set([type(item) for item in content])
if list in type_set: if list in type_set:
if len(type_set) > 1: if len(type_set) > 1:
# list 跟 非list 混在一起 # list 跟 非list 混在一起
@@ -139,7 +146,7 @@ class FieldArray(object):
self.name, self.BASIC_TYPES, content_type)) self.name, self.BASIC_TYPES, content_type))
self.content_dim = 1 self.content_dim = 1
return self._basic_type_detection(type_set) return self._basic_type_detection(type_set)
def _basic_type_detection(self, type_set): def _basic_type_detection(self, type_set):
""" """
:param type_set: a set of Python types :param type_set: a set of Python types
@@ -158,7 +165,7 @@ class FieldArray(object):
else: else:
# str, int, float混在一起 # str, int, float混在一起
raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set))) raise RuntimeError("Mixed data types in Field {}: {}".format(self.name, list(type_set)))
def _1d_list_check(self, val): def _1d_list_check(self, val):
"""如果不是1D list就报错 """如果不是1D list就报错
""" """
@@ -168,7 +175,7 @@ class FieldArray(object):
self._basic_type_detection(type_set) self._basic_type_detection(type_set)
# otherwise: _basic_type_detection will raise error # otherwise: _basic_type_detection will raise error
return True return True
def _2d_list_check(self, val): def _2d_list_check(self, val):
"""如果不是2D list 就报错 """如果不是2D list 就报错
""" """
@@ -181,15 +188,15 @@ class FieldArray(object):
inner_type_set.add(type(obj)) inner_type_set.add(type(obj))
self._basic_type_detection(inner_type_set) self._basic_type_detection(inner_type_set)
return True return True
@staticmethod @staticmethod
def _map_to_np_type(basic_type): def _map_to_np_type(basic_type):
type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray} type_mapping = {int: np.int64, float: np.float64, str: np.str, np.ndarray: np.ndarray}
return type_mapping[basic_type] return type_mapping[basic_type]
def __repr__(self): def __repr__(self):
return "FieldArray {}: {}".format(self.name, self.content.__repr__()) return "FieldArray {}: {}".format(self.name, self.content.__repr__())
def append(self, val): def append(self, val):
"""将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有 """将val append到这个field的尾部。如果这个field已经被设置为input或者target,则在append之前会检查该类型是否与已有
的内容是匹配的。 的内容是匹配的。
@@ -208,7 +215,7 @@ class FieldArray(object):
else: else:
raise RuntimeError( raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
if self.is_input is True or self.is_target is True: if self.is_input is True or self.is_target is True:
if type(val) == list: if type(val) == list:
if len(val) == 0: if len(val) == 0:
@@ -231,14 +238,14 @@ class FieldArray(object):
raise RuntimeError( raise RuntimeError(
"Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES)) "Unexpected data type {}. Should be list, np.array, or {}".format(type(val), self.BASIC_TYPES))
self.content.append(val) self.content.append(val)
def __getitem__(self, indices): def __getitem__(self, indices):
return self.get(indices, pad=False) return self.get(indices, pad=False)
def __setitem__(self, idx, val): def __setitem__(self, idx, val):
assert isinstance(idx, int) assert isinstance(idx, int)
self.content[idx] = val self.content[idx] = val
def get(self, indices, pad=True): def get(self, indices, pad=True):
""" """
根据给定的indices返回内容 根据给定的indices返回内容
@@ -251,13 +258,13 @@ class FieldArray(object):
return self.content[indices] return self.content[indices]
if self.is_input is False and self.is_target is False: if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name)) raise RuntimeError("Please specify either is_input or is_target is True for {}".format(self.name))
contents = [self.content[i] for i in indices] contents = [self.content[i] for i in indices]
if self.padder is None or pad is False: if self.padder is None or pad is False:
return np.array(contents) return np.array(contents)
else: else:
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype) return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype)
def set_padder(self, padder): def set_padder(self, padder):
""" """
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。
@@ -269,7 +276,7 @@ class FieldArray(object):
self.padder = deepcopy(padder) self.padder = deepcopy(padder)
else: else:
self.padder = None self.padder = None
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
""" """
修改padder的pad_val. 修改padder的pad_val.
@@ -279,8 +286,7 @@ class FieldArray(object):
if self.padder is not None: if self.padder is not None:
self.padder.set_pad_val(pad_val) self.padder.set_pad_val(pad_val)
return self return self


def __len__(self): def __len__(self):
""" """
Returns the size of FieldArray. Returns the size of FieldArray.
@@ -288,7 +294,7 @@ class FieldArray(object):
:return int length: :return int length:
""" """
return len(self.content) return len(self.content)
def to(self, other): def to(self, other):
""" """
将other的属性复制给本FieldArray(other必须为FieldArray类型). 将other的属性复制给本FieldArray(other必须为FieldArray类型).
@@ -298,14 +304,15 @@ class FieldArray(object):
:return: :class:`~fastNLP.FieldArray` :return: :class:`~fastNLP.FieldArray`
""" """
assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other)) assert isinstance(other, FieldArray), "Only support FieldArray type, not {}.".format(type(other))
self.is_input = other.is_input self.is_input = other.is_input
self.is_target = other.is_target self.is_target = other.is_target
self.padder = other.padder self.padder = other.padder
self.ignore_type = other.ignore_type self.ignore_type = other.ignore_type
return self return self



def _is_iterable(content): def _is_iterable(content):
try: try:
_ = (e for e in content) _ = (e for e in content)
@@ -331,13 +338,13 @@ class Padder:
:return: np.array([padded_element]) :return: np.array([padded_element])
""" """
def __init__(self, pad_val=0, **kwargs): def __init__(self, pad_val=0, **kwargs):
self.pad_val = pad_val self.pad_val = pad_val
def set_pad_val(self, pad_val): def set_pad_val(self, pad_val):
self.pad_val = pad_val self.pad_val = pad_val
def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
""" """
传入的是List内容。假设有以下的DataSet。 传入的是List内容。假设有以下的DataSet。
@@ -396,13 +403,13 @@ class AutoPadder(Padder):
2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。 2.2 如果该field的内容为List, 那么会将Batch中的List pad为一样长。若该List下还有里层的List需要padding,请使用其它padder。
即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad 即如果Instance中field形如[1, 2, 3, ...],则可以pad;若为[[1,2], [3,4, ...]]则不能进行pad
""" """
def __init__(self, pad_val=0): def __init__(self, pad_val=0):
""" """
:param pad_val: int, padding的位置使用该index :param pad_val: int, padding的位置使用该index
""" """
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
def _is_two_dimension(self, contents): def _is_two_dimension(self, contents):
""" """
判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度 判断contents是不是只有两个维度。[[1,2], [3]]是两个维度. [[[1,2], [3, 4, 5]], [[4,5]]]有三个维度
@@ -416,7 +423,7 @@ class AutoPadder(Padder):
return False return False
return True return True
return False return False
def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
if not _is_iterable(contents[0]): if not _is_iterable(contents[0]):
@@ -458,6 +465,7 @@ class EngChar2DPadder(Padder):
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder


""" """
def __init__(self, pad_val=0, pad_length=0): def __init__(self, pad_val=0, pad_length=0):
""" """
:param pad_val: int, pad的位置使用该index :param pad_val: int, pad的位置使用该index
@@ -465,9 +473,9 @@ class EngChar2DPadder(Padder):
都pad或截取到该长度. 都pad或截取到该长度.
""" """
super().__init__(pad_val=pad_val) super().__init__(pad_val=pad_val)
self.pad_length = pad_length self.pad_length = pad_length
def _exactly_three_dims(self, contents, field_name): def _exactly_three_dims(self, contents, field_name):
""" """
检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character 检查传入的contents是否刚好是3维,如果不是3维就报错。理论上,第一个维度是batch,第二个维度是word,第三个维度是character
@@ -486,10 +494,10 @@ class EngChar2DPadder(Padder):
value = value[0] value = value[0]
except: except:
raise ValueError("Field:{} only has two dimensions.".format(field_name)) raise ValueError("Field:{} only has two dimensions.".format(field_name))
if _is_iterable(value): if _is_iterable(value):
raise ValueError("Field:{} has more than 3 dimension.".format(field_name)) raise ValueError("Field:{} has more than 3 dimension.".format(field_name))
def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
""" """
期望输入类似于 期望输入类似于
@@ -516,12 +524,12 @@ class EngChar2DPadder(Padder):
max_sent_length = max(len(word_lst) for word_lst in contents) max_sent_length = max(len(word_lst) for word_lst in contents)
batch_size = len(contents) batch_size = len(contents)
dtype = type(contents[0][0][0]) dtype = type(contents[0][0][0])
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val,
dtype=dtype)
dtype=dtype)
for b_idx, word_lst in enumerate(contents): for b_idx, word_lst in enumerate(contents):
for c_idx, char_lst in enumerate(word_lst): for c_idx, char_lst in enumerate(word_lst):
chars = char_lst[:max_char_length] chars = char_lst[:max_char_length]
padded_array[b_idx, c_idx, :len(chars)] = chars padded_array[b_idx, c_idx, :len(chars)] = chars
return padded_array
return padded_array

+ 3
- 1
fastNLP/core/instance.py View File

@@ -3,7 +3,9 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格


""" """
__all__ = ["Instance"]
__all__ = [
"Instance"
]




class Instance(object): class Instance(object):


+ 14
- 3
fastNLP/core/losses.py View File

@@ -2,13 +2,12 @@
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。


""" """
__all__ = ["LossBase", "L1Loss", "LossFunc", "LossInForward", "BCELoss", "CrossEntropyLoss", "NLLLoss"]
import inspect import inspect
from collections import defaultdict

import torch import torch
import torch.nn.functional as F import torch.nn.functional as F


from collections import defaultdict

from .utils import _CheckError from .utils import _CheckError
from .utils import _CheckRes from .utils import _CheckRes
from .utils import _build_args from .utils import _build_args
@@ -16,6 +15,18 @@ from .utils import _check_arg_dict_list
from .utils import _check_function_or_method from .utils import _check_function_or_method
from .utils import _get_func_signature from .utils import _get_func_signature


__all__ = [
"LossBase",
"LossFunc",
"LossInForward",
"CrossEntropyLoss",
"BCELoss",
"L1Loss",
"NLLLoss"
]



class LossBase(object): class LossBase(object):
""" """


+ 108
- 100
fastNLP/core/metrics.py View File

@@ -3,11 +3,11 @@ metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为


""" """
import inspect import inspect
from collections import defaultdict

import numpy as np import numpy as np
import torch import torch


from collections import defaultdict

from .utils import _CheckError from .utils import _CheckError
from .utils import _CheckRes from .utils import _CheckRes
from .utils import _build_args from .utils import _build_args
@@ -16,6 +16,13 @@ from .utils import _get_func_signature
from .utils import seq_len_to_mask from .utils import seq_len_to_mask
from .vocabulary import Vocabulary from .vocabulary import Vocabulary


__all__ = [
"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",
"SQuADMetric"
]



class MetricBase(object): class MetricBase(object):
""" """
@@ -106,16 +113,17 @@ class MetricBase(object):
self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值 self.get_metric将统计当前的评价指标并返回评价结果, 返回值需要是一个dict, key是指标名称,value是指标的值


""" """
def __init__(self): def __init__(self):
self.param_map = {} # key is param in function, value is input param. self.param_map = {} # key is param in function, value is input param.
self._checked = False self._checked = False
def evaluate(self, *args, **kwargs): def evaluate(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def get_metric(self, reset=True): def get_metric(self, reset=True):
raise NotImplemented raise NotImplemented
def _init_param_map(self, key_map=None, **kwargs): def _init_param_map(self, key_map=None, **kwargs):
"""检查key_map和其他参数map,并将这些映射关系添加到self.param_map """检查key_map和其他参数map,并将这些映射关系添加到self.param_map


@@ -148,7 +156,7 @@ class MetricBase(object):
for value, key_set in value_counter.items(): for value, key_set in value_counter.items():
if len(key_set) > 1: if len(key_set) > 1:
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")
# check consistence between signature and param_map # check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = [arg for arg in func_spect.args if arg != 'self'] func_args = [arg for arg in func_spect.args if arg != 'self']
@@ -157,7 +165,7 @@ class MetricBase(object):
raise NameError( raise NameError(
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.") f"initialization parameters, or change its signature.")
def _fast_param_map(self, pred_dict, target_dict): def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element such as pred_dict has one element, target_dict has one element
@@ -172,7 +180,7 @@ class MetricBase(object):
fast_param['target'] = list(target_dict.values())[0] fast_param['target'] = list(target_dict.values())[0]
return fast_param return fast_param
return fast_param return fast_param
def __call__(self, pred_dict, target_dict): def __call__(self, pred_dict, target_dict):
""" """
这个方法会调用self.evaluate 方法. 这个方法会调用self.evaluate 方法.
@@ -187,12 +195,12 @@ class MetricBase(object):
:param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容) :param target_dict: DataSet.batch_y里的键-值对所组成的dict(即is_target=True的fields的内容)
:return: :return:
""" """
fast_param = self._fast_param_map(pred_dict, target_dict) fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param: if fast_param:
self.evaluate(**fast_param) self.evaluate(**fast_param)
return return
if not self._checked: if not self._checked:
if not callable(self.evaluate): if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")
@@ -202,14 +210,14 @@ class MetricBase(object):
for func_arg, input_arg in self.param_map.items(): for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args: if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.") raise NameError(f"`{func_arg}` not in {_get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not # 2. only part of the param_map are passed, left are not
for arg in func_args: for arg in func_args:
if arg not in self.param_map: if arg not in self.param_map:
self.param_map[arg] = arg # This param does not need mapping. self.param_map[arg] = arg # This param does not need mapping.
self._evaluate_args = func_args self._evaluate_args = func_args
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}
# need to wrap inputs in dict. # need to wrap inputs in dict.
mapped_pred_dict = {} mapped_pred_dict = {}
mapped_target_dict = {} mapped_target_dict = {}
@@ -229,7 +237,7 @@ class MetricBase(object):
not_duplicate_flag += 1 not_duplicate_flag += 1
if not_duplicate_flag == 3: if not_duplicate_flag == 3:
duplicated.append(input_arg) duplicated.append(input_arg)
# missing # missing
if not self._checked: if not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
@@ -240,23 +248,23 @@ class MetricBase(object):
for idx, func_arg in enumerate(missing): for idx, func_arg in enumerate(missing):
# Don't delete `` in this information, nor add `` # Don't delete `` in this information, nor add ``
replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \ replaced_missing[idx] = f"{self.param_map[func_arg]}" + f"(assign to `{func_arg}` " \
f"in `{self.__class__.__name__}`)"
f"in `{self.__class__.__name__}`)"
check_res = _CheckRes(missing=replaced_missing, check_res = _CheckRes(missing=replaced_missing,
unused=check_res.unused, unused=check_res.unused,
duplicated=duplicated, duplicated=duplicated,
required=check_res.required, required=check_res.required,
all_needed=check_res.all_needed, all_needed=check_res.all_needed,
varargs=check_res.varargs) varargs=check_res.varargs)
if check_res.missing or check_res.duplicated: if check_res.missing or check_res.duplicated:
raise _CheckError(check_res=check_res, raise _CheckError(check_res=check_res,
func_signature=_get_func_signature(self.evaluate)) func_signature=_get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)
self.evaluate(**refined_args) self.evaluate(**refined_args)
self._checked = True self._checked = True
return return




@@ -271,15 +279,16 @@ class AccuracyMetric(MetricBase):
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
""" """
def __init__(self, pred=None, target=None, seq_len=None): def __init__(self, pred=None, target=None, seq_len=None):
super().__init__() super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len) self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.total = 0 self.total = 0
self.acc_count = 0 self.acc_count = 0
def evaluate(self, pred, target, seq_len=None): def evaluate(self, pred, target, seq_len=None):
""" """
evaluate函数将针对一个批次的预测结果做评价指标的累计 evaluate函数将针对一个批次的预测结果做评价指标的累计
@@ -299,16 +308,16 @@ class AccuracyMetric(MetricBase):
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")
if seq_len is not None and not isinstance(seq_len, torch.Tensor): if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.") f"got {type(seq_len)}.")
if seq_len is not None: if seq_len is not None:
masks = seq_len_to_mask(seq_len=seq_len) masks = seq_len_to_mask(seq_len=seq_len)
else: else:
masks = None masks = None
if pred.size() == target.size(): if pred.size() == target.size():
pass pass
elif len(pred.size()) == len(target.size()) + 1: elif len(pred.size()) == len(target.size()) + 1:
@@ -317,7 +326,7 @@ class AccuracyMetric(MetricBase):
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")
target = target.to(pred) target = target.to(pred)
if masks is not None: if masks is not None:
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item() self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks.eq(0), 0)).item()
@@ -325,7 +334,7 @@ class AccuracyMetric(MetricBase):
else: else:
self.acc_count += torch.sum(torch.eq(pred, target)).item() self.acc_count += torch.sum(torch.eq(pred, target)).item()
self.total += np.prod(list(pred.size())) self.total += np.prod(list(pred.size()))
def get_metric(self, reset=True): def get_metric(self, reset=True):
""" """
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.
@@ -350,7 +359,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None):
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
""" """
ignore_labels = set(ignore_labels) if ignore_labels else set() ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = [] spans = []
prev_bmes_tag = None prev_bmes_tag = None
for idx, tag in enumerate(tags): for idx, tag in enumerate(tags):
@@ -358,14 +367,14 @@ def _bmes_tag_to_spans(tags, ignore_labels=None):
bmes_tag, label = tag[:1], tag[2:] bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'): if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]:
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx spans[-1][1][1] = idx
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans
if span[0] not in ignore_labels
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
] ]




@@ -379,7 +388,7 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None):
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
""" """
ignore_labels = set(ignore_labels) if ignore_labels else set() ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = [] spans = []
prev_bmes_tag = None prev_bmes_tag = None
for idx, tag in enumerate(tags): for idx, tag in enumerate(tags):
@@ -387,16 +396,16 @@ def _bmeso_tag_to_spans(tags, ignore_labels=None):
bmes_tag, label = tag[:1], tag[2:] bmes_tag, label = tag[:1], tag[2:]
if bmes_tag in ('b', 's'): if bmes_tag in ('b', 's'):
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label==spans[-1][0]:
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]:
spans[-1][1][1] = idx spans[-1][1][1] = idx
elif bmes_tag == 'o': elif bmes_tag == 'o':
pass pass
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bmes_tag = bmes_tag prev_bmes_tag = bmes_tag
return [(span[0], (span[1][0], span[1][1]+1))
for span in spans
if span[0] not in ignore_labels
return [(span[0], (span[1][0], span[1][1] + 1))
for span in spans
if span[0] not in ignore_labels
] ]




@@ -410,7 +419,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] :return: List[Tuple[str, List[int, int]]]. [(label,[start, end])]
""" """
ignore_labels = set(ignore_labels) if ignore_labels else set() ignore_labels = set(ignore_labels) if ignore_labels else set()
spans = [] spans = []
prev_bio_tag = None prev_bio_tag = None
for idx, tag in enumerate(tags): for idx, tag in enumerate(tags):
@@ -418,14 +427,14 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
bio_tag, label = tag[:1], tag[2:] bio_tag, label = tag[:1], tag[2:]
if bio_tag == 'b': if bio_tag == 'b':
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label==spans[-1][0]:
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]:
spans[-1][1][1] = idx spans[-1][1][1] = idx
elif bio_tag == 'o': # o tag does not count
elif bio_tag == 'o': # o tag does not count
pass pass
else: else:
spans.append((label, [idx, idx])) spans.append((label, [idx, idx]))
prev_bio_tag = bio_tag prev_bio_tag = bio_tag
return [(span[0], (span[1][0], span[1][1]+1)) for span in spans if span[0] not in ignore_labels]
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]




class SpanFPreRecMetric(MetricBase): class SpanFPreRecMetric(MetricBase):
@@ -470,16 +479,17 @@ class SpanFPreRecMetric(MetricBase):
:param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5 :param float beta: f_beta分数,f_beta = (1 + beta^2)*(pre*rec)/(beta^2*pre + rec). 常用为beta=0.5, 1, 2. 若为0.5
则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
""" """
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None,
only_gross=True, f_type='micro', beta=1):
only_gross=True, f_type='micro', beta=1):
encoding_type = encoding_type.lower() encoding_type = encoding_type.lower()
if not isinstance(tag_vocab, Vocabulary): if not isinstance(tag_vocab, Vocabulary):
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab)))
if f_type not in ('micro', 'macro'): if f_type not in ('micro', 'macro'):
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type))
self.encoding_type = encoding_type self.encoding_type = encoding_type
if self.encoding_type == 'bmes': if self.encoding_type == 'bmes':
self.tag_to_span_func = _bmes_tag_to_spans self.tag_to_span_func = _bmes_tag_to_spans
@@ -489,22 +499,22 @@ class SpanFPreRecMetric(MetricBase):
self.tag_to_span_func = _bmeso_tag_to_spans self.tag_to_span_func = _bmeso_tag_to_spans
else: else:
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.")
self.ignore_labels = ignore_labels self.ignore_labels = ignore_labels
self.f_type = f_type self.f_type = f_type
self.beta = beta self.beta = beta
self.beta_square = self.beta**2
self.beta_square = self.beta ** 2
self.only_gross = only_gross self.only_gross = only_gross
super().__init__() super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len) self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.tag_vocab = tag_vocab self.tag_vocab = tag_vocab
self._true_positives = defaultdict(int) self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int) self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int) self._false_negatives = defaultdict(int)
def evaluate(self, pred, target, seq_len): def evaluate(self, pred, target, seq_len):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 """evaluate函数将针对一个批次的预测结果做评价指标的累计


@@ -519,11 +529,11 @@ class SpanFPreRecMetric(MetricBase):
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")
if not isinstance(seq_len, torch.Tensor): if not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.") f"got {type(seq_len)}.")
if pred.size() == target.size() and len(target.size()) == 2: if pred.size() == target.size() and len(target.size()) == 2:
pass pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
@@ -536,20 +546,20 @@ class SpanFPreRecMetric(MetricBase):
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or " f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.") f"{pred.size()[:-1]}, got {target.size()}.")
batch_size = pred.size(0) batch_size = pred.size(0)
pred = pred.tolist() pred = pred.tolist()
target = target.tolist() target = target.tolist()
for i in range(batch_size): for i in range(batch_size):
pred_tags = pred[i][:int(seq_len[i])] pred_tags = pred[i][:int(seq_len[i])]
gold_tags = target[i][:int(seq_len[i])] gold_tags = target[i][:int(seq_len[i])]
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels)
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels)
for span in pred_spans: for span in pred_spans:
if span in gold_spans: if span in gold_spans:
self._true_positives[span[0]] += 1 self._true_positives[span[0]] += 1
@@ -558,7 +568,7 @@ class SpanFPreRecMetric(MetricBase):
self._false_positives[span[0]] += 1 self._false_positives[span[0]] += 1
for span in gold_spans: for span in gold_spans:
self._false_negatives[span[0]] += 1 self._false_negatives[span[0]] += 1
def get_metric(self, reset=True): def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {} evaluate_result = {}
@@ -577,19 +587,19 @@ class SpanFPreRecMetric(MetricBase):
f_sum += f f_sum += f
pre_sum += pre pre_sum += pre
rec_sum + rec rec_sum + rec
if not self.only_gross and tag!='': # tag!=''防止无tag的情况
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag) f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag) pre_key = 'pre-{}'.format(tag)
rec_key = 'rec-{}'.format(tag) rec_key = 'rec-{}'.format(tag)
evaluate_result[f_key] = f evaluate_result[f_key] = f
evaluate_result[pre_key] = pre evaluate_result[pre_key] = pre
evaluate_result[rec_key] = rec evaluate_result[rec_key] = rec
if self.f_type == 'macro': if self.f_type == 'macro':
evaluate_result['f'] = f_sum/len(tags)
evaluate_result['pre'] = pre_sum/len(tags)
evaluate_result['rec'] = rec_sum/len(tags)
evaluate_result['f'] = f_sum / len(tags)
evaluate_result['pre'] = pre_sum / len(tags)
evaluate_result['rec'] = rec_sum / len(tags)
if self.f_type == 'micro': if self.f_type == 'micro':
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()), sum(self._false_negatives.values()),
@@ -597,17 +607,17 @@ class SpanFPreRecMetric(MetricBase):
evaluate_result['f'] = f evaluate_result['f'] = f
evaluate_result['pre'] = pre evaluate_result['pre'] = pre
evaluate_result['rec'] = rec evaluate_result['rec'] = rec
if reset: if reset:
self._true_positives = defaultdict(int) self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int) self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int) self._false_negatives = defaultdict(int)
for key, value in evaluate_result.items(): for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6) evaluate_result[key] = round(value, 6)
return evaluate_result return evaluate_result
def _compute_f_pre_rec(self, tp, fn, fp): def _compute_f_pre_rec(self, tp, fn, fp):
""" """


@@ -619,11 +629,10 @@ class SpanFPreRecMetric(MetricBase):
pre = tp / (fp + tp + 1e-13) pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13) rec = tp / (fn + tp + 1e-13)
f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13)
return f, pre, rec return f, pre, rec





def _prepare_metrics(metrics): def _prepare_metrics(metrics):
""" """


@@ -705,33 +714,33 @@ class SQuADMetric(MetricBase):
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 :param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出
""" """
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, def __init__(self, pred1=None, pred2=None, target1=None, target2=None,
beta=1, right_open=True, print_predict_stat=False): beta=1, right_open=True, print_predict_stat=False):
super(SQuADMetric, self).__init__() super(SQuADMetric, self).__init__()
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2)
self.print_predict_stat = print_predict_stat self.print_predict_stat = print_predict_stat
self.no_ans_correct = 0 self.no_ans_correct = 0
self.no_ans_wrong = 0 self.no_ans_wrong = 0
self.has_ans_correct = 0 self.has_ans_correct = 0
self.has_ans_wrong = 0 self.has_ans_wrong = 0
self.has_ans_f = 0. self.has_ans_f = 0.
self.no2no = 0 self.no2no = 0
self.no2yes = 0 self.no2yes = 0
self.yes2no = 0 self.yes2no = 0
self.yes2yes = 0 self.yes2yes = 0
self.f_beta = beta self.f_beta = beta
self.right_open = right_open self.right_open = right_open
def evaluate(self, pred1, pred2, target1, target2): def evaluate(self, pred1, pred2, target1, target2):
"""evaluate函数将针对一个批次的预测结果做评价指标的累计 """evaluate函数将针对一个批次的预测结果做评价指标的累计


@@ -745,7 +754,7 @@ class SQuADMetric(MetricBase):
pred_end = pred2 pred_end = pred2
target_start = target1 target_start = target1
target_end = target2 target_end = target2
if len(pred_start.size()) == 2: if len(pred_start.size()) == 2:
start_inference = pred_start.max(dim=-1)[1].cpu().tolist() start_inference = pred_start.max(dim=-1)[1].cpu().tolist()
else: else:
@@ -754,12 +763,12 @@ class SQuADMetric(MetricBase):
end_inference = pred_end.max(dim=-1)[1].cpu().tolist() end_inference = pred_end.max(dim=-1)[1].cpu().tolist()
else: else:
end_inference = pred_end.cpu().tolist() end_inference = pred_end.cpu().tolist()
start, end = [], [] start, end = [], []
max_len = pred_start.size(1) max_len = pred_start.size(1)
t_start = target_start.cpu().tolist() t_start = target_start.cpu().tolist()
t_end = target_end.cpu().tolist() t_end = target_end.cpu().tolist()
for s, e in zip(start_inference, end_inference): for s, e in zip(start_inference, end_inference):
start.append(min(s, e)) start.append(min(s, e))
end.append(max(s, e)) end.append(max(s, e))
@@ -779,7 +788,7 @@ class SQuADMetric(MetricBase):
self.yes2no += 1 self.yes2no += 1
else: else:
self.yes2yes += 1 self.yes2yes += 1
if s == ts and e == te: if s == ts and e == te:
self.has_ans_correct += 1 self.has_ans_correct += 1
else: else:
@@ -787,29 +796,29 @@ class SQuADMetric(MetricBase):
a = [0] * s + [1] * (e - s) + [0] * (max_len - e) a = [0] * s + [1] * (e - s) + [0] * (max_len - e)
b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te) b = [0] * ts + [1] * (te - ts) + [0] * (max_len - te)
a, b = torch.tensor(a), torch.tensor(b) a, b = torch.tensor(a), torch.tensor(b)
TP = int(torch.sum(a * b)) TP = int(torch.sum(a * b))
pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0 pre = TP / int(torch.sum(a)) if int(torch.sum(a)) > 0 else 0
rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0 rec = TP / int(torch.sum(b)) if int(torch.sum(b)) > 0 else 0
if pre + rec > 0: if pre + rec > 0:
f = (1 + (self.f_beta**2)) * pre * rec / ((self.f_beta**2) * pre + rec)
f = (1 + (self.f_beta ** 2)) * pre * rec / ((self.f_beta ** 2) * pre + rec)
else: else:
f = 0 f = 0
self.has_ans_f += f self.has_ans_f += f
def get_metric(self, reset=True): def get_metric(self, reset=True):
"""get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.""" """get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果."""
evaluate_result = {} evaluate_result = {}
if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0: if self.no_ans_correct + self.no_ans_wrong + self.has_ans_correct + self.no_ans_wrong <= 0:
return evaluate_result return evaluate_result
evaluate_result['EM'] = 0 evaluate_result['EM'] = 0
evaluate_result[f'f_{self.f_beta}'] = 0 evaluate_result[f'f_{self.f_beta}'] = 0
flag = 0 flag = 0
if self.no_ans_correct + self.no_ans_wrong > 0: if self.no_ans_correct + self.no_ans_wrong > 0:
evaluate_result[f'noAns-f_{self.f_beta}'] = \ evaluate_result[f'noAns-f_{self.f_beta}'] = \
round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3) round(100 * self.no_ans_correct / (self.no_ans_correct + self.no_ans_wrong), 3)
@@ -818,7 +827,7 @@ class SQuADMetric(MetricBase):
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}'] evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'noAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['noAns-EM'] evaluate_result['EM'] += evaluate_result['noAns-EM']
flag += 1 flag += 1
if self.has_ans_correct + self.has_ans_wrong > 0: if self.has_ans_correct + self.has_ans_wrong > 0:
evaluate_result[f'hasAns-f_{self.f_beta}'] = \ evaluate_result[f'hasAns-f_{self.f_beta}'] = \
round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3) round(100 * self.has_ans_f / (self.has_ans_correct + self.has_ans_wrong), 3)
@@ -827,32 +836,31 @@ class SQuADMetric(MetricBase):
evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}'] evaluate_result[f'f_{self.f_beta}'] += evaluate_result[f'hasAns-f_{self.f_beta}']
evaluate_result['EM'] += evaluate_result['hasAns-EM'] evaluate_result['EM'] += evaluate_result['hasAns-EM']
flag += 1 flag += 1
if self.print_predict_stat: if self.print_predict_stat:
evaluate_result['no2no'] = self.no2no evaluate_result['no2no'] = self.no2no
evaluate_result['no2yes'] = self.no2yes evaluate_result['no2yes'] = self.no2yes
evaluate_result['yes2no'] = self.yes2no evaluate_result['yes2no'] = self.yes2no
evaluate_result['yes2yes'] = self.yes2yes evaluate_result['yes2yes'] = self.yes2yes
if flag <= 0: if flag <= 0:
return evaluate_result return evaluate_result
evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3) evaluate_result[f'f_{self.f_beta}'] = round(evaluate_result[f'f_{self.f_beta}'] / flag, 3)
evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3) evaluate_result['EM'] = round(evaluate_result['EM'] / flag, 3)
if reset: if reset:
self.no_ans_correct = 0 self.no_ans_correct = 0
self.no_ans_wrong = 0 self.no_ans_wrong = 0
self.has_ans_correct = 0 self.has_ans_correct = 0
self.has_ans_wrong = 0 self.has_ans_wrong = 0
self.has_ans_f = 0. self.has_ans_f = 0.
self.no2no = 0 self.no2no = 0
self.no2yes = 0 self.no2yes = 0
self.yes2no = 0 self.yes2no = 0
self.yes2yes = 0 self.yes2yes = 0
return evaluate_result return evaluate_result


+ 14
- 6
fastNLP/core/optimizer.py View File

@@ -4,6 +4,12 @@ optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :cl
""" """
import torch import torch


__all__ = [
"Optimizer",
"SGD",
"Adam"
]



class Optimizer(object): class Optimizer(object):
""" """
@@ -12,15 +18,16 @@ class Optimizer(object):
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
:param kwargs: additional parameters. :param kwargs: additional parameters.
""" """
def __init__(self, model_params, **kwargs): def __init__(self, model_params, **kwargs):
if model_params is not None and not hasattr(model_params, "__next__"): if model_params is not None and not hasattr(model_params, "__next__"):
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
self.model_params = model_params self.model_params = model_params
self.settings = kwargs self.settings = kwargs
def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
raise NotImplementedError raise NotImplementedError
def _get_require_grads_param(self, params): def _get_require_grads_param(self, params):
""" """
将params中不需要gradient的删除 将params中不需要gradient的删除
@@ -29,6 +36,7 @@ class Optimizer(object):
""" """
return [param for param in params if param.requires_grad] return [param for param in params if param.requires_grad]



class SGD(Optimizer): class SGD(Optimizer):
""" """
别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` 别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD`
@@ -37,12 +45,12 @@ class SGD(Optimizer):
:param float momentum: momentum. Default: 0 :param float momentum: momentum. Default: 0
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
""" """
def __init__(self, lr=0.001, momentum=0, model_params=None): def __init__(self, lr=0.001, momentum=0, model_params=None):
if not isinstance(lr, float): if not isinstance(lr, float):
raise TypeError("learning rate has to be float.") raise TypeError("learning rate has to be float.")
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)
def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
if self.model_params is None: if self.model_params is None:
# careful! generator cannot be assigned. # careful! generator cannot be assigned.
@@ -59,13 +67,13 @@ class Adam(Optimizer):
:param float weight_decay: :param float weight_decay:
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
""" """
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None):
if not isinstance(lr, float): if not isinstance(lr, float):
raise TypeError("learning rate has to be float.") raise TypeError("learning rate has to be float.")
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad,
weight_decay=weight_decay) weight_decay=weight_decay)
def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
if self.model_params is None: if self.model_params is None:
# careful! generator cannot be assigned. # careful! generator cannot be assigned.


+ 8
- 3
fastNLP/core/predictor.py View File

@@ -1,7 +1,11 @@
from collections import defaultdict

"""
..todo::
检查这个类是否需要
"""
import torch import torch


from collections import defaultdict

from . import Batch from . import Batch
from . import DataSet from . import DataSet
from . import SequentialSampler from . import SequentialSampler
@@ -9,7 +13,8 @@ from .utils import _build_args




class Predictor(object): class Predictor(object):
"""An interface for predicting outputs based on trained models.
"""
An interface for predicting outputs based on trained models.


It does not care about evaluations of the model, which is different from Tester. It does not care about evaluations of the model, which is different from Tester.
This is a high-level model wrapper to be called by FastNLP. This is a high-level model wrapper to be called by FastNLP.


+ 8
- 4
fastNLP/core/sampler.py View File

@@ -1,12 +1,16 @@
""" """
sampler 子类实现了 fastNLP 所需的各种采样器。 sampler 子类实现了 fastNLP 所需的各种采样器。


""" """
__all__ = ["Sampler", "BucketSampler", "SequentialSampler", "RandomSampler"]
import numpy as np

from itertools import chain from itertools import chain


import numpy as np
__all__ = [
"Sampler",
"BucketSampler",
"SequentialSampler",
"RandomSampler"
]




class Sampler(object): class Sampler(object):


+ 16
- 13
fastNLP/core/tester.py View File

@@ -33,9 +33,8 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation


""" """
import warnings import warnings

import torch import torch
from torch import nn
import torch.nn as nn


from .batch import Batch from .batch import Batch
from .dataset import DataSet from .dataset import DataSet
@@ -49,6 +48,10 @@ from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .utils import _move_model_to_device from .utils import _move_model_to_device


__all__ = [
"Tester"
]



class Tester(object): class Tester(object):
""" """
@@ -77,29 +80,29 @@ class Tester(object):
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
""" """
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1):
super(Tester, self).__init__() super(Tester, self).__init__()
if not isinstance(data, DataSet): if not isinstance(data, DataSet):
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")
self.metrics = _prepare_metrics(metrics) self.metrics = _prepare_metrics(metrics)
self.data = data self.data = data
self._model = _move_model_to_device(model, device=device) self._model = _move_model_to_device(model, device=device)
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose self.verbose = verbose
# 如果是DataParallel将没有办法使用predict方法 # 如果是DataParallel将没有办法使用predict方法
if isinstance(self._model, nn.DataParallel): if isinstance(self._model, nn.DataParallel):
if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'): if hasattr(self._model.module, 'predict') and not hasattr(self._model, 'predict'):
warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function," warnings.warn("Cannot use DataParallel to test your model, because your model offer predict() function,"
" while DataParallel has no predict() function.") " while DataParallel has no predict() function.")
self._model = self._model.module self._model = self._model.module
# check predict # check predict
if hasattr(self._model, 'predict'): if hasattr(self._model, 'predict'):
self._predict_func = self._model.predict self._predict_func = self._model.predict
@@ -109,7 +112,7 @@ class Tester(object):
f"for evaluation, not `{type(self._predict_func)}`.") f"for evaluation, not `{type(self._predict_func)}`.")
else: else:
self._predict_func = self._model.forward self._predict_func = self._model.forward
def test(self): def test(self):
"""开始进行验证,并返回验证结果。 """开始进行验证,并返回验证结果。


@@ -144,12 +147,12 @@ class Tester(object):
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0) dataset=self.data, check_level=0)
if self.verbose >= 1: if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results))) print("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False) self._mode(network, is_test=False)
return eval_results return eval_results
def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


@@ -161,13 +164,13 @@ class Tester(object):
model.eval() model.eval()
else: else:
model.train() model.train()
def _data_forward(self, func, x): def _data_forward(self, func, x):
"""A forward pass of the model. """ """A forward pass of the model. """
x = _build_args(func, **x) x = _build_args(func, **x)
y = func(**x) y = func(**x)
return y return y
def _format_eval_results(self, results): def _format_eval_results(self, results):
"""Override this method to support more print formats. """Override this method to support more print formats.




+ 6
- 7
fastNLP/core/trainer.py View File

@@ -297,13 +297,12 @@ Example2.3
""" """


import os import os
import time
from datetime import datetime
from datetime import timedelta

import numpy as np import numpy as np
import time
import torch import torch
from torch import nn
import torch.nn as nn

from datetime import datetime, timedelta


try: try:
from tqdm.auto import tqdm from tqdm.auto import tqdm
@@ -315,6 +314,7 @@ from .callback import CallbackManager, CallbackException
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .metrics import _prepare_metrics from .metrics import _prepare_metrics
from .optimizer import Optimizer
from .sampler import Sampler from .sampler import Sampler
from .sampler import RandomSampler from .sampler import RandomSampler
from .sampler import SequentialSampler from .sampler import SequentialSampler
@@ -326,7 +326,6 @@ from .utils import _check_loss_evaluate
from .utils import _move_dict_value_to_device from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import _get_model_device from .utils import _get_model_device
from .optimizer import Optimizer
from .utils import _move_model_to_device from .utils import _move_model_to_device




@@ -464,7 +463,7 @@ class Trainer(object):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs len(self.train_data) % self.batch_size != 0)) * self.n_epochs
self.model = _move_model_to_device(self.model, device=device) self.model = _move_model_to_device(self.model, device=device)
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
elif isinstance(optimizer, Optimizer): elif isinstance(optimizer, Optimizer):


+ 72
- 60
fastNLP/core/utils.py View File

@@ -1,20 +1,25 @@
""" """
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
""" """
__all__ = ["cache_results", "seq_len_to_mask"]
import _pickle import _pickle
import inspect import inspect
import numpy as np
import os import os
import torch
import torch.nn as nn
import warnings import warnings

from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple


import numpy as np
import torch
from torch import nn
__all__ = [
"cache_results",
"seq_len_to_mask"
]


_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])
'varargs'])



def _prepare_cache_filepath(filepath): def _prepare_cache_filepath(filepath):
""" """
@@ -83,11 +88,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
:param int _verbose: 是否打印cache的信息。 :param int _verbose: 是否打印cache的信息。
:return: :return:
""" """
def wrapper_(func): def wrapper_(func):
signature = inspect.signature(func) signature = inspect.signature(func)
for key, _ in signature.parameters.items(): for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose'): if key in ('_cache_fp', '_refresh', '_verbose'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if '_cache_fp' in kwargs: if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp') cache_filepath = kwargs.pop('_cache_fp')
@@ -95,7 +102,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
else: else:
cache_filepath = _cache_fp cache_filepath = _cache_fp
if '_refresh' in kwargs: if '_refresh' in kwargs:
refresh = kwargs.pop('_refresh')
refresh = kwargs.pop('_refresh')
assert isinstance(refresh, bool), "_refresh can only be bool." assert isinstance(refresh, bool), "_refresh can only be bool."
else: else:
refresh = _refresh refresh = _refresh
@@ -105,16 +112,16 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
else: else:
verbose = _verbose verbose = _verbose
refresh_flag = True refresh_flag = True
if cache_filepath is not None and refresh is False: if cache_filepath is not None and refresh is False:
# load data # load data
if os.path.exists(cache_filepath): if os.path.exists(cache_filepath):
with open(cache_filepath, 'rb') as f: with open(cache_filepath, 'rb') as f:
results = _pickle.load(f) results = _pickle.load(f)
if verbose==1:
if verbose == 1:
print("Read cache from {}.".format(cache_filepath)) print("Read cache from {}.".format(cache_filepath))
refresh_flag = False refresh_flag = False
if refresh_flag: if refresh_flag:
results = func(*args, **kwargs) results = func(*args, **kwargs)
if cache_filepath is not None: if cache_filepath is not None:
@@ -124,11 +131,14 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
with open(cache_filepath, 'wb') as f: with open(cache_filepath, 'wb') as f:
_pickle.dump(results, f) _pickle.dump(results, f)
print("Save cache to {}.".format(cache_filepath)) print("Save cache to {}.".format(cache_filepath))
return results return results
return wrapper return wrapper
return wrapper_ return wrapper_



# def save_pickle(obj, pickle_path, file_name): # def save_pickle(obj, pickle_path, file_name):
# """Save an object into a pickle file. # """Save an object into a pickle file.
# #
@@ -196,7 +206,7 @@ def _move_model_to_device(model, device):
""" """
if isinstance(model, torch.nn.parallel.DistributedDataParallel): if isinstance(model, torch.nn.parallel.DistributedDataParallel):
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.")
if device is None: if device is None:
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
model.cuda() model.cuda()
@@ -205,34 +215,35 @@ def _move_model_to_device(model, device):
if not torch.cuda.is_available() and ( if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")
if isinstance(device, int): if isinstance(device, int):
assert device>-1, "device can only be non-negative integer"
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(),
device)
assert device > -1, "device can only be non-negative integer"
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format(
torch.cuda.device_count(),
device)
device = torch.device('cuda:{}'.format(device)) device = torch.device('cuda:{}'.format(device))
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
if device.type == 'cuda' and device.index is not None: if device.type == 'cuda' and device.index is not None:
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
elif isinstance(device, torch.device): elif isinstance(device, torch.device):
if device.type == 'cuda' and device.index is not None: if device.type == 'cuda' and device.index is not None:
assert device.index<torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
assert device.index < torch.cuda.device_count(), "Only has {} gpus, cannot use device cuda:{}.".format(
torch.cuda.device_count(),
device)
elif isinstance(device, list): elif isinstance(device, list):
types = set([type(d) for d in device]) types = set([type(d) for d in device])
assert len(types)==1, "Mixed type in device, only `int` allowed."
assert len(types) == 1, "Mixed type in device, only `int` allowed."
assert list(types)[0] == int, "Only int supported for multiple devices." assert list(types)[0] == int, "Only int supported for multiple devices."
assert len(set(device))==len(device), "Duplicated device id found in device."
assert len(set(device)) == len(device), "Duplicated device id found in device."
for d in device: for d in device:
assert d>-1, "Only non-negative device id allowed."
if len(device)>1:
assert d > -1, "Only non-negative device id allowed."
if len(device) > 1:
output_device = device[0] output_device = device[0]
model = nn.DataParallel(model, device_ids=device, output_device=output_device) model = nn.DataParallel(model, device_ids=device, output_device=output_device)
device = torch.device(device[0]) device = torch.device(device[0])
@@ -250,9 +261,9 @@ def _get_model_device(model):
:return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。 :return: torch.device,None 如果返回值为None,说明这个模型没有任何参数。
""" """
assert isinstance(model, nn.Module) assert isinstance(model, nn.Module)
parameters = list(model.parameters()) parameters = list(model.parameters())
if len(parameters)==0:
if len(parameters) == 0:
return None return None
else: else:
return parameters[0].device return parameters[0].device
@@ -407,7 +418,7 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
if not isinstance(device, torch.device): if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") raise TypeError(f"device must be `torch.device`, got `{type(device)}`")
for arg in args: for arg in args:
if isinstance(arg, dict): if isinstance(arg, dict):
for key, value in arg.items(): for key, value in arg.items():
@@ -422,10 +433,10 @@ class _CheckError(Exception):


_CheckError. Used in losses.LossBase, metrics.MetricBase. _CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """
def __init__(self, check_res: _CheckRes, func_signature: str): def __init__(self, check_res: _CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`'] errs = [f'Problems occurred when calling `{func_signature}`']
if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing: if check_res.missing:
@@ -434,9 +445,9 @@ class _CheckError(Exception):
errs.append(f"\tduplicated param: {check_res.duplicated}") errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused: if check_res.unused:
errs.append(f"\tunused param: {check_res.unused}") errs.append(f"\tunused param: {check_res.unused}")
Exception.__init__(self, '\n'.join(errs)) Exception.__init__(self, '\n'.join(errs))
self.check_res = check_res self.check_res = check_res
self.func_signature = func_signature self.func_signature = func_signature


@@ -456,7 +467,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
# if check_res.varargs: # if check_res.varargs:
# errs.append(f"\tvarargs: *{check_res.varargs}") # errs.append(f"\tvarargs: *{check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.unused: if check_res.unused:
for _unused in check_res.unused: for _unused in check_res.unused:
if _unused in target_dict: if _unused in target_dict:
@@ -466,8 +477,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
if _unused_field: if _unused_field:
unuseds.append(f"\tunused field: {_unused_field}") unuseds.append(f"\tunused field: {_unused_field}")
if _unused_param: if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
module_name = func_signature.split('.')[0] module_name = func_signature.split('.')[0]
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}") errs.append(f"\tmissing param: {check_res.missing}")
@@ -488,14 +499,14 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
mapped_missing.append(_miss) mapped_missing.append(_miss)
else: else:
unmapped_missing.append(_miss) unmapped_missing.append(_miss)
for _miss in mapped_missing + unmapped_missing: for _miss in mapped_missing + unmapped_missing:
if _miss in dataset: if _miss in dataset:
suggestions.append(f"Set `{_miss}` as target.") suggestions.append(f"Set `{_miss}` as target.")
else: else:
_tmp = '' _tmp = ''
if check_res.unused: if check_res.unused:
_tmp = f"Check key assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}."
_tmp = f"Check key assignment for `{input_func_map.get(_miss,_miss)}` when initialize {module_name}."
if _tmp: if _tmp:
_tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.'
else: else:
@@ -513,25 +524,25 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
# else: # else:
# _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' # _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.'
# suggestions.append(_tmp) # suggestions.append(_tmp)
if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.") errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of " suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
if len(errs)>0:
if len(errs) > 0:
errs.extend(unuseds) errs.extend(unuseds)
elif check_level == STRICT_CHECK_LEVEL: elif check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds) errs.extend(unuseds)
if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}') errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions) > 1: if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
if idx>0:
if idx > 0:
sugg_str += '\t\t\t' sugg_str += '\t\t\t'
sugg_str += f'({idx+1}). {sugg}\n'
sugg_str += f'({idx + 1}). {sugg}\n'
sugg_str = sugg_str[:-1] sugg_str = sugg_str[:-1]
else: else:
sugg_str += suggestions[0] sugg_str += suggestions[0]
@@ -546,14 +557,15 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
_unused_warn = f'{check_res.unused} is not used by {module_name}.' _unused_warn = f'{check_res.unused} is not used by {module_name}.'
warnings.warn(message=_unused_warn) warnings.warn(message=_unused_warn)



def _check_forward_error(forward_func, batch_x, dataset, check_level): def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x) check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = _get_func_signature(forward_func) func_signature = _get_func_signature(forward_func)
errs = [] errs = []
suggestions = [] suggestions = []
_unused = [] _unused = []
# if check_res.varargs: # if check_res.varargs:
# errs.append(f"\tvarargs: {check_res.varargs}") # errs.append(f"\tvarargs: {check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
@@ -574,20 +586,20 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`." # f"rename the field in `unused field:`."
suggestions.append(_tmp) suggestions.append(_tmp)
if check_res.unused: if check_res.unused:
_unused = [f"\tunused field: {check_res.unused}"] _unused = [f"\tunused field: {check_res.unused}"]
if len(errs)>0:
if len(errs) > 0:
errs.extend(_unused) errs.extend(_unused)
elif check_level == STRICT_CHECK_LEVEL: elif check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused) errs.extend(_unused)
if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}') errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
if len(suggestions) > 1: if len(suggestions) > 1:
for idx, sugg in enumerate(suggestions): for idx, sugg in enumerate(suggestions):
sugg_str += f'({idx+1}). {sugg}'
sugg_str += f'({idx + 1}). {sugg}'
else: else:
sugg_str += suggestions[0] sugg_str += suggestions[0]
err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str err_str = '\n' + '\n'.join(errs) + '\n\tSuggestion: ' + sugg_str
@@ -622,8 +634,8 @@ def seq_len_to_mask(seq_len):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
max_len = int(seq_len.max()) max_len = int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len<seq_len.reshape(-1, 1)
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
elif isinstance(seq_len, torch.Tensor): elif isinstance(seq_len, torch.Tensor):
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
batch_size = seq_len.size(0) batch_size = seq_len.size(0)
@@ -632,7 +644,7 @@ def seq_len_to_mask(seq_len):
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
else: else:
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")
return mask return mask




@@ -640,24 +652,24 @@ class _pseudo_tqdm:
""" """
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
pass pass
def write(self, info): def write(self, info):
print(info) print(info)
def set_postfix_str(self, info): def set_postfix_str(self, info):
print(info) print(info)
def __getattr__(self, item): def __getattr__(self, item):
def pass_func(*args, **kwargs): def pass_func(*args, **kwargs):
pass pass
return pass_func return pass_func
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
del self del self

+ 5
- 0
fastNLP/core/vocabulary.py View File

@@ -1,7 +1,12 @@
from functools import wraps from functools import wraps
from collections import Counter from collections import Counter

from .dataset import DataSet from .dataset import DataSet


__all__ = [
"Vocabulary"
]



def _check_build_vocab(func): def _check_build_vocab(func):
"""A decorator to make sure the indexing is built before used. """A decorator to make sure the indexing is built before used.


Loading…
Cancel
Save