You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

_split.py 8.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. Created on Fri Jun 24 11:13:26 2022
  5. @author: ljia
  6. """
  7. from abc import abstractmethod
  8. import numbers
  9. import warnings
  10. import numpy as np
  11. from sklearn.utils import check_random_state, check_array, column_or_1d, indexable
  12. from sklearn.utils.validation import _num_samples
  13. from sklearn.utils.multiclass import type_of_target
  14. class BaseCrossValidatorWithValid(object):
  15. """Base class for all cross-validators.
  16. Implementations must define `_iter_valid_test_masks` or `_iter_valid_stest_indices`.
  17. """
  18. def split(self, X, y=None, groups=None):
  19. """Generate indices to split data into training, valid, and test set.
  20. Parameters
  21. ----------
  22. X : array-like of shape (n_samples, n_features)
  23. Training data, where `n_samples` is the number of samples
  24. and `n_features` is the number of features.
  25. y : array-like of shape (n_samples,)
  26. The target variable for supervised learning problems.
  27. groups : array-like of shape (n_samples,), default=None
  28. Group labels for the samples used while splitting the dataset into
  29. train/test set.
  30. Yields
  31. ------
  32. train : ndarray
  33. The training set indices for that split.
  34. valid : ndarray
  35. The valid set indices for that split.
  36. test : ndarray
  37. The testing set indices for that split.
  38. """
  39. X, y, groups = indexable(X, y, groups)
  40. indices = np.arange(_num_samples(X))
  41. for valid_index, test_index in self._iter_valid_test_masks(X, y, groups):
  42. train_index = indices[np.logical_not(np.logical_or(valid_index, test_index))]
  43. valid_index = indices[valid_index]
  44. test_index = indices[test_index]
  45. yield train_index, valid_index, test_index
  46. # Since subclasses must implement either _iter_valid_test_masks or
  47. # _iter_valid_test_indices, neither can be abstract.
  48. def _iter_valid_test_masks(self, X=None, y=None, groups=None):
  49. """Generates boolean masks corresponding to valid and test sets.
  50. By default, delegates to _iter_valid_test_indices(X, y, groups)
  51. """
  52. for valid_index, test_index in self._iter_valid_test_indices(X, y, groups):
  53. valid_mask = np.zeros(_num_samples(X), dtype=bool)
  54. test_mask = np.zeros(_num_samples(X), dtype=bool)
  55. valid_mask[valid_index] = True
  56. test_mask[test_index] = True
  57. yield valid_mask, test_mask
  58. def _iter_valid_test_indices(self, X=None, y=None, groups=None):
  59. """Generates integer indices corresponding to valid and test sets."""
  60. raise NotImplementedError
  61. @abstractmethod
  62. def get_n_splits(self, X=None, y=None, groups=None):
  63. """Returns the number of splitting iterations in the cross-validator"""
  64. def __repr__(self):
  65. return _build_repr(self)
  66. class _BaseKFoldWithValid(BaseCrossValidatorWithValid):
  67. """Base class for KFold, GroupKFold, and StratifiedKFold"""
  68. @abstractmethod
  69. def __init__(self, n_splits, *, stratify, shuffle, random_state):
  70. if not isinstance(n_splits, numbers.Integral):
  71. raise ValueError(
  72. 'The number of folds must be of Integral type. '
  73. '%s of type %s was passed.' % (n_splits, type(n_splits))
  74. )
  75. n_splits = int(n_splits)
  76. if n_splits <= 2:
  77. raise ValueError(
  78. 'k-fold cross-validation requires at least one'
  79. ' train/valid/test split by setting n_splits=3 or more,'
  80. ' got n_splits={0}.'.format(n_splits)
  81. )
  82. if not isinstance(shuffle, bool):
  83. raise TypeError('shuffle must be True or False; got {0}'.format(shuffle))
  84. if not shuffle and random_state is not None: # None is the default
  85. raise ValueError(
  86. 'Setting a random_state has no effect since shuffle is '
  87. 'False. You should leave '
  88. 'random_state to its default (None), or set shuffle=True.',
  89. )
  90. self.n_splits = n_splits
  91. self.stratify = stratify
  92. self.shuffle = shuffle
  93. self.random_state = random_state
  94. def split(self, X, y=None, groups=None):
  95. """Generate indices to split data into training, valid and test set."""
  96. X, y, groups = indexable(X, y, groups)
  97. n_samples = _num_samples(X)
  98. if self.n_splits > n_samples:
  99. raise ValueError(
  100. (
  101. 'Cannot have number of splits n_splits={0} greater'
  102. ' than the number of samples: n_samples={1}.'
  103. ).format(self.n_splits, n_samples)
  104. )
  105. for train, valid, test in super().split(X, y, groups):
  106. yield train, valid, test
  107. class KFoldWithValid(_BaseKFoldWithValid):
  108. def __init__(
  109. self,
  110. n_splits=5,
  111. *,
  112. stratify=False,
  113. shuffle=False,
  114. random_state=None
  115. ):
  116. super().__init__(
  117. n_splits=n_splits,
  118. stratify=stratify,
  119. shuffle=shuffle,
  120. random_state=random_state
  121. )
  122. def _make_valid_test_folds(self, X, y=None):
  123. rng = check_random_state(self.random_state)
  124. y = np.asarray(y)
  125. type_of_target_y = type_of_target(y)
  126. allowed_target_types = ('binary', 'multiclass')
  127. if type_of_target_y not in allowed_target_types:
  128. raise ValueError(
  129. 'Supported target types are: {}. Got {!r} instead.'.format(
  130. allowed_target_types, type_of_target_y
  131. )
  132. )
  133. y = column_or_1d(y)
  134. _, y_idx, y_inv = np.unique(y, return_index=True, return_inverse=True)
  135. # y_inv encodes y according to lexicographic order. We invert y_idx to
  136. # map the classes so that they are encoded by order of appearance:
  137. # 0 represents the first label appearing in y, 1 the second, etc.
  138. _, class_perm = np.unique(y_idx, return_inverse=True)
  139. y_encoded = class_perm[y_inv]
  140. n_classes = len(y_idx)
  141. y_counts = np.bincount(y_encoded)
  142. min_groups = np.min(y_counts)
  143. if np.all(self.n_splits > y_counts):
  144. raise ValueError(
  145. "n_splits=%d cannot be greater than the"
  146. " number of members in each class." % (self.n_splits)
  147. )
  148. if self.n_splits > min_groups:
  149. warnings.warn(
  150. "The least populated class in y has only %d"
  151. " members, which is less than n_splits=%d."
  152. % (min_groups, self.n_splits),
  153. UserWarning,
  154. )
  155. # Determine the optimal number of samples from each class in each fold,
  156. # using round robin over the sorted y. (This can be done direct from
  157. # counts, but that code is unreadable.)
  158. y_order = np.sort(y_encoded)
  159. allocation = np.asarray(
  160. [
  161. np.bincount(y_order[i :: self.n_splits], minlength=n_classes)
  162. for i in range(self.n_splits)
  163. ]
  164. )
  165. # To maintain the data order dependencies as best as possible within
  166. # the stratification constraint, we assign samples from each class in
  167. # blocks (and then mess that up when shuffle=True).
  168. test_folds = np.empty(len(y), dtype='i')
  169. for k in range(n_classes):
  170. # since the kth column of allocation stores the number of samples
  171. # of class k in each test set, this generates blocks of fold
  172. # indices corresponding to the allocation for class k.
  173. folds_for_class = np.arange(self.n_splits).repeat(allocation[:, k])
  174. if self.shuffle:
  175. rng.shuffle(folds_for_class)
  176. test_folds[y_encoded == k] = folds_for_class
  177. return test_folds
  178. def _iter_valid_test_masks(self, X, y=None, groups=None):
  179. test_folds = self._make_valid_test_folds(X, y)
  180. for i in range(self.n_splits):
  181. if i + 1 < self.n_splits:
  182. j = i + 1
  183. else:
  184. j = 0
  185. yield test_folds == i, test_folds == j
  186. def split(self, X, y, groups=None):
  187. y = check_array(y, input_name='y', ensure_2d=False, dtype=None)
  188. return super().split(X, y, groups)
  189. class _RepeatedSplitsWithValid(object):
  190. def __init__(
  191. self,
  192. cv,
  193. *,
  194. n_repeats=10,
  195. random_state=None,
  196. **cvargs
  197. ):
  198. if not isinstance(n_repeats, int):
  199. raise ValueError('Number of repetitions must be of integer type.')
  200. if n_repeats <= 0:
  201. raise ValueError('Number of repetitions must be greater than 0.')
  202. self.cv = cv
  203. self.n_repeats = n_repeats
  204. self.random_state = random_state
  205. self.cvargs = cvargs
  206. def split(self, X, y=None, groups=None):
  207. n_repeats = self.n_repeats
  208. rng = check_random_state(self.random_state)
  209. for idx in range(n_repeats):
  210. cv = self.cv(random_state=rng, shuffle=True, **self.cvargs)
  211. for train_index, valid_index, test_index in cv.split(X, y, groups):
  212. yield train_index, valid_index, test_index
  213. class RepeatedKFoldWithValid(_RepeatedSplitsWithValid):
  214. def __init__(
  215. self,
  216. *,
  217. n_splits=5,
  218. n_repeats=10,
  219. stratify=False,
  220. random_state=None
  221. ):
  222. super().__init__(
  223. KFoldWithValid,
  224. n_repeats=n_repeats,
  225. stratify=stratify,
  226. random_state=random_state,
  227. n_splits=n_splits,
  228. )

A Python package for graph kernels, graph edit distances and graph pre-image problem.