You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hetu_transformer.py 9.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. import hetu as ht
  2. from hetu import init
  3. import numpy as np
  4. def layer_norm(
  5. input_tensor,
  6. feature_size,
  7. eps=1e-8
  8. ):
  9. scale = init.ones(name='layer_norm_scale', shape=(feature_size, ))
  10. bias = init.zeros(name='layer_norm_biad', shape=(feature_size, ))
  11. return ht.layer_normalization_op(input_tensor, scale, bias, eps=eps)
  12. def dense(
  13. input_tensor,
  14. fan_in,
  15. fan_out,
  16. activation=None,
  17. kernel_initializer=init.xavier_normal,
  18. bias_initializer=init.zeros
  19. ):
  20. weights = kernel_initializer(name='dense_weights', shape=(fan_in, fan_out))
  21. bias = bias_initializer(name='dense_bias', shape=(fan_out,))
  22. outputs = ht.matmul_op(input_tensor, weights)
  23. outputs = outputs + ht.broadcastto_op(bias, outputs)
  24. if activation is not None:
  25. outputs = activation(outputs)
  26. return outputs
  27. def dropout(
  28. input_tensor,
  29. dropout_prob
  30. ):
  31. if dropout_prob is None or dropout_prob == 0.0:
  32. return input_tensor
  33. output = ht.dropout_op(input_tensor, 1.0 - dropout_prob)
  34. return output
  35. def get_token_embeddings(vocab_size, num_units, initializer=init.xavier_normal, zero_pad=True):
  36. if zero_pad:
  37. embedding_part = initializer(
  38. name='embedding_table', shape=(vocab_size-1, num_units))
  39. padding_zero = init.zeros(
  40. name='padding_zero', shape=(1, num_units), trainable=False)
  41. embeddings = ht.concat_op(padding_zero, embedding_part)
  42. else:
  43. embeddings = initializer(
  44. name='embedding_table', shape=(vocab_size, num_units))
  45. return embeddings
  46. def multihead_attention(
  47. queries, keys, values,
  48. config,
  49. query_act=None, key_act=None, value_act=None,
  50. attention_mask=None,
  51. causality=False):
  52. def transpose_for_scores(input_tensor):
  53. output_tensor = ht.array_reshape_op(
  54. input_tensor, [config.batch_size, -1, config.num_heads, config.d_model // config.num_heads])
  55. output_tensor = ht.transpose_op(output_tensor, [0, 2, 1, 3])
  56. return output_tensor
  57. batch_size = config.batch_size
  58. hidden_size = config.d_model
  59. num_attention_heads = config.num_heads
  60. caus_len = config.maxlen2 - 1
  61. attention_probs_dropout_prob = config.dropout_rate
  62. size_per_head = hidden_size // num_attention_heads
  63. # reshape to 2d
  64. queries2d = ht.array_reshape_op(
  65. queries, [-1, hidden_size]) # (N * T_q, d_model)
  66. keys2d = ht.array_reshape_op(keys, [-1, hidden_size]) # (N * T_k, d_model)
  67. values2d = ht.array_reshape_op(
  68. values, [-1, hidden_size]) # (N * T_k, d_model)
  69. # linear transformation
  70. query_layer = dense(queries2d, hidden_size, hidden_size,
  71. query_act) # (N * T_k, d_model)
  72. key_layer = dense(keys2d, hidden_size, hidden_size,
  73. key_act) # (N * T_k, d_model)
  74. value_layer = dense(values2d, hidden_size, hidden_size,
  75. value_act) # (N * T_k, d_model)
  76. # transpose
  77. query_layer = transpose_for_scores(query_layer) # (N, h, T_q, d_model/h)
  78. key_layer = transpose_for_scores(key_layer) # (N, h, T_k, d_model/h)
  79. value_layer = transpose_for_scores(value_layer) # (N, h, T_k, d_model/h)
  80. # score
  81. attention_scores = ht.batch_matmul_op(
  82. query_layer, key_layer, trans_B=True) # (N, h, T_q, T_k)
  83. attention_scores = attention_scores * (1.0 / np.sqrt(float(size_per_head)))
  84. # mask
  85. if attention_mask is not None:
  86. zeros = ht.Variable('no_mask', value=np.array(
  87. (0,), dtype=np.float32), trainable=False)
  88. adder = ht.Variable('attention_mask', value=np.array(
  89. (-2**32+1,), dtype=np.float32), trainable=False)
  90. zeros = ht.broadcastto_op(zeros, attention_mask)
  91. adder = ht.broadcastto_op(adder, attention_mask)
  92. attention_mask = ht.where_op(attention_mask, zeros, adder) # (N, T)
  93. attention_mask = ht.array_reshape_op(
  94. attention_mask, [batch_size, 1, 1, -1])
  95. attention_scores = attention_scores + \
  96. ht.broadcastto_op(attention_mask, attention_scores)
  97. if causality:
  98. tril = ht.Variable(name='tril', value=np.tril(
  99. np.ones((caus_len, caus_len))), trainable=False) # (T, T)
  100. future_masks = ht.broadcast_shape_op(
  101. tril, [batch_size, num_attention_heads, caus_len, caus_len])
  102. adder = ht.Variable('future_mask', value=np.array(
  103. (-2**32+1,), dtype=np.float32), trainable=False)
  104. adder = ht.broadcastto_op(adder, future_masks)
  105. attention_scores = ht.where_op(
  106. future_masks, attention_scores, adder) # (N, h, T, T)
  107. # probs
  108. attention_probs = ht.softmax_op(attention_scores)
  109. attention_probs = dropout(attention_probs, attention_probs_dropout_prob)
  110. context_layer = ht.batch_matmul_op(attention_probs, value_layer)
  111. context_layer = ht.transpose_op(context_layer, [0, 2, 1, 3])
  112. outputs = ht.array_reshape_op(
  113. context_layer,
  114. [batch_size, -1, num_attention_heads * size_per_head])
  115. # Residual connection
  116. outputs = outputs + queries # (N, T_q, d_model)
  117. # Normalize
  118. outputs = layer_norm(outputs, hidden_size) # (N, T_q, d_model)
  119. return outputs
  120. def ff(inputs, config):
  121. outputs = ht.array_reshape_op(inputs, [-1, config.d_model])
  122. outputs = dense(outputs, config.d_model,
  123. config.d_ff, activation=ht.relu_op)
  124. outputs = dense(outputs, config.d_ff, config.d_model)
  125. outputs = ht.array_reshape_op(
  126. outputs, [config.batch_size, -1, config.d_model])
  127. outputs = outputs + inputs
  128. outputs = layer_norm(outputs, config.d_model)
  129. return outputs
  130. def label_smoothing(inputs, V, epsilon=0.1):
  131. # V = inputs.shape[-1] # number of channels
  132. return ((1-epsilon) * inputs) + (epsilon / V)
  133. def positional_encoding(
  134. inputs,
  135. inputs_shape,
  136. maxlen,
  137. masking=True
  138. ):
  139. N, T, E = tuple(inputs_shape)
  140. position_enc = np.array([
  141. [pos / np.power(10000, (i & -2)/E) for i in range(E)]
  142. for pos in range(maxlen)])
  143. position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
  144. position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1
  145. position_enc = position_enc[:T, :]
  146. outputs = ht.Variable(name='position_enc', value=np.tile(
  147. position_enc, [N, 1, 1]), trainable=False)
  148. zeros = ht.Variable(name='zeros', value=np.zeros(
  149. inputs_shape), trainable=False)
  150. if masking:
  151. outputs = ht.where_op(inputs, outputs, zeros)
  152. return outputs
  153. class Transformer(object):
  154. def __init__(self, hp):
  155. self.hp = hp
  156. self.embeddings = get_token_embeddings(
  157. self.hp.vocab_size, self.hp.d_model, zero_pad=True)
  158. def encode(self, xs):
  159. x = xs
  160. # embedding
  161. enc = ht.embedding_lookup_op(self.embeddings, x) # (N, T1, d_model)
  162. enc = enc * self.hp.d_model**0.5 # scale
  163. enc += positional_encoding(enc, (self.hp.batch_size,
  164. self.hp.maxlen1, self.hp.d_model), self.hp.maxlen1)
  165. enc = dropout(enc, self.hp.dropout_rate)
  166. # Blocks
  167. for i in range(self.hp.num_blocks):
  168. # self-attention
  169. enc = multihead_attention(
  170. queries=enc, keys=enc, values=enc,
  171. config=self.hp,
  172. attention_mask=x,
  173. causality=False
  174. )
  175. # feed forward
  176. enc = ff(enc, config=self.hp)
  177. memory = enc
  178. return memory
  179. def decode(self, ys, memory, src_masks):
  180. decoder_inputs = ys
  181. # embedding
  182. dec = ht.embedding_lookup_op(
  183. self.embeddings, decoder_inputs) # (N, T2, d_model)
  184. dec = dec * self.hp.d_model ** 0.5 # scale
  185. dec += positional_encoding(dec, (self.hp.batch_size,
  186. self.hp.maxlen2-1, self.hp.d_model), self.hp.maxlen2)
  187. dec = dropout(dec, self.hp.dropout_rate)
  188. # Blocks
  189. for i in range(self.hp.num_blocks):
  190. # Masked self-attention (Note that causality is True at this time)
  191. dec = multihead_attention(
  192. queries=dec, keys=dec, values=dec,
  193. config=self.hp,
  194. attention_mask=decoder_inputs,
  195. causality=True,
  196. )
  197. # Vanilla attention
  198. dec = multihead_attention(
  199. queries=dec, keys=memory, values=memory,
  200. config=self.hp,
  201. attention_mask=src_masks,
  202. causality=False,
  203. )
  204. # Feed Forward
  205. dec = ff(dec, config=self.hp)
  206. dec = ht.array_reshape_op(
  207. dec, [-1, self.hp.d_model]) # (N * T, d_model)
  208. logits = ht.array_reshape_op(ht.matmul_op(dec, self.embeddings, trans_B=True), [
  209. self.hp.batch_size, -1, self.hp.vocab_size]) # (N, T, vocab)
  210. return logits
  211. def train(self, xs, ys):
  212. # forward
  213. memory = self.encode(xs)
  214. logits = self.decode(ys[0], memory, xs)
  215. # train scheme
  216. y = ys[1]
  217. y_ = label_smoothing(ht.one_hot_op(
  218. y, self.hp.vocab_size), self.hp.vocab_size) # (N, T, vocab)
  219. loss = ht.softmaxcrossentropy_op(logits, y_)
  220. return loss