You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import numpy as np
  2. import torch
  3. import torch.nn as nn
  4. import torch.utils.data.sampler as sampler
  5. class InfiniteSampler(sampler.Sampler):
  6. def __init__(self, num_examples, batch_size=1):
  7. self.num_examples = num_examples
  8. self.batch_size = batch_size
  9. def __iter__(self):
  10. while True:
  11. order = np.random.permutation(self.num_examples)
  12. for i in range(self.num_examples):
  13. yield order[i : i + self.batch_size]
  14. i += self.batch_size
  15. def __len__(self):
  16. return None
  17. def gen_mappings(chars, symbs):
  18. n_char = len(chars)
  19. n_symbs = len(symbs)
  20. if n_char != n_symbs:
  21. print("Characters and symbols size dosen't match.")
  22. return
  23. from itertools import permutations
  24. mappings = []
  25. # returned mappings
  26. perms = permutations(symbs)
  27. for p in perms:
  28. if p.index(1) < p.index(0):
  29. continue
  30. mappings.append(dict(zip(chars, list(p))))
  31. return mappings
  32. def mapping_res(original_pred_res, m):
  33. return [[m[symbol] for symbol in formula] for formula in original_pred_res]
  34. def remapping_res(pred_res, m):
  35. remapping = {}
  36. for key, value in m.items():
  37. remapping[value] = key
  38. return [[remapping[symbol] for symbol in formula] for formula in pred_res]
  39. def extract_feature(img):
  40. extractor = nn.AvgPool2d(2, stride=2)
  41. feature_map = np.array(extractor(torch.Tensor(img)))
  42. return feature_map.reshape((-1,))
  43. def reduce_dimension(data):
  44. for truth_value in [0, 1]:
  45. for equation_len in range(5, 27):
  46. equations = data[truth_value][equation_len]
  47. reduced_equations = [
  48. [extract_feature(symbol_img) for symbol_img in equation] for equation in equations
  49. ]
  50. data[truth_value][equation_len] = reduced_equations

An efficient Python toolkit for Abductive Learning (ABL), a novel paradigm that integrates machine learning and logical reasoning in a unified framework.