You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 2.9 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import numpy as np
  2. from .plog import INFO
  3. from collections import OrderedDict
  4. # for multiple predictions, modify from `learn_add.py`
  5. def flatten(l):
  6. # return [item for sublist in l for item in flatten(sublist)] if isinstance(l, (list, tuple)) else [l]
  7. if not isinstance(l[0], (list, tuple)):
  8. return l
  9. # TODO 稍微对比一下和itertools.chain.from_iterable(nested_list)的速度区别,看看哪个好
  10. return [item for sublist in l for item in sublist] if isinstance(l, (list, tuple)) else [l]
  11. # for multiple predictions, modify from `learn_add.py`
  12. def reform_idx(flatten_pred_res, save_pred_res):
  13. re = []
  14. i = 0
  15. for e in save_pred_res:
  16. re.append(flatten_pred_res[i:i + len(e)])
  17. i += len(e)
  18. return re
  19. def hamming_dist(A, B):
  20. B = np.array(B)
  21. A = np.expand_dims(A, axis = 0).repeat(axis=0, repeats=(len(B)))
  22. return np.sum(A != B, axis = 1)
  23. def confidence_dist(A, B):
  24. B = np.array(B)
  25. A = np.clip(A, 1e-9, 1)
  26. A = np.expand_dims(A, axis=0)
  27. A = A.repeat(axis=0, repeats=(len(B)))
  28. rows = np.array(range(len(B)))
  29. rows = np.expand_dims(rows, axis = 1).repeat(axis = 1, repeats = len(B[0]))
  30. cols = np.array(range(len(B[0])))
  31. cols = np.expand_dims(cols, axis = 0).repeat(axis = 0, repeats = len(B))
  32. return 1 - np.prod(A[rows, cols, B], axis = 1)
  33. def block_sample(X, Z, Y, sample_num, epoch_idx):
  34. part_num = len(X) // sample_num
  35. if part_num == 0:
  36. part_num = 1
  37. seg_idx = epoch_idx % part_num
  38. INFO("seg_idx:", seg_idx, ", part num:", part_num, ", data num:", len(X))
  39. X = X[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  40. Z = Z[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  41. Y = Y[sample_num * seg_idx : sample_num * (seg_idx + 1)]
  42. return X, Z, Y
  43. def gen_mappings(chars, symbs):
  44. n_char = len(chars)
  45. n_symbs = len(symbs)
  46. if n_char != n_symbs:
  47. print('Characters and symbols size dosen\'t match.')
  48. return
  49. from itertools import permutations
  50. mappings = []
  51. # returned mappings
  52. perms = permutations(symbs)
  53. for p in perms:
  54. mappings.append(dict(zip(chars, list(p))))
  55. return mappings
  56. def mapping_res(original_pred_res, m):
  57. return [[m[symbol] for symbol in formula] for formula in original_pred_res]
  58. def remapping_res(pred_res, m):
  59. remapping = {}
  60. for key, value in m.items():
  61. remapping[value] = key
  62. return [[remapping[symbol] for symbol in formula] for formula in pred_res]
  63. def check_equal(a, b, max_err=0):
  64. if isinstance(a, (int, float)) and isinstance(b, (int, float)):
  65. return abs(a - b) <= max_err
  66. if isinstance(a, list) and isinstance(b, list):
  67. if len(a) != len(b):
  68. return False
  69. for i in range(len(a)):
  70. if not check_equal(a[i], b[i]):
  71. return False
  72. return True
  73. else:
  74. return a == b

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.