You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.0 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import logging
  2. from collections import OrderedDict
  3. import json
  4. import numpy as np
  5. from . import ontology
  6. def clean_replace(s, r, t, forward=True, backward=False):
  7. def clean_replace_single(s, r, t, forward, backward, sidx=0):
  8. # idx = s[sidx:].find(r)
  9. idx = s.find(r)
  10. if idx == -1:
  11. return s, -1
  12. idx_r = idx + len(r)
  13. if backward:
  14. while idx > 0 and s[idx - 1]:
  15. idx -= 1
  16. elif idx > 0 and s[idx - 1] != ' ':
  17. return s, -1
  18. if forward:
  19. while idx_r < len(s) and (s[idx_r].isalpha()
  20. or s[idx_r].isdigit()):
  21. idx_r += 1
  22. elif idx_r != len(s) and (s[idx_r].isalpha() or s[idx_r].isdigit()):
  23. return s, -1
  24. return s[:idx] + t + s[idx_r:], idx_r
  25. # source, replace, target = s, r, t
  26. # count = 0
  27. sidx = 0
  28. while sidx != -1:
  29. s, sidx = clean_replace_single(s, r, t, forward, backward, sidx)
  30. # count += 1
  31. # print(s, sidx)
  32. # if count == 20:
  33. # print(source, '\n', replace, '\n', target)
  34. # quit()
  35. return s
  36. def py2np(list):
  37. return np.array(list)
  38. def write_dict(fn, dic):
  39. with open(fn, 'w') as f:
  40. json.dump(dic, f, indent=2)
  41. def f1_score(label_list, pred_list):
  42. tp = len([t for t in pred_list if t in label_list])
  43. fp = max(0, len(pred_list) - tp)
  44. fn = max(0, len(label_list) - tp)
  45. precision = tp / (tp + fp + 1e-10)
  46. recall = tp / (tp + fn + 1e-10)
  47. f1 = 2 * precision * recall / (precision + recall + 1e-10)
  48. return f1
  49. class MultiWOZVocab(object):
  50. def __init__(self, vocab_size=0):
  51. """
  52. vocab for multiwoz dataset
  53. """
  54. self.vocab_size = vocab_size
  55. self.vocab_size_oov = 0 # get after construction
  56. self._idx2word = {} # word + oov
  57. self._word2idx = {} # word
  58. self._freq_dict = {} # word + oov
  59. for w in [
  60. '[PAD]', '<go_r>', '[UNK]', '<go_b>', '<go_a>', '<eos_u>',
  61. '<eos_r>', '<eos_b>', '<eos_a>', '<go_d>', '<eos_d>'
  62. ]:
  63. self._absolute_add_word(w)
  64. def _absolute_add_word(self, w):
  65. idx = len(self._idx2word)
  66. self._idx2word[idx] = w
  67. self._word2idx[w] = idx
  68. def add_word(self, word):
  69. if word not in self._freq_dict:
  70. self._freq_dict[word] = 0
  71. self._freq_dict[word] += 1
  72. def has_word(self, word):
  73. return self._freq_dict.get(word)
  74. def _add_to_vocab(self, word):
  75. if word not in self._word2idx:
  76. idx = len(self._idx2word)
  77. self._idx2word[idx] = word
  78. self._word2idx[word] = idx
  79. def construct(self):
  80. l = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
  81. print('Vocabulary size including oov: %d' %
  82. (len(l) + len(self._idx2word)))
  83. if len(l) + len(self._idx2word) < self.vocab_size:
  84. logging.warning(
  85. 'actual label set smaller than that configured: {}/{}'.format(
  86. len(l) + len(self._idx2word), self.vocab_size))
  87. for word in ontology.all_domains + ['general']:
  88. word = '[' + word + ']'
  89. self._add_to_vocab(word)
  90. for word in ontology.all_acts:
  91. word = '[' + word + ']'
  92. self._add_to_vocab(word)
  93. for word in ontology.all_slots:
  94. self._add_to_vocab(word)
  95. for word in l:
  96. if word.startswith('[value_') and word.endswith(']'):
  97. self._add_to_vocab(word)
  98. for word in l:
  99. self._add_to_vocab(word)
  100. self.vocab_size_oov = len(self._idx2word)
  101. def load_vocab(self, vocab_path):
  102. self._freq_dict = json.loads(
  103. open(vocab_path + '.freq.json', 'r').read())
  104. self._word2idx = json.loads(
  105. open(vocab_path + '.word2idx.json', 'r').read())
  106. self._idx2word = {}
  107. for w, idx in self._word2idx.items():
  108. self._idx2word[idx] = w
  109. self.vocab_size_oov = len(self._idx2word)
  110. print('vocab file loaded from "' + vocab_path + '"')
  111. print('Vocabulary size including oov: %d' % (self.vocab_size_oov))
  112. def save_vocab(self, vocab_path):
  113. _freq_dict = OrderedDict(
  114. sorted(
  115. self._freq_dict.items(), key=lambda kv: kv[1], reverse=True))
  116. write_dict(vocab_path + '.word2idx.json', self._word2idx)
  117. write_dict(vocab_path + '.freq.json', _freq_dict)
  118. def encode(self, word, include_oov=True):
  119. if include_oov:
  120. if self._word2idx.get(word, None) is None:
  121. raise ValueError(
  122. 'Unknown word: %s. Vocabulary should include oovs here.' %
  123. word)
  124. return self._word2idx[word]
  125. else:
  126. word = '<unk>' if word not in self._word2idx else word
  127. return self._word2idx[word]
  128. def sentence_encode(self, word_list):
  129. return [self.encode(_) for _ in word_list]
  130. def oov_idx_map(self, idx):
  131. return 2 if idx > self.vocab_size else idx
  132. def sentence_oov_map(self, index_list):
  133. return [self.oov_idx_map(_) for _ in index_list]
  134. def decode(self, idx, indicate_oov=False):
  135. if not self._idx2word.get(idx):
  136. raise ValueError(
  137. 'Error idx: %d. Vocabulary should include oovs here.' % idx)
  138. if not indicate_oov or idx < self.vocab_size:
  139. return self._idx2word[idx]
  140. else:
  141. return self._idx2word[idx] + '(o)'
  142. def sentence_decode(self, index_list, eos=None, indicate_oov=False):
  143. l = [self.decode(_, indicate_oov) for _ in index_list]
  144. if not eos or eos not in l:
  145. return ' '.join(l)
  146. else:
  147. idx = l.index(eos)
  148. return ' '.join(l[:idx])
  149. def nl_decode(self, l, eos=None):
  150. return [self.sentence_decode(_, eos) + '\n' for _ in l]

致力于通过开放的社区合作,开源AI模型以及相关创新技术,推动基于模型即服务的生态繁荣发展