|
-
- __all__ = [
- 'get_padded_numpy_array'
- ]
-
-
- from typing import Sequence, List
- import re
- from inspect import isclass
-
- import numpy as np
- np_str_obj_array_pattern = re.compile(r'[SaUO]')
-
-
- def get_shape(batch_field:List, shape=None):
- """
- 给定 field 返回这个 field pad 完成之后的 shape 。
- 例如: [[1, 2, 3], [3]] -> [2, 3]
- [[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3]
-
- :param batch_field: list,第 0 维一般为 batch 维度。
- :param shape: 无需传入。
- :return:
- """
- if shape is None:
- shape = []
- if isinstance(batch_field, Sequence):
- num_ele = len(batch_field)
- _shape = shape + [num_ele]
- try:
- shapes = []
- if isinstance(batch_field[0], Sequence):
- for _field in batch_field:
- shapes.append(get_shape(_field, _shape))
- if len(shapes) == 1:
- max_shape = shapes[0]
- else:
- max_shape = [max(_) for _ in zip(*shapes)]
-
- return max_shape
- except IndexError: # 空的shape
- pass
- return _shape # 说明是一个空的 sequence
- else:
- return shape
-
-
- def fill_array(batch_field:List, padded_batch:np.ndarray):
- """
- 将 batch_field 中的值填入到 array 中。
-
- :param batch_field: 需要填充进入 array 中的内容
- :param padded_batch: 待填充的 np.ndarray
- :return:
- """
- if padded_batch.ndim == 2:
- for i, content_i in enumerate(batch_field):
- padded_batch[i, :len(content_i)] = content_i
- elif padded_batch.ndim == 3:
- for i, content_i in enumerate(batch_field):
- for j, content_ii in enumerate(content_i):
- padded_batch[i, j, :len(content_ii)] = content_ii
- elif padded_batch.ndim == 4:
- try: # 应该是图像,所以直接应该就 ok 了。
- padded_batch = np.array(batch_field)
- except:
- for i, content_i in enumerate(batch_field):
- for j, content_ii in enumerate(content_i):
- for k, content_iii in enumerate(content_ii):
- padded_batch[i, j, k, :len(content_iii)] = content_iii
- elif padded_batch.ndim == 1:
- padded_batch[:] = batch_field
- else:
- raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
- "report.")
- return padded_batch
-
-
- def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray:
- """
- 例如:
- [[1,2], [3]] -> np.array([[1, 2], [3, 0]])
-
- :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
- /4d(多为图片)。
- :param dtype: 目标类别是什么
- :param pad_val: pad 的 value
- :return:
- """
- shapes = get_shape(batch_field)
- array = np.full(shapes, dtype=dtype, fill_value=pad_val)
- array = fill_array(batch_field, array)
- return array
-
-
- def get_padded_nest_list(batch_field: List, pad_val=0) -> List:
- """
- 例如:
- [[1,2], [3]] -> [[1, 2], [3, 0]]
-
- :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
- /4d(多为图片)。
- :param pad_val: pad 的 value
- :return:
- """
-
- array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist()
- return array
-
-
- def is_number_or_numpy_number(dtype):
- """
- 判断 dtype 是否是数字类型,或者 numpy 的数字类型。
- is_number_or_numpy_number(type(3)) # True
- is_number_or_numpy_number(type(3.1)) # True
- is_number_or_numpy_number(type('3')) # False
- is_number_or_numpy_number(type(True)) # True
- is_number_or_numpy_number(type(np.zeros(3)[0])) # True
- is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True
- is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True
- is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False
- is_number_or_numpy_number(np.array([1, [2]]).dtype) # False
-
- :param dtype:
- :return:
- """
- if is_number(dtype):
- return True
- else:
- if isclass(dtype):
- return is_numpy_generic_class(dtype)
- elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
- return True
- return False
-
-
- def is_numpy_number_dtype(dtype):
- if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
- return True
- return False
-
-
- def is_numpy_generic_class(dtype):
- """
- 形如 np.int64,或者 np.zeros(1).dtype.type 的值
-
- :param dtype:
- :return:
- """
- if isclass(dtype) and issubclass(dtype, np.generic):
- return True
- return False
-
-
- def is_number(dtype):
- try:
- if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \
- and not is_numpy_number_dtype(dtype):
- return True
- return False
- except:
- return False
-
-
-
- if __name__ == '__main__':
- # a = [[[1]], [1, 2, 3], [3]]
- # a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
- # b = get_padded_nest_list(a)
- # print(type(b[0]))
- # print(b)
- # import torch
- print(is_number(type('a')))
- print(is_number_or_numpy_number(type(3))) # True
- print(is_number_or_numpy_number(type(3.1))) # True
- print(is_number_or_numpy_number(type('3'))) # False
- print(is_number_or_numpy_number(type(True))) # True
- print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True
- print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True
- print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True
- print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False
- print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False
|