You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tf_transformer.py 16 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442
  1. import numpy as np
  2. import tensorflow as tf
  3. from tqdm import tqdm
  4. import logging
  5. logging.basicConfig(level=logging.INFO)
  6. def ln(inputs, epsilon=1e-8, scope="ln"):
  7. '''Applies layer normalization. See https://arxiv.org/abs/1607.06450.
  8. inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`.
  9. epsilon: A floating number. A very small number for preventing ZeroDivision Error.
  10. scope: Optional scope for `variable_scope`.
  11. Returns:
  12. A tensor with the same shape and data dtype as `inputs`.
  13. '''
  14. with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
  15. inputs_shape = inputs.get_shape()
  16. params_shape = inputs_shape[-1:]
  17. mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
  18. beta = tf.get_variable("beta", params_shape,
  19. initializer=tf.zeros_initializer())
  20. gamma = tf.get_variable("gamma", params_shape,
  21. initializer=tf.ones_initializer())
  22. normalized = (inputs - mean) / ((variance + epsilon) ** (.5))
  23. outputs = gamma * normalized + beta
  24. return outputs
  25. def get_token_embeddings(vocab_size, num_units, initializer=tf.contrib.layers.xavier_initializer(), zero_pad=True):
  26. '''Constructs token embedding matrix.
  27. Note that the column of index 0's are set to zeros.
  28. vocab_size: scalar. V.
  29. num_units: embedding dimensionalty. E.
  30. zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero
  31. To apply query/key masks easily, zero pad is turned on.
  32. Returns
  33. weight variable: (V, E)
  34. '''
  35. with tf.variable_scope("shared_weight_matrix"):
  36. embeddings = tf.get_variable('weight_mat',
  37. dtype=tf.float32,
  38. shape=(vocab_size, num_units),
  39. initializer=initializer)
  40. if zero_pad:
  41. embeddings = tf.concat((tf.zeros(shape=[1, num_units]),
  42. embeddings[1:, :]), 0)
  43. return embeddings
  44. def multihead_attention(
  45. queries, keys, values,
  46. batch_size, hidden_size,
  47. num_attention_heads=8,
  48. query_act=None, key_act=None, value_act=None,
  49. attention_mask=None,
  50. attention_probs_dropout_prob=0.0,
  51. training=True, causality=False,
  52. scope="multihead_attention"):
  53. def transpose_for_scores(input_tensor):
  54. output_tensor = tf.reshape(
  55. input_tensor, [batch_size, -1, num_attention_heads, hidden_size // num_attention_heads])
  56. output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3])
  57. return output_tensor
  58. size_per_head = hidden_size // num_attention_heads
  59. with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
  60. # linear transformation
  61. query_layer = tf.layers.dense(
  62. queries, hidden_size, activation=query_act) # (N, T_q, d_model)
  63. key_layer = tf.layers.dense(
  64. keys, hidden_size, activation=key_act) # (N, T_k, d_model)
  65. value_layer = tf.layers.dense(
  66. values, hidden_size, activation=value_act) # (N, T_k, d_model)
  67. # transpose
  68. query_layer = transpose_for_scores(
  69. query_layer) # (N, h, T_q, d_model/h)
  70. key_layer = transpose_for_scores(key_layer) # (N, h, T_k, d_model/h)
  71. value_layer = transpose_for_scores(
  72. value_layer) # (N, h, T_k, d_model/h)
  73. # score
  74. attention_scores = tf.matmul(
  75. query_layer, key_layer, transpose_b=True) # (N, h, T_q, T_k)
  76. attention_scores /= size_per_head ** 0.5
  77. # mask
  78. if attention_mask is not None:
  79. attention_mask = tf.to_float(attention_mask) # (N, T_k)
  80. attention_mask = tf.reshape(
  81. attention_mask, [batch_size, 1, 1, -1]) # (N, 1, 1, T_k)
  82. attention_scores = attention_scores + \
  83. attention_mask * (-2**32+1) # (N, h, T_q, T_k)
  84. if causality:
  85. diag_vals = tf.ones_like(
  86. attention_scores[0, 0, :, :]) # (T_q, T_k)
  87. tril = tf.linalg.LinearOperatorLowerTriangular(
  88. diag_vals).to_dense() # (T_q, T_k)
  89. future_masks = tf.broadcast_to(
  90. tril, [batch_size, num_attention_heads, tril.shape[0], tril.shape[1]]) # (N, h, T_q, T_k)
  91. paddings = tf.ones_like(future_masks) * (-2**32+1)
  92. attention_scores = tf.where(
  93. tf.equal(future_masks, 0), paddings, attention_scores)
  94. # probs
  95. attention_probs = tf.nn.softmax(attention_scores) # (N, h, T_q, T_k)
  96. attention_probs = tf.layers.dropout(
  97. attention_probs, rate=attention_probs_dropout_prob, training=training)
  98. # (N, h, T_q, d_model/h)
  99. context_layer = tf.matmul(attention_probs, value_layer)
  100. context_layer = tf.transpose(
  101. context_layer, [0, 2, 1, 3]) # (N, T_q, h, d_model/h)
  102. outputs = tf.reshape(context_layer, [
  103. batch_size, -1, num_attention_heads * size_per_head]) # (N, T_q, d_model)
  104. # Residual connection
  105. outputs += queries # (N, T_q, d_model)
  106. # Normalize
  107. outputs = ln(outputs) # (N, T_q, d_model)
  108. return outputs
  109. def ff(inputs, num_units, scope="positionwise_feedforward"):
  110. '''position-wise feed forward net. See 3.3
  111. inputs: A 3d tensor with shape of [N, T, C].
  112. num_units: A list of two integers.
  113. scope: Optional scope for `variable_scope`.
  114. Returns:
  115. A 3d tensor with the same shape and dtype as inputs
  116. '''
  117. with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
  118. # Inner layer
  119. outputs = tf.layers.dense(inputs, num_units[0], activation=tf.nn.relu)
  120. # Outer layer
  121. outputs = tf.layers.dense(outputs, num_units[1])
  122. # Residual connection
  123. outputs += inputs
  124. # Normalize
  125. outputs = ln(outputs)
  126. return outputs
  127. def label_smoothing(inputs, epsilon=0.1):
  128. '''Applies label smoothing. See 5.4 and https://arxiv.org/abs/1512.00567.
  129. inputs: 3d tensor. [N, T, V], where V is the number of vocabulary.
  130. epsilon: Smoothing rate.
  131. For example,
  132. ```
  133. import tensorflow as tf
  134. inputs = tf.convert_to_tensor([[[0, 0, 1],
  135. [0, 1, 0],
  136. [1, 0, 0]],
  137. [[1, 0, 0],
  138. [1, 0, 0],
  139. [0, 1, 0]]], tf.float32)
  140. outputs = label_smoothing(inputs)
  141. with tf.Session() as sess:
  142. print(sess.run([outputs]))
  143. >>
  144. [array([[[ 0.03333334, 0.03333334, 0.93333334],
  145. [ 0.03333334, 0.93333334, 0.03333334],
  146. [ 0.93333334, 0.03333334, 0.03333334]],
  147. [[ 0.93333334, 0.03333334, 0.03333334],
  148. [ 0.93333334, 0.03333334, 0.03333334],
  149. [ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)]
  150. ```
  151. '''
  152. V = inputs.get_shape().as_list()[-1] # number of channels
  153. return ((1-epsilon) * inputs) + (epsilon / V)
  154. def positional_encoding(inputs,
  155. maxlen,
  156. masking=True,
  157. scope="positional_encoding"):
  158. '''Sinusoidal Positional_Encoding. See 3.5
  159. inputs: 3d tensor. (N, T, E)
  160. maxlen: scalar. Must be >= T
  161. masking: Boolean. If True, padding positions are set to zeros.
  162. scope: Optional scope for `variable_scope`.
  163. returns
  164. 3d tensor that has the same shape as inputs.
  165. '''
  166. E = inputs.get_shape().as_list()[-1] # static
  167. N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
  168. with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
  169. # position indices
  170. position_ind = tf.tile(tf.expand_dims(
  171. tf.range(T), 0), [N, 1]) # (N, T)
  172. # First part of the PE function: sin and cos argument
  173. position_enc = np.array([
  174. [pos / np.power(10000, (i-i % 2)/E) for i in range(E)]
  175. for pos in range(maxlen)])
  176. # Second part, apply the cosine to even columns and sin to odds.
  177. position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
  178. position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1
  179. position_enc = tf.convert_to_tensor(
  180. position_enc, tf.float32) # (maxlen, E)
  181. # lookup
  182. outputs = tf.nn.embedding_lookup(position_enc, position_ind)
  183. # masks
  184. if masking:
  185. outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)
  186. return tf.to_float(outputs)
  187. # def noam_scheme(init_lr, global_step, warmup_steps=4000.):
  188. # '''Noam scheme learning rate decay
  189. # init_lr: initial learning rate. scalar.
  190. # global_step: scalar.
  191. # warmup_steps: scalar. During warmup_steps, learning rate increases
  192. # until it reaches init_lr.
  193. # '''
  194. # step = tf.cast(global_step + 1, dtype=tf.float32)
  195. # return init_lr * warmup_steps ** 0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5)
  196. class Transformer(object):
  197. '''
  198. xs: tuple of
  199. x: int32 tensor. (N, T1)
  200. x_seqlens: int32 tensor. (N,)
  201. sents1: str tensor. (N,)
  202. ys: tuple of
  203. decoder_input: int32 tensor. (N, T2)
  204. y: int32 tensor. (N, T2)
  205. y_seqlen: int32 tensor. (N, )
  206. sents2: str tensor. (N,)
  207. training: boolean.
  208. '''
  209. def __init__(self, hp):
  210. self.hp = hp
  211. # self.token2idx, self.idx2token = load_vocab(hp.vocab)
  212. self.embeddings = get_token_embeddings(
  213. self.hp.vocab_size, self.hp.d_model, zero_pad=True)
  214. def encode(self, xs, training=True):
  215. '''
  216. Returns
  217. memory: encoder outputs. (N, T1, d_model)
  218. '''
  219. with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
  220. x = xs
  221. # src_masks
  222. src_masks = tf.math.equal(x, 0) # (N, T1)
  223. # embedding
  224. enc = tf.nn.embedding_lookup(
  225. self.embeddings, x) # (N, T1, d_model)
  226. enc *= self.hp.d_model**0.5 # scale
  227. enc += positional_encoding(enc, self.hp.maxlen1)
  228. enc = tf.layers.dropout(
  229. enc, self.hp.dropout_rate, training=training)
  230. # Blocks
  231. for i in range(self.hp.num_blocks):
  232. with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
  233. # self-attention
  234. enc = multihead_attention(
  235. queries=enc, keys=enc, values=enc,
  236. batch_size=self.hp.batch_size, hidden_size=self.hp.d_model,
  237. num_attention_heads=self.hp.num_heads,
  238. attention_mask=src_masks,
  239. attention_probs_dropout_prob=self.hp.dropout_rate,
  240. training=training,
  241. causality=False
  242. )
  243. # feed forward
  244. enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
  245. memory = enc
  246. return memory, src_masks
  247. def decode(self, ys, memory, src_masks, training=True):
  248. '''
  249. memory: encoder outputs. (N, T1, d_model)
  250. src_masks: (N, T1)
  251. Returns
  252. logits: (N, T2, V). float32.
  253. y_hat: (N, T2). int32
  254. y: (N, T2). int32
  255. sents2: (N,). string.
  256. '''
  257. with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
  258. decoder_inputs = ys
  259. # tgt_masks
  260. tgt_masks = tf.math.equal(decoder_inputs, 0) # (N, T2)
  261. # embedding
  262. dec = tf.nn.embedding_lookup(
  263. self.embeddings, decoder_inputs) # (N, T2, d_model)
  264. dec *= self.hp.d_model ** 0.5 # scale
  265. dec += positional_encoding(dec, self.hp.maxlen2)
  266. dec = tf.layers.dropout(
  267. dec, self.hp.dropout_rate, training=training)
  268. # Blocks
  269. for i in range(self.hp.num_blocks):
  270. with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
  271. # Masked self-attention (Note that causality is True at this time)
  272. dec = multihead_attention(
  273. queries=dec, keys=dec, values=dec,
  274. batch_size=self.hp.batch_size, hidden_size=self.hp.d_model,
  275. num_attention_heads=self.hp.num_heads,
  276. attention_mask=tgt_masks,
  277. attention_probs_dropout_prob=self.hp.dropout_rate,
  278. training=training,
  279. causality=True,
  280. scope="self_attention"
  281. )
  282. # Vanilla attention
  283. dec = multihead_attention(
  284. queries=dec, keys=memory, values=memory,
  285. batch_size=self.hp.batch_size, hidden_size=self.hp.d_model,
  286. num_attention_heads=self.hp.num_heads,
  287. attention_mask=src_masks,
  288. attention_probs_dropout_prob=self.hp.dropout_rate,
  289. training=training,
  290. causality=False,
  291. scope="vanilla_attention"
  292. )
  293. # Feed Forward
  294. dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])
  295. # Final linear projection (embedding weights are shared)
  296. weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
  297. logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
  298. # y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
  299. return logits
  300. def train(self, xs, ys):
  301. '''
  302. Returns
  303. loss: scalar.
  304. train_op: training operation
  305. global_step: scalar.
  306. summaries: training summary node
  307. '''
  308. # forward
  309. memory, src_masks = self.encode(xs)
  310. logits = self.decode(ys[0], memory, src_masks)
  311. # train scheme
  312. y = ys[1]
  313. y_ = label_smoothing(tf.one_hot(y, depth=self.hp.vocab_size))
  314. loss = tf.nn.softmax_cross_entropy_with_logits_v2(
  315. logits=logits, labels=y_)
  316. return loss
  317. # def eval(self, xs, ys):
  318. # '''Predicts autoregressively
  319. # At inference, input ys is ignored.
  320. # Returns
  321. # y_hat: (N, T2)
  322. # '''
  323. # decoder_inputs, y, y_seqlen, sents2 = ys
  324. # decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx["<s>"]
  325. # ys = (decoder_inputs, y, y_seqlen, sents2)
  326. # memory, sents1, src_masks = self.encode(xs, False)
  327. # logging.info("Inference graph is being built. Please be patient.")
  328. # for _ in tqdm(range(self.hp.maxlen2)):
  329. # logits, y_hat, y, sents2 = self.decode(ys, memory, src_masks, False)
  330. # if tf.reduce_sum(y_hat, 1) == self.token2idx["<pad>"]: break
  331. # _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
  332. # ys = (_decoder_inputs, y, y_seqlen, sents2)
  333. # # monitor a random sample
  334. # n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
  335. # sent1 = sents1[n]
  336. # pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
  337. # sent2 = sents2[n]
  338. # tf.summary.text("sent1", sent1)
  339. # tf.summary.text("pred", pred)
  340. # tf.summary.text("sent2", sent2)
  341. # summaries = tf.summary.merge_all()
  342. # return y_hat, summaries
  343. # def convert_idx_to_token_tensor(inputs, idx2token):
  344. # '''Converts int32 tensor to string tensor.
  345. # inputs: 1d int32 tensor. indices.
  346. # idx2token: dictionary
  347. # Returns
  348. # 1d string tensor.
  349. # '''
  350. # def my_func(inputs):
  351. # return " ".join(idx2token[elem] for elem in inputs)
  352. # return tf.py_func(my_func, [inputs], tf.string)
  353. # def load_vocab(vocab_fpath):
  354. # '''Loads vocabulary file and returns idx<->token maps
  355. # vocab_fpath: string. vocabulary file path.
  356. # Note that these are reserved
  357. # 0: <pad>, 1: <unk>, 2: <s>, 3: </s>
  358. # Returns
  359. # two dictionaries.
  360. # '''
  361. # vocab = [line.split()[0] for line in open(vocab_fpath, 'r', encoding='utf-8').read().splitlines()]
  362. # token2idx = {token: idx for idx, token in enumerate(vocab)}
  363. # idx2token = {idx: token for idx, token in enumerate(vocab)}
  364. # return token2idx, idx2token