You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nlp.py 40 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import collections
  4. import os
  5. import random
  6. import re
  7. import subprocess
  8. import tempfile
  9. import warnings
  10. from collections import Counter
  11. import numpy as np
  12. import six as _six
  13. import tensorflow as tf
  14. from six.moves import urllib, xrange
  15. from tensorflow.python.platform import gfile
  16. import tensorlayer as tl
  17. from tensorlayer.lazy_imports import LazyImport
  18. nltk = LazyImport("nltk")
  19. __all__ = [
  20. 'generate_skip_gram_batch',
  21. 'sample',
  22. 'sample_top',
  23. 'SimpleVocabulary',
  24. 'Vocabulary',
  25. 'process_sentence',
  26. 'create_vocab',
  27. 'simple_read_words',
  28. 'read_words',
  29. 'read_analogies_file',
  30. 'build_vocab',
  31. 'build_reverse_dictionary',
  32. 'build_words_dataset',
  33. 'words_to_word_ids',
  34. 'word_ids_to_words',
  35. 'save_vocab',
  36. 'basic_tokenizer',
  37. 'create_vocabulary',
  38. 'initialize_vocabulary',
  39. 'sentence_to_token_ids',
  40. 'data_to_token_ids',
  41. 'moses_multi_bleu',
  42. ]
  43. def as_bytes(bytes_or_text, encoding='utf-8'):
  44. """Converts either bytes or unicode to `bytes`, using utf-8 encoding for text.
  45. Args:
  46. bytes_or_text: A `bytes`, `str`, or `unicode` object.
  47. encoding: A string indicating the charset for encoding unicode.
  48. Returns:
  49. A `bytes` object.
  50. Raises:
  51. TypeError: If `bytes_or_text` is not a binary or unicode string.
  52. """
  53. if isinstance(bytes_or_text, _six.text_type):
  54. return bytes_or_text.encode(encoding)
  55. elif isinstance(bytes_or_text, bytes):
  56. return bytes_or_text
  57. else:
  58. raise TypeError('Expected binary or unicode string, got %r' % (bytes_or_text, ))
  59. def as_text(bytes_or_text, encoding='utf-8'):
  60. """Returns the given argument as a unicode string.
  61. Args:
  62. bytes_or_text: A `bytes`, `str`, or `unicode` object.
  63. encoding: A string indicating the charset for decoding unicode.
  64. Returns:
  65. A `unicode` (Python 2) or `str` (Python 3) object.
  66. Raises:
  67. TypeError: If `bytes_or_text` is not a binary or unicode string.
  68. """
  69. if isinstance(bytes_or_text, _six.text_type):
  70. return bytes_or_text
  71. elif isinstance(bytes_or_text, bytes):
  72. return bytes_or_text.decode(encoding)
  73. else:
  74. raise TypeError('Expected binary or unicode string, got %r' % bytes_or_text)
  75. def generate_skip_gram_batch(data, batch_size, num_skips, skip_window, data_index=0):
  76. """Generate a training batch for the Skip-Gram model.
  77. See `Word2Vec example <https://github.com/tensorlayer/tensorlayer/blob/master/example/tutorial_word2vec_basic.py>`__.
  78. Parameters
  79. ----------
  80. data : list of data
  81. To present context, usually a list of integers.
  82. batch_size : int
  83. Batch size to return.
  84. num_skips : int
  85. How many times to reuse an input to generate a label.
  86. skip_window : int
  87. How many words to consider left and right.
  88. data_index : int
  89. Index of the context location. This code use `data_index` to instead of yield like ``tl.iterate``.
  90. Returns
  91. -------
  92. batch : list of data
  93. Inputs.
  94. labels : list of data
  95. Labels
  96. data_index : int
  97. Index of the context location.
  98. Examples
  99. --------
  100. Setting num_skips=2, skip_window=1, use the right and left words.
  101. In the same way, num_skips=4, skip_window=2 means use the nearby 4 words.
  102. >>> data = [1,2,3,4,5,6,7,8,9,10,11]
  103. >>> batch, labels, data_index = tl.nlp.generate_skip_gram_batch(data=data, batch_size=8, num_skips=2, skip_window=1, data_index=0)
  104. >>> print(batch)
  105. [2 2 3 3 4 4 5 5]
  106. >>> print(labels)
  107. [[3]
  108. [1]
  109. [4]
  110. [2]
  111. [5]
  112. [3]
  113. [4]
  114. [6]]
  115. """
  116. # global data_index # you can put data_index outside the function, then
  117. # modify the global data_index in the function without return it.
  118. # note: without using yield, this code use data_index to instead.
  119. if batch_size % num_skips != 0:
  120. raise Exception("batch_size should be able to be divided by num_skips.")
  121. if num_skips > 2 * skip_window:
  122. raise Exception("num_skips <= 2 * skip_window")
  123. batch = np.ndarray(shape=(batch_size), dtype=np.int32)
  124. labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)
  125. span = 2 * skip_window + 1 # [ skip_window target skip_window ]
  126. buffer = collections.deque(maxlen=span)
  127. for _ in range(span):
  128. buffer.append(data[data_index])
  129. data_index = (data_index + 1) % len(data)
  130. for i in range(batch_size // num_skips):
  131. target = skip_window # target label at the center of the buffer
  132. targets_to_avoid = [skip_window]
  133. for j in range(num_skips):
  134. while target in targets_to_avoid:
  135. target = random.randint(0, span - 1)
  136. targets_to_avoid.append(target)
  137. batch[i * num_skips + j] = buffer[skip_window]
  138. labels[i * num_skips + j, 0] = buffer[target]
  139. buffer.append(data[data_index])
  140. data_index = (data_index + 1) % len(data)
  141. return batch, labels, data_index
  142. def sample(a=None, temperature=1.0):
  143. """Sample an index from a probability array.
  144. Parameters
  145. ----------
  146. a : list of float
  147. List of probabilities.
  148. temperature : float or None
  149. The higher the more uniform. When a = [0.1, 0.2, 0.7],
  150. - temperature = 0.7, the distribution will be sharpen [0.05048273, 0.13588945, 0.81362782]
  151. - temperature = 1.0, the distribution will be the same [0.1, 0.2, 0.7]
  152. - temperature = 1.5, the distribution will be filtered [0.16008435, 0.25411807, 0.58579758]
  153. - If None, it will be ``np.argmax(a)``
  154. Notes
  155. ------
  156. - No matter what is the temperature and input list, the sum of all probabilities will be one. Even if input list = [1, 100, 200], the sum of all probabilities will still be one.
  157. - For large vocabulary size, choice a higher temperature or ``tl.nlp.sample_top`` to avoid error.
  158. """
  159. if a is None:
  160. raise Exception("a : list of float")
  161. b = np.copy(a)
  162. try:
  163. if temperature == 1:
  164. return np.argmax(np.random.multinomial(1, a, 1))
  165. if temperature is None:
  166. return np.argmax(a)
  167. else:
  168. a = np.log(a) / temperature
  169. a = np.exp(a) / np.sum(np.exp(a))
  170. return np.argmax(np.random.multinomial(1, a, 1))
  171. except Exception:
  172. # np.set_printoptions(threshold=np.nan)
  173. # tl.logging.info(a)
  174. # tl.logging.info(np.sum(a))
  175. # tl.logging.info(np.max(a))
  176. # tl.logging.info(np.min(a))
  177. # exit()
  178. message = "For large vocabulary_size, choice a higher temperature\
  179. to avoid log error. Hint : use ``sample_top``. "
  180. warnings.warn(message, Warning)
  181. # tl.logging.info(a)
  182. # tl.logging.info(b)
  183. return np.argmax(np.random.multinomial(1, b, 1))
  184. def sample_top(a=None, top_k=10):
  185. """Sample from ``top_k`` probabilities.
  186. Parameters
  187. ----------
  188. a : list of float
  189. List of probabilities.
  190. top_k : int
  191. Number of candidates to be considered.
  192. """
  193. if a is None:
  194. a = []
  195. idx = np.argpartition(a, -top_k)[-top_k:]
  196. probs = a[idx]
  197. # tl.logging.info("new %f" % probs)
  198. probs = probs / np.sum(probs)
  199. choice = np.random.choice(idx, p=probs)
  200. return choice
  201. # old implementation
  202. # a = np.array(a)
  203. # idx = np.argsort(a)[::-1]
  204. # idx = idx[:top_k]
  205. # # a = a[idx]
  206. # probs = a[idx]
  207. # tl.logging.info("prev %f" % probs)
  208. # # probs = probs / np.sum(probs)
  209. # # choice = np.random.choice(idx, p=probs)
  210. # # return choice
  211. # Vector representations of words (Advanced) UNDOCUMENT
  212. class SimpleVocabulary(object):
  213. """Simple vocabulary wrapper, see create_vocab().
  214. Parameters
  215. ------------
  216. vocab : dictionary
  217. A dictionary that maps word to ID.
  218. unk_id : int
  219. The ID for 'unknown' word.
  220. """
  221. def __init__(self, vocab, unk_id):
  222. """Initialize the vocabulary."""
  223. self._vocab = vocab
  224. self._unk_id = unk_id
  225. def word_to_id(self, word):
  226. """Returns the integer id of a word string."""
  227. if word in self._vocab:
  228. return self._vocab[word]
  229. else:
  230. return self._unk_id
  231. class Vocabulary(object):
  232. """Create Vocabulary class from a given vocabulary and its id-word, word-id convert.
  233. See create_vocab() and ``tutorial_tfrecord3.py``.
  234. Parameters
  235. -----------
  236. vocab_file : str
  237. The file contains the vocabulary (can be created via ``tl.nlp.create_vocab``), where the words are the first whitespace-separated token on each line (other tokens are ignored) and the word ids are the corresponding line numbers.
  238. start_word : str
  239. Special word denoting sentence start.
  240. end_word : str
  241. Special word denoting sentence end.
  242. unk_word : str
  243. Special word denoting unknown words.
  244. Attributes
  245. ------------
  246. vocab : dictionary
  247. A dictionary that maps word to ID.
  248. reverse_vocab : list of int
  249. A list that maps ID to word.
  250. start_id : int
  251. For start ID.
  252. end_id : int
  253. For end ID.
  254. unk_id : int
  255. For unknown ID.
  256. pad_id : int
  257. For Padding ID.
  258. Examples
  259. -------------
  260. The vocab file looks like follow, includes `start_word` , `end_word` ...
  261. >>> a 969108
  262. >>> <S> 586368
  263. >>> </S> 586368
  264. >>> . 440479
  265. >>> on 213612
  266. >>> of 202290
  267. >>> the 196219
  268. >>> in 182598
  269. >>> with 152984
  270. >>> and 139109
  271. >>> is 97322
  272. """
  273. def __init__(self, vocab_file, start_word="<S>", end_word="</S>", unk_word="<UNK>", pad_word="<PAD>"):
  274. if not tf.io.gfile.exists(vocab_file):
  275. tl.logging.fatal("Vocab file %s not found." % vocab_file)
  276. tl.logging.info("Initializing vocabulary from file: %s" % vocab_file)
  277. with tf.io.gfile.GFile(vocab_file, mode="r") as f:
  278. reverse_vocab = list(f.readlines())
  279. reverse_vocab = [line.split()[0] for line in reverse_vocab]
  280. # assert start_word in reverse_vocab
  281. # assert end_word in reverse_vocab
  282. if start_word not in reverse_vocab: # haodong
  283. reverse_vocab.append(start_word)
  284. if end_word not in reverse_vocab:
  285. reverse_vocab.append(end_word)
  286. if unk_word not in reverse_vocab:
  287. reverse_vocab.append(unk_word)
  288. if pad_word not in reverse_vocab:
  289. reverse_vocab.append(pad_word)
  290. vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
  291. tl.logging.info("Vocabulary from %s : %s %s %s" % (vocab_file, start_word, end_word, unk_word))
  292. tl.logging.info(" vocabulary with %d words (includes start_word, end_word, unk_word)" % len(vocab))
  293. # tl.logging.info(" vocabulary with %d words" % len(vocab))
  294. self.vocab = vocab # vocab[word] = id
  295. self.reverse_vocab = reverse_vocab # reverse_vocab[id] = word
  296. # Save special word ids.
  297. self.start_id = vocab[start_word]
  298. self.end_id = vocab[end_word]
  299. self.unk_id = vocab[unk_word]
  300. self.pad_id = vocab[pad_word]
  301. tl.logging.info(" start_id: %d" % self.start_id)
  302. tl.logging.info(" end_id : %d" % self.end_id)
  303. tl.logging.info(" unk_id : %d" % self.unk_id)
  304. tl.logging.info(" pad_id : %d" % self.pad_id)
  305. def word_to_id(self, word):
  306. """Returns the integer word id of a word string."""
  307. if word in self.vocab:
  308. return self.vocab[word]
  309. else:
  310. return self.unk_id
  311. def id_to_word(self, word_id):
  312. """Returns the word string of an integer word id."""
  313. if word_id >= len(self.reverse_vocab):
  314. return self.reverse_vocab[self.unk_id]
  315. else:
  316. return self.reverse_vocab[word_id]
  317. def process_sentence(sentence, start_word="<S>", end_word="</S>"):
  318. """Seperate a sentence string into a list of string words, add start_word and end_word,
  319. see ``create_vocab()`` and ``tutorial_tfrecord3.py``.
  320. Parameters
  321. ----------
  322. sentence : str
  323. A sentence.
  324. start_word : str or None
  325. The start word. If None, no start word will be appended.
  326. end_word : str or None
  327. The end word. If None, no end word will be appended.
  328. Returns
  329. ---------
  330. list of str
  331. A list of strings that separated into words.
  332. Examples
  333. -----------
  334. >>> c = "how are you?"
  335. >>> c = tl.nlp.process_sentence(c)
  336. >>> print(c)
  337. ['<S>', 'how', 'are', 'you', '?', '</S>']
  338. Notes
  339. -------
  340. - You have to install the following package.
  341. - `Installing NLTK <http://www.nltk.org/install.html>`__
  342. - `Installing NLTK data <http://www.nltk.org/data.html>`__
  343. """
  344. if start_word is not None:
  345. process_sentence = [start_word]
  346. else:
  347. process_sentence = []
  348. process_sentence.extend(nltk.tokenize.word_tokenize(sentence.lower()))
  349. if end_word is not None:
  350. process_sentence.append(end_word)
  351. return process_sentence
  352. def create_vocab(sentences, word_counts_output_file, min_word_count=1):
  353. """Creates the vocabulary of word to word_id.
  354. See ``tutorial_tfrecord3.py``.
  355. The vocabulary is saved to disk in a text file of word counts. The id of each
  356. word in the file is its corresponding 0-based line number.
  357. Parameters
  358. ------------
  359. sentences : list of list of str
  360. All sentences for creating the vocabulary.
  361. word_counts_output_file : str
  362. The file name.
  363. min_word_count : int
  364. Minimum number of occurrences for a word.
  365. Returns
  366. --------
  367. :class:`SimpleVocabulary`
  368. The simple vocabulary object, see :class:`Vocabulary` for more.
  369. Examples
  370. --------
  371. Pre-process sentences
  372. >>> captions = ["one two , three", "four five five"]
  373. >>> processed_capts = []
  374. >>> for c in captions:
  375. >>> c = tl.nlp.process_sentence(c, start_word="<S>", end_word="</S>")
  376. >>> processed_capts.append(c)
  377. >>> print(processed_capts)
  378. ...[['<S>', 'one', 'two', ',', 'three', '</S>'], ['<S>', 'four', 'five', 'five', '</S>']]
  379. Create vocabulary
  380. >>> tl.nlp.create_vocab(processed_capts, word_counts_output_file='vocab.txt', min_word_count=1)
  381. Creating vocabulary.
  382. Total words: 8
  383. Words in vocabulary: 8
  384. Wrote vocabulary file: vocab.txt
  385. Get vocabulary object
  386. >>> vocab = tl.nlp.Vocabulary('vocab.txt', start_word="<S>", end_word="</S>", unk_word="<UNK>")
  387. INFO:tensorflow:Initializing vocabulary from file: vocab.txt
  388. [TL] Vocabulary from vocab.txt : <S> </S> <UNK>
  389. vocabulary with 10 words (includes start_word, end_word, unk_word)
  390. start_id: 2
  391. end_id: 3
  392. unk_id: 9
  393. pad_id: 0
  394. """
  395. tl.logging.info("Creating vocabulary.")
  396. counter = Counter()
  397. for c in sentences:
  398. counter.update(c)
  399. # tl.logging.info('c',c)
  400. tl.logging.info(" Total words: %d" % len(counter))
  401. # Filter uncommon words and sort by descending count.
  402. word_counts = [x for x in counter.items() if x[1] >= min_word_count]
  403. word_counts.sort(key=lambda x: x[1], reverse=True)
  404. word_counts = [("<PAD>", 0)] + word_counts # 1st id should be reserved for padding
  405. # tl.logging.info(word_counts)
  406. tl.logging.info(" Words in vocabulary: %d" % len(word_counts))
  407. # Write out the word counts file.
  408. with tf.io.gfile.GFile(word_counts_output_file, "w") as f:
  409. f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts]))
  410. tl.logging.info(" Wrote vocabulary file: %s" % word_counts_output_file)
  411. # Create the vocabulary dictionary.
  412. reverse_vocab = [x[0] for x in word_counts]
  413. unk_id = len(reverse_vocab)
  414. vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
  415. vocab = SimpleVocabulary(vocab_dict, unk_id)
  416. return vocab
  417. # Vector representations of words
  418. def simple_read_words(filename="nietzsche.txt"):
  419. """Read context from file without any preprocessing.
  420. Parameters
  421. ----------
  422. filename : str
  423. A file path (like .txt file)
  424. Returns
  425. --------
  426. str
  427. The context in a string.
  428. """
  429. with open(filename, "r") as f:
  430. words = f.read()
  431. return words
  432. def read_words(filename="nietzsche.txt", replace=None):
  433. """Read list format context from a file.
  434. For customized read_words method, see ``tutorial_generate_text.py``.
  435. Parameters
  436. ----------
  437. filename : str
  438. a file path.
  439. replace : list of str
  440. replace original string by target string.
  441. Returns
  442. -------
  443. list of str
  444. The context in a list (split using space).
  445. """
  446. if replace is None:
  447. replace = ['\n', '<eos>']
  448. with tf.io.gfile.GFile(filename, "r") as f:
  449. try: # python 3.4 or older
  450. context_list = f.read().replace(*replace).split()
  451. except Exception: # python 3.5
  452. f.seek(0)
  453. replace = [x.encode('utf-8') for x in replace]
  454. context_list = f.read().replace(*replace).split()
  455. return context_list
  456. def read_analogies_file(eval_file='questions-words.txt', word2id=None):
  457. """Reads through an analogy question file, return its id format.
  458. Parameters
  459. ----------
  460. eval_file : str
  461. The file name.
  462. word2id : dictionary
  463. a dictionary that maps word to ID.
  464. Returns
  465. --------
  466. numpy.array
  467. A ``[n_examples, 4]`` numpy array containing the analogy question's word IDs.
  468. Examples
  469. ---------
  470. The file should be in this format
  471. >>> : capital-common-countries
  472. >>> Athens Greece Baghdad Iraq
  473. >>> Athens Greece Bangkok Thailand
  474. >>> Athens Greece Beijing China
  475. >>> Athens Greece Berlin Germany
  476. >>> Athens Greece Bern Switzerland
  477. >>> Athens Greece Cairo Egypt
  478. >>> Athens Greece Canberra Australia
  479. >>> Athens Greece Hanoi Vietnam
  480. >>> Athens Greece Havana Cuba
  481. Get the tokenized analogy question data
  482. >>> words = tl.files.load_matt_mahoney_text8_dataset()
  483. >>> data, count, dictionary, reverse_dictionary = tl.nlp.build_words_dataset(words, vocabulary_size, True)
  484. >>> analogy_questions = tl.nlp.read_analogies_file(eval_file='questions-words.txt', word2id=dictionary)
  485. >>> print(analogy_questions)
  486. [[ 3068 1248 7161 1581]
  487. [ 3068 1248 28683 5642]
  488. [ 3068 1248 3878 486]
  489. ...,
  490. [ 1216 4309 19982 25506]
  491. [ 1216 4309 3194 8650]
  492. [ 1216 4309 140 312]]
  493. """
  494. if word2id is None:
  495. word2id = {}
  496. questions = []
  497. questions_skipped = 0
  498. with open(eval_file, "rb") as analogy_f:
  499. for line in analogy_f:
  500. if line.startswith(b":"): # Skip comments.
  501. continue
  502. words = line.strip().lower().split(b" ") # lowercase
  503. ids = [word2id.get(w.strip().decode()) for w in words]
  504. if None in ids or len(ids) != 4:
  505. questions_skipped += 1
  506. else:
  507. questions.append(np.array(ids))
  508. tl.logging.info("Eval analogy file: %s" % eval_file)
  509. tl.logging.info("Questions: %d", len(questions))
  510. tl.logging.info("Skipped: %d", questions_skipped)
  511. analogy_questions = np.array(questions, dtype=np.int32)
  512. return analogy_questions
  513. def build_vocab(data):
  514. """Build vocabulary.
  515. Given the context in list format.
  516. Return the vocabulary, which is a dictionary for word to id.
  517. e.g. {'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 .... }
  518. Parameters
  519. ----------
  520. data : list of str
  521. The context in list format
  522. Returns
  523. --------
  524. dictionary
  525. that maps word to unique ID. e.g. {'campbell': 2587, 'atlantic': 2247, 'aoun': 6746 .... }
  526. References
  527. ---------------
  528. - `tensorflow.models.rnn.ptb.reader <https://github.com/tensorflow/tensorflow/tree/master/tensorflow/models/rnn/ptb>`_
  529. Examples
  530. --------
  531. >>> data_path = os.getcwd() + '/simple-examples/data'
  532. >>> train_path = os.path.join(data_path, "ptb.train.txt")
  533. >>> word_to_id = build_vocab(read_txt_words(train_path))
  534. """
  535. # data = _read_words(filename)
  536. counter = collections.Counter(data)
  537. # tl.logging.info('counter %s' % counter) # dictionary for the occurrence number of each word, e.g. 'banknote': 1, 'photography': 1, 'kia': 1
  538. count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
  539. # tl.logging.info('count_pairs %s' % count_pairs) # convert dictionary to list of tuple, e.g. ('ssangyong', 1), ('swapo', 1), ('wachter', 1)
  540. words, _ = list(zip(*count_pairs))
  541. word_to_id = dict(zip(words, range(len(words))))
  542. # tl.logging.info(words) # list of words
  543. # tl.logging.info(word_to_id) # dictionary for word to id, e.g. 'campbell': 2587, 'atlantic': 2247, 'aoun': 6746
  544. return word_to_id
  545. def build_reverse_dictionary(word_to_id):
  546. """Given a dictionary that maps word to integer id.
  547. Returns a reverse dictionary that maps a id to word.
  548. Parameters
  549. ----------
  550. word_to_id : dictionary
  551. that maps word to ID.
  552. Returns
  553. --------
  554. dictionary
  555. A dictionary that maps IDs to words.
  556. """
  557. reverse_dictionary = dict(zip(word_to_id.values(), word_to_id.keys()))
  558. return reverse_dictionary
  559. def build_words_dataset(words=None, vocabulary_size=50000, printable=True, unk_key='UNK'):
  560. """Build the words dictionary and replace rare words with 'UNK' token.
  561. The most common word has the smallest integer id.
  562. Parameters
  563. ----------
  564. words : list of str or byte
  565. The context in list format. You may need to do preprocessing on the words, such as lower case, remove marks etc.
  566. vocabulary_size : int
  567. The maximum vocabulary size, limiting the vocabulary size. Then the script replaces rare words with 'UNK' token.
  568. printable : boolean
  569. Whether to print the read vocabulary size of the given words.
  570. unk_key : str
  571. Represent the unknown words.
  572. Returns
  573. --------
  574. data : list of int
  575. The context in a list of ID.
  576. count : list of tuple and list
  577. Pair words and IDs.
  578. - count[0] is a list : the number of rare words
  579. - count[1:] are tuples : the number of occurrence of each word
  580. - e.g. [['UNK', 418391], (b'the', 1061396), (b'of', 593677), (b'and', 416629), (b'one', 411764)]
  581. dictionary : dictionary
  582. It is `word_to_id` that maps word to ID.
  583. reverse_dictionary : a dictionary
  584. It is `id_to_word` that maps ID to word.
  585. Examples
  586. --------
  587. >>> words = tl.files.load_matt_mahoney_text8_dataset()
  588. >>> vocabulary_size = 50000
  589. >>> data, count, dictionary, reverse_dictionary = tl.nlp.build_words_dataset(words, vocabulary_size)
  590. References
  591. -----------------
  592. - `tensorflow/examples/tutorials/word2vec/word2vec_basic.py <https://github.com/tensorflow/tensorflow/blob/r0.7/tensorflow/examples/tutorials/word2vec/word2vec_basic.py>`__
  593. """
  594. if words is None:
  595. raise Exception("words : list of str or byte")
  596. count = [[unk_key, -1]]
  597. count.extend(collections.Counter(words).most_common(vocabulary_size - 1))
  598. dictionary = dict()
  599. for word, _ in count:
  600. dictionary[word] = len(dictionary)
  601. data = list()
  602. unk_count = 0
  603. for word in words:
  604. if word in dictionary:
  605. index = dictionary[word]
  606. else:
  607. index = 0 # dictionary['UNK']
  608. unk_count += 1
  609. data.append(index)
  610. count[0][1] = unk_count
  611. reverse_dictionary = dict(zip(dictionary.values(), dictionary.keys()))
  612. if printable:
  613. tl.logging.info('Real vocabulary size %d' % len(collections.Counter(words).keys()))
  614. tl.logging.info('Limited vocabulary size {}'.format(vocabulary_size))
  615. if len(collections.Counter(words).keys()) < vocabulary_size:
  616. raise Exception(
  617. "len(collections.Counter(words).keys()) >= vocabulary_size , the limited vocabulary_size must be less than or equal to the read vocabulary_size"
  618. )
  619. return data, count, dictionary, reverse_dictionary
  620. def words_to_word_ids(data=None, word_to_id=None, unk_key='UNK'):
  621. """Convert a list of string (words) to IDs.
  622. Parameters
  623. ----------
  624. data : list of string or byte
  625. The context in list format
  626. word_to_id : a dictionary
  627. that maps word to ID.
  628. unk_key : str
  629. Represent the unknown words.
  630. Returns
  631. --------
  632. list of int
  633. A list of IDs to represent the context.
  634. Examples
  635. --------
  636. >>> words = tl.files.load_matt_mahoney_text8_dataset()
  637. >>> vocabulary_size = 50000
  638. >>> data, count, dictionary, reverse_dictionary = tl.nlp.build_words_dataset(words, vocabulary_size, True)
  639. >>> context = [b'hello', b'how', b'are', b'you']
  640. >>> ids = tl.nlp.words_to_word_ids(words, dictionary)
  641. >>> context = tl.nlp.word_ids_to_words(ids, reverse_dictionary)
  642. >>> print(ids)
  643. [6434, 311, 26, 207]
  644. >>> print(context)
  645. [b'hello', b'how', b'are', b'you']
  646. References
  647. ---------------
  648. - `tensorflow.models.rnn.ptb.reader <https://github.com/tensorflow/tensorflow/tree/master/tensorflow/models/rnn/ptb>`__
  649. """
  650. if data is None:
  651. raise Exception("data : list of string or byte")
  652. if word_to_id is None:
  653. raise Exception("word_to_id : a dictionary")
  654. # if isinstance(data[0], six.string_types):
  655. # tl.logging.info(type(data[0]))
  656. # # exit()
  657. # tl.logging.info(data[0])
  658. # tl.logging.info(word_to_id)
  659. # return [word_to_id[str(word)] for word in data]
  660. # else:
  661. word_ids = []
  662. for word in data:
  663. if word_to_id.get(word) is not None:
  664. word_ids.append(word_to_id[word])
  665. else:
  666. word_ids.append(word_to_id[unk_key])
  667. return word_ids
  668. # return [word_to_id[word] for word in data] # this one
  669. # if isinstance(data[0], str):
  670. # # tl.logging.info('is a string object')
  671. # return [word_to_id[word] for word in data]
  672. # else:#if isinstance(s, bytes):
  673. # # tl.logging.info('is a unicode object')
  674. # # tl.logging.info(data[0])
  675. # return [word_to_id[str(word)] f
  676. def word_ids_to_words(data, id_to_word):
  677. """Convert a list of integer to strings (words).
  678. Parameters
  679. ----------
  680. data : list of int
  681. The context in list format.
  682. id_to_word : dictionary
  683. a dictionary that maps ID to word.
  684. Returns
  685. --------
  686. list of str
  687. A list of string or byte to represent the context.
  688. Examples
  689. ---------
  690. see ``tl.nlp.words_to_word_ids``
  691. """
  692. return [id_to_word[i] for i in data]
  693. def save_vocab(count=None, name='vocab.txt'):
  694. """Save the vocabulary to a file so the model can be reloaded.
  695. Parameters
  696. ----------
  697. count : a list of tuple and list
  698. count[0] is a list : the number of rare words,
  699. count[1:] are tuples : the number of occurrence of each word,
  700. e.g. [['UNK', 418391], (b'the', 1061396), (b'of', 593677), (b'and', 416629), (b'one', 411764)]
  701. Examples
  702. ---------
  703. >>> words = tl.files.load_matt_mahoney_text8_dataset()
  704. >>> vocabulary_size = 50000
  705. >>> data, count, dictionary, reverse_dictionary = tl.nlp.build_words_dataset(words, vocabulary_size, True)
  706. >>> tl.nlp.save_vocab(count, name='vocab_text8.txt')
  707. >>> vocab_text8.txt
  708. UNK 418391
  709. the 1061396
  710. of 593677
  711. and 416629
  712. one 411764
  713. in 372201
  714. a 325873
  715. to 316376
  716. """
  717. if count is None:
  718. count = []
  719. pwd = os.getcwd()
  720. vocabulary_size = len(count)
  721. with open(os.path.join(pwd, name), "w") as f:
  722. for i in xrange(vocabulary_size):
  723. f.write("%s %d\n" % (as_text(count[i][0]), count[i][1]))
  724. tl.logging.info("%d vocab saved to %s in %s" % (vocabulary_size, name, pwd))
  725. # Functions for translation
  726. def basic_tokenizer(sentence, _WORD_SPLIT=re.compile(b"([.,!?\"':;)(])")):
  727. """Very basic tokenizer: split the sentence into a list of tokens.
  728. Parameters
  729. -----------
  730. sentence : tensorflow.python.platform.gfile.GFile Object
  731. _WORD_SPLIT : regular expression for word spliting.
  732. Examples
  733. --------
  734. >>> see create_vocabulary
  735. >>> from tensorflow.python.platform import gfile
  736. >>> train_path = "wmt/giga-fren.release2"
  737. >>> with gfile.GFile(train_path + ".en", mode="rb") as f:
  738. >>> for line in f:
  739. >>> tokens = tl.nlp.basic_tokenizer(line)
  740. >>> tl.logging.info(tokens)
  741. >>> exit()
  742. [b'Changing', b'Lives', b'|', b'Changing', b'Society', b'|', b'How',
  743. b'It', b'Works', b'|', b'Technology', b'Drives', b'Change', b'Home',
  744. b'|', b'Concepts', b'|', b'Teachers', b'|', b'Search', b'|', b'Overview',
  745. b'|', b'Credits', b'|', b'HHCC', b'Web', b'|', b'Reference', b'|',
  746. b'Feedback', b'Virtual', b'Museum', b'of', b'Canada', b'Home', b'Page']
  747. References
  748. ----------
  749. - Code from ``/tensorflow/models/rnn/translation/data_utils.py``
  750. """
  751. words = []
  752. sentence = as_bytes(sentence)
  753. for space_separated_fragment in sentence.strip().split():
  754. words.extend(re.split(_WORD_SPLIT, space_separated_fragment))
  755. return [w for w in words if w]
  756. def create_vocabulary(
  757. vocabulary_path, data_path, max_vocabulary_size, tokenizer=None, normalize_digits=True,
  758. _DIGIT_RE=re.compile(br"\d"), _START_VOCAB=None
  759. ):
  760. r"""Create vocabulary file (if it does not exist yet) from data file.
  761. Data file is assumed to contain one sentence per line. Each sentence is
  762. tokenized and digits are normalized (if normalize_digits is set).
  763. Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
  764. We write it to vocabulary_path in a one-token-per-line format, so that later
  765. token in the first line gets id=0, second line gets id=1, and so on.
  766. Parameters
  767. -----------
  768. vocabulary_path : str
  769. Path where the vocabulary will be created.
  770. data_path : str
  771. Data file that will be used to create vocabulary.
  772. max_vocabulary_size : int
  773. Limit on the size of the created vocabulary.
  774. tokenizer : function
  775. A function to use to tokenize each data sentence. If None, basic_tokenizer will be used.
  776. normalize_digits : boolean
  777. If true, all digits are replaced by `0`.
  778. _DIGIT_RE : regular expression function
  779. Default is ``re.compile(br"\d")``.
  780. _START_VOCAB : list of str
  781. The pad, go, eos and unk token, default is ``[b"_PAD", b"_GO", b"_EOS", b"_UNK"]``.
  782. References
  783. ----------
  784. - Code from ``/tensorflow/models/rnn/translation/data_utils.py``
  785. """
  786. if _START_VOCAB is None:
  787. _START_VOCAB = [b"_PAD", b"_GO", b"_EOS", b"_UNK"]
  788. if not gfile.Exists(vocabulary_path):
  789. tl.logging.info("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
  790. vocab = {}
  791. with gfile.GFile(data_path, mode="rb") as f:
  792. counter = 0
  793. for line in f:
  794. counter += 1
  795. if counter % 100000 == 0:
  796. tl.logging.info(" processing line %d" % counter)
  797. tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
  798. for w in tokens:
  799. word = re.sub(_DIGIT_RE, b"0", w) if normalize_digits else w
  800. if word in vocab:
  801. vocab[word] += 1
  802. else:
  803. vocab[word] = 1
  804. vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
  805. if len(vocab_list) > max_vocabulary_size:
  806. vocab_list = vocab_list[:max_vocabulary_size]
  807. with gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
  808. for w in vocab_list:
  809. vocab_file.write(w + b"\n")
  810. else:
  811. tl.logging.info("Vocabulary %s from data %s exists" % (vocabulary_path, data_path))
  812. def initialize_vocabulary(vocabulary_path):
  813. """Initialize vocabulary from file, return the `word_to_id` (dictionary)
  814. and `id_to_word` (list).
  815. We assume the vocabulary is stored one-item-per-line, so a file will result in a vocabulary {"dog": 0, "cat": 1}, and this function will also return the reversed-vocabulary ["dog", "cat"].
  816. Parameters
  817. -----------
  818. vocabulary_path : str
  819. Path to the file containing the vocabulary.
  820. Returns
  821. --------
  822. vocab : dictionary
  823. a dictionary that maps word to ID.
  824. rev_vocab : list of int
  825. a list that maps ID to word.
  826. Examples
  827. ---------
  828. >>> Assume 'test' contains
  829. dog
  830. cat
  831. bird
  832. >>> vocab, rev_vocab = tl.nlp.initialize_vocabulary("test")
  833. >>> print(vocab)
  834. >>> {b'cat': 1, b'dog': 0, b'bird': 2}
  835. >>> print(rev_vocab)
  836. >>> [b'dog', b'cat', b'bird']
  837. Raises
  838. -------
  839. ValueError : if the provided vocabulary_path does not exist.
  840. """
  841. if gfile.Exists(vocabulary_path):
  842. rev_vocab = []
  843. with gfile.GFile(vocabulary_path, mode="rb") as f:
  844. rev_vocab.extend(f.readlines())
  845. rev_vocab = [as_bytes(line.strip()) for line in rev_vocab]
  846. vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
  847. return vocab, rev_vocab
  848. else:
  849. raise ValueError("Vocabulary file %s not found.", vocabulary_path)
  850. def sentence_to_token_ids(
  851. sentence, vocabulary, tokenizer=None, normalize_digits=True, UNK_ID=3, _DIGIT_RE=re.compile(br"\d")
  852. ):
  853. """Convert a string to list of integers representing token-ids.
  854. For example, a sentence "I have a dog" may become tokenized into
  855. ["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
  856. "a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
  857. Parameters
  858. -----------
  859. sentence : tensorflow.python.platform.gfile.GFile Object
  860. The sentence in bytes format to convert to token-ids, see ``basic_tokenizer()`` and ``data_to_token_ids()``.
  861. vocabulary : dictionary
  862. Mmapping tokens to integers.
  863. tokenizer : function
  864. A function to use to tokenize each sentence. If None, ``basic_tokenizer`` will be used.
  865. normalize_digits : boolean
  866. If true, all digits are replaced by 0.
  867. Returns
  868. --------
  869. list of int
  870. The token-ids for the sentence.
  871. """
  872. if tokenizer:
  873. words = tokenizer(sentence)
  874. else:
  875. words = basic_tokenizer(sentence)
  876. if not normalize_digits:
  877. return [vocabulary.get(w, UNK_ID) for w in words]
  878. # Normalize digits by 0 before looking words up in the vocabulary.
  879. return [vocabulary.get(re.sub(_DIGIT_RE, b"0", w), UNK_ID) for w in words]
  880. def data_to_token_ids(
  881. data_path, target_path, vocabulary_path, tokenizer=None, normalize_digits=True, UNK_ID=3,
  882. _DIGIT_RE=re.compile(br"\d")
  883. ):
  884. """Tokenize data file and turn into token-ids using given vocabulary file.
  885. This function loads data line-by-line from data_path, calls the above
  886. sentence_to_token_ids, and saves the result to target_path. See comment
  887. for sentence_to_token_ids on the details of token-ids format.
  888. Parameters
  889. -----------
  890. data_path : str
  891. Path to the data file in one-sentence-per-line format.
  892. target_path : str
  893. Path where the file with token-ids will be created.
  894. vocabulary_path : str
  895. Path to the vocabulary file.
  896. tokenizer : function
  897. A function to use to tokenize each sentence. If None, ``basic_tokenizer`` will be used.
  898. normalize_digits : boolean
  899. If true, all digits are replaced by 0.
  900. References
  901. ----------
  902. - Code from ``/tensorflow/models/rnn/translation/data_utils.py``
  903. """
  904. if not gfile.Exists(target_path):
  905. tl.logging.info("Tokenizing data in %s" % data_path)
  906. vocab, _ = initialize_vocabulary(vocabulary_path)
  907. with gfile.GFile(data_path, mode="rb") as data_file:
  908. with gfile.GFile(target_path, mode="w") as tokens_file:
  909. counter = 0
  910. for line in data_file:
  911. counter += 1
  912. if counter % 100000 == 0:
  913. tl.logging.info(" tokenizing line %d" % counter)
  914. token_ids = sentence_to_token_ids(
  915. line, vocab, tokenizer, normalize_digits, UNK_ID=UNK_ID, _DIGIT_RE=_DIGIT_RE
  916. )
  917. tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
  918. else:
  919. tl.logging.info("Target path %s exists" % target_path)
  920. def moses_multi_bleu(hypotheses, references, lowercase=False):
  921. """Calculate the bleu score for hypotheses and references
  922. using the MOSES ulti-bleu.perl script.
  923. Parameters
  924. ------------
  925. hypotheses : numpy.array.string
  926. A numpy array of strings where each string is a single example.
  927. references : numpy.array.string
  928. A numpy array of strings where each string is a single example.
  929. lowercase : boolean
  930. If True, pass the "-lc" flag to the multi-bleu script
  931. Examples
  932. ---------
  933. >>> hypotheses = ["a bird is flying on the sky"]
  934. >>> references = ["two birds are flying on the sky", "a bird is on the top of the tree", "an airplane is on the sky",]
  935. >>> score = tl.nlp.moses_multi_bleu(hypotheses, references)
  936. Returns
  937. --------
  938. float
  939. The BLEU score
  940. References
  941. ----------
  942. - `Google/seq2seq/metric/bleu <https://github.com/google/seq2seq>`__
  943. """
  944. if np.size(hypotheses) == 0:
  945. return np.float32(0.0)
  946. # Get MOSES multi-bleu script
  947. try:
  948. multi_bleu_path, _ = urllib.request.urlretrieve(
  949. "https://raw.githubusercontent.com/moses-smt/mosesdecoder/"
  950. "master/scripts/generic/multi-bleu.perl"
  951. )
  952. os.chmod(multi_bleu_path, 0o755)
  953. except Exception: # pylint: disable=W0702
  954. tl.logging.info("Unable to fetch multi-bleu.perl script, using local.")
  955. metrics_dir = os.path.dirname(os.path.realpath(__file__))
  956. bin_dir = os.path.abspath(os.path.join(metrics_dir, "..", "..", "bin"))
  957. multi_bleu_path = os.path.join(bin_dir, "tools/multi-bleu.perl")
  958. # Dump hypotheses and references to tempfiles
  959. hypothesis_file = tempfile.NamedTemporaryFile()
  960. hypothesis_file.write("\n".join(hypotheses).encode("utf-8"))
  961. hypothesis_file.write(b"\n")
  962. hypothesis_file.flush()
  963. reference_file = tempfile.NamedTemporaryFile()
  964. reference_file.write("\n".join(references).encode("utf-8"))
  965. reference_file.write(b"\n")
  966. reference_file.flush()
  967. # Calculate BLEU using multi-bleu script
  968. with open(hypothesis_file.name, "r") as read_pred:
  969. bleu_cmd = [multi_bleu_path]
  970. if lowercase:
  971. bleu_cmd += ["-lc"]
  972. bleu_cmd += [reference_file.name]
  973. try:
  974. bleu_out = subprocess.check_output(bleu_cmd, stdin=read_pred, stderr=subprocess.STDOUT)
  975. bleu_out = bleu_out.decode("utf-8")
  976. bleu_score = re.search(r"BLEU = (.+?),", bleu_out).group(1)
  977. bleu_score = float(bleu_score)
  978. except subprocess.CalledProcessError as error:
  979. if error.output is not None:
  980. tl.logging.warning("multi-bleu.perl script returned non-zero exit code")
  981. tl.logging.warning(error.output)
  982. bleu_score = np.float32(0.0)
  983. # Close temp files
  984. hypothesis_file.close()
  985. reference_file.close()
  986. return np.float32(bleu_score)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.