You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

hetu_bert.py 35 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. import hetu as ht
  2. import numpy as np
  3. '''
  4. Bert Module Architecture & Input/Output Tensor Size
  5. BertModel Inputs:
  6. input_ids: [batch_size, seq_len], word token indices in the vocabulary
  7. BertModel Outputs:
  8. sequence_output: [batch_size, seq_len, hidden_size] (from BertEncoder)
  9. pooled_output: [batch_size, hidden_size] (from BertPooler)
  10. BertModel:
  11. --[batch_size, seq_len]--
  12. BertEmbeddings:
  13. Embedding(word/position/token_type)
  14. LayerNorm
  15. Dropout
  16. --[batch_size, seq_len, hidden_size]--
  17. --[batch_size, seq_len, hidden_size]--
  18. BertEncoder:
  19. BertLayer(num_hidden_layers):
  20. BertAttention:
  21. BertSelfAttention
  22. --[batch_size, seq_len, hidden_size]--
  23. BertSelfOutput:
  24. Linear
  25. Dropout
  26. Add & LayerNorm
  27. --[batch_size, seq_len, hidden_size]--
  28. BertIntermediate:
  29. Linear + Act(gule)
  30. --[batch_size, seq_len, intermediate_size]--
  31. BertOutput:
  32. Linear
  33. Dropout
  34. Add & LayerNorm
  35. --[batch_size, seq_len, hidden_size]--
  36. --[batch_size, seq_len, hidden_size]--
  37. BertPooler:
  38. (Slice, select [cls])
  39. --[batch_size, hidden_size]--
  40. Linear + Act(Tanh)
  41. --[batch_size, hidden_size]--
  42. Bert
  43. '''
  44. '''
  45. BertEmbeddings:
  46. --------------------------------------------------------------------------------------------------'''
  47. class BertEmbeddings(object):
  48. """Construct the embeddings from word, position and token_type embeddings.
  49. """
  50. def __init__(self, config):
  51. self.seq_len = config.max_position_embeddings
  52. self.batch_size = config.batch_size
  53. self.word_embeddings = Embedding(config.vocab_size, config.hidden_size, "word_embeddings")
  54. self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size, 'position_embeddings')
  55. self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size, 'token_type_embeddings')
  56. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  57. self.dropout = Dropout(config.hidden_dropout_prob)
  58. def __call__(self, input_ids, token_type_ids):
  59. '''
  60. inputs:
  61. input_ids: [batch_size, seq_len]
  62. token_type_ids: [batch_size, seq_len]
  63. outputs:
  64. embeddings: [batch_size, seq_len, hidden_size]
  65. '''
  66. seq_length= self.seq_len
  67. batch_size = self.batch_size
  68. position_ids = ht.Variable('position_ids', value=np.arange(seq_length).reshape((1,-1)).repeat(batch_size,axis=0), dtype=np.long, trainable=False, ctx=input_ids.ctx)
  69. '''Embedding Size
  70. inputs_id:[batch_size, seq_len], embedding_table:[vocab_size, hidden_size]
  71. position_ids:[batch_size, seq_len], embedding_table:[seq_len, hidden_size]
  72. token_type_ids:[batch_size, seq_len], embedding_tabel:[type_vocab_size, hidden_size]
  73. --> embeddings: [batch_size, seq_len, hidden_size]
  74. '''
  75. words_embeddings = self.word_embeddings(input_ids)
  76. position_embeddings = self.position_embeddings(position_ids)
  77. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  78. embeddings = words_embeddings + position_embeddings + token_type_embeddings
  79. embeddings = self.LayerNorm(embeddings)
  80. embeddings = self.dropout(embeddings)
  81. return embeddings
  82. '''-----------------------------------------------------------------------------------------------'''
  83. '''
  84. BertEncoder & BertLayer:
  85. --------------------------------------------------------------------------------------------------'''
  86. class BertEncoder(object):
  87. def __init__(self, config):
  88. self.output_hidden_states = config.output_hidden_states
  89. self.layer = [BertLayer(config) for _ in range(config.num_hidden_layers)]
  90. def __call__(self, hidden_states, attention_mask=None):
  91. '''
  92. inputs:
  93. hidden_states: [batch_size, seq_len, hidden_size]
  94. attention_mask: [batch_size, num_heads, seq_len, seq_len]
  95. outputs:
  96. hidden_states: [batch_size, seq_len, hidden_size]
  97. all_hidden_states: optional, num_hidden_layers * [batch_size, seq_len, hidden_size]
  98. '''
  99. for i, layer_module in enumerate(self.layer):
  100. hidden_states = layer_module(hidden_states, attention_mask)
  101. return hidden_states # last-layer hidden state
  102. class BertLayer(object):
  103. def __init__(self, config):
  104. self.attention = BertAttention(config)
  105. self.intermediate = BertIntermediate(config)
  106. self.output = BertOutput(config)
  107. def __call__(self, hidden_states, attention_mask):
  108. '''
  109. inputs:
  110. hidden_states: [batch_size, seq_len, hidden_size]
  111. attention_mask: [batch_size, num_heads, seq_len, seq_len]
  112. outputs:
  113. layer_output: [batch_size, seq_len, hidden_size]
  114. '''
  115. attention_output = self.attention(hidden_states, attention_mask)
  116. intermediate_output = self.intermediate(attention_output)
  117. layer_output = self.output(intermediate_output, attention_output)
  118. return layer_output
  119. '''-----------------------------------------------------------------------------------------------'''
  120. '''
  121. BertAttention & BertSelfAttention & BertSelfOutput
  122. --------------------------------------------------------------------------------------------------'''
  123. class BertAttention(object):
  124. def __init__(self, config):
  125. self.self = BertSelfAttention(config)
  126. self.output = BertSelfOutput(config)
  127. def __call__(self, input_tensor, attention_mask):
  128. '''
  129. inputs:
  130. input_tensor: [batch_size, seq_len, hidden_size]
  131. attention_mask: [batch_size, num_heads, seq_len, seq_len]
  132. outputs:
  133. attention_output: [batch_size, seq_len, hidden_size]
  134. '''
  135. self_output = self.self(input_tensor, attention_mask)
  136. attention_output = self.output(self_output, input_tensor)
  137. return attention_output
  138. class BertSelfAttention(object):
  139. def __init__(self, config):
  140. if config.hidden_size % config.num_attention_heads != 0:
  141. raise ValueError(
  142. "The hidden size (%d) is not a multiple of the number of attention "
  143. "heads (%d)" % (config.hidden_size, config.num_attention_heads))
  144. self.num_attention_heads = config.num_attention_heads
  145. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  146. self.all_head_size = self.num_attention_heads * self.attention_head_size #all_head_size == hidden_size
  147. self.hidden_size = config.hidden_size
  148. self.seq_len = config.max_position_embeddings
  149. self.batch_size = config.batch_size
  150. linear_input_shape = [self.batch_size, self.seq_len, self.hidden_size]
  151. self.query = Linear(config.hidden_size, self.all_head_size, input_shape=linear_input_shape)
  152. self.key = Linear(config.hidden_size, self.all_head_size, input_shape=linear_input_shape)
  153. self.value = Linear(config.hidden_size, self.all_head_size, input_shape=linear_input_shape)
  154. self.dropout = Dropout(config.attention_probs_dropout_prob)
  155. def transpose_for_scores(self, input_tensor):
  156. output_tensor = ht.array_reshape_op(
  157. input_tensor, [self.batch_size, self.seq_len, self.num_attention_heads, self.attention_head_size])
  158. output_tensor = ht.transpose_op(output_tensor, [0, 2, 1, 3])
  159. return output_tensor
  160. def __call__(self, hidden_states, attention_mask):
  161. '''
  162. inputs:
  163. hidden_states: [batch_size, seq_len, hidden_size]
  164. attention_mask: [batch_size, 1, 1, seq_len]
  165. outputs:
  166. context_layer: [batch_size, seq_len, hidden_size]
  167. '''
  168. # linear transformation
  169. mixed_query_layer = self.query(hidden_states) # [batch_size, seq_len, hidden_size]
  170. mixed_key_layer = self.key(hidden_states) # [batch_size, seq_len, hidden_size]
  171. mixed_value_layer = self.value(hidden_states) # [batch_size, seq_len, hidden_size]
  172. # transpose
  173. query_layer = self.transpose_for_scores(mixed_query_layer) # [batch_size, num_heads, seq_len, head_size]
  174. key_layer = self.transpose_for_scores(mixed_key_layer) # [batch_size, num_heads, seq_len, head_size]
  175. value_layer = self.transpose_for_scores(mixed_value_layer) # [batch_size, num_heads, seq_len, head_size]
  176. # score
  177. key_layer_scaled = key_layer * (1.0 / np.sqrt(float(self.attention_head_size)))
  178. attention_scores = ht.batch_matmul_op(query_layer, key_layer_scaled, trans_B=True) # [batch_size, num_heads, seq_len, seq_len]
  179. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  180. attention_scores = attention_scores + ht.broadcastto_op(attention_mask, attention_scores) # [batch_size, num_heads, seq_len, seq_len]
  181. # Normalize the attention scores to probabilities.
  182. attention_probs = ht.softmax_op(attention_scores)
  183. # This is actually dropping out entire tokens to attend to, which might
  184. # seem a bit unusual, but is taken from the original Transformer paper.
  185. attention_probs = self.dropout(attention_probs)
  186. context_layer = ht.batch_matmul_op(attention_probs, value_layer) # [batch_size, num_heads, seq_len, head_size]
  187. context_layer = ht.transpose_op(context_layer, [0, 2, 1, 3]) # [batch_size, seq_len, num_heads, head_size]
  188. context_layer = ht.array_reshape_op(context_layer, [-1, self.seq_len, self.all_head_size]) # [batch_size, seq_len, hidden_size]
  189. return context_layer
  190. class BertSelfOutput(object):
  191. def __init__(self, config):
  192. linear_input_shape = [config.batch_size, config.max_position_embeddings, config.hidden_size]
  193. self.dense = Linear(config.hidden_size, config.hidden_size, input_shape=linear_input_shape)
  194. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  195. self.dropout = Dropout(config.hidden_dropout_prob)
  196. def __call__(self, hidden_states, input_tensor):
  197. '''
  198. inputs:
  199. hidden_states: [batch_size, seq_len, hidden_size]
  200. input_tensor: [batch_size, seq_len, hidden_size]
  201. outputs:
  202. hidden_states: [batch_size, seq_len, hidden_size]
  203. '''
  204. hidden_states = self.dense(hidden_states)
  205. hidden_states = self.dropout(hidden_states)
  206. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  207. return hidden_states
  208. '''-----------------------------------------------------------------------------------------------'''
  209. '''
  210. BertIntermediate & BertOutput (2-layer FeedForward)
  211. --------------------------------------------------------------------------------------------------'''
  212. class BertIntermediate(object):
  213. def __init__(self, config):
  214. if config.hidden_act == "relu":
  215. self.intermediate_act_fn = ht.relu_op
  216. elif config.hidden_act == "gelu":
  217. self.intermediate_act_fn = ht.gelu_op
  218. print("Gelu activation is not implemented yet.")
  219. assert(False)
  220. linear_input_shape = [config.batch_size, config.max_position_embeddings, config.hidden_size]
  221. self.dense = Linear(config.hidden_size, config.intermediate_size, activation = self.intermediate_act_fn, input_shape=linear_input_shape)
  222. def __call__(self, hidden_states):
  223. '''
  224. inputs:
  225. hidden_states: [batch_size, seq_len, hidden_size]
  226. outputs:
  227. hidden_states: [batch_size, seq_len, intermediate_size]
  228. '''
  229. hidden_states = self.dense(hidden_states)
  230. return hidden_states
  231. class BertOutput(object):
  232. def __init__(self, config):
  233. linear_input_shape = [config.batch_size, config.max_position_embeddings, config.intermediate_size]
  234. self.dense = Linear(config.intermediate_size, config.hidden_size, input_shape=linear_input_shape)
  235. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  236. self.dropout = Dropout(config.hidden_dropout_prob)
  237. def __call__(self, hidden_states, input_tensor):
  238. '''
  239. inputs:
  240. hidden_states: [batch_size, seq_len, intermediate_size]
  241. outputs:
  242. hidden_states: [batch_size, seq_len, hidden_size]
  243. '''
  244. hidden_states = self.dense(hidden_states)
  245. hidden_states = self.dropout(hidden_states)
  246. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  247. return hidden_states
  248. '''-----------------------------------------------------------------------------------------------'''
  249. '''
  250. BertPooler
  251. --------------------------------------------------------------------------------------------------'''
  252. class BertPooler(object):
  253. def __init__(self, config):
  254. self.dense = Linear(config.hidden_size, config.hidden_size, activation = ht.tanh_op)
  255. self.batch_size = config.batch_size
  256. self.hidden_size = config.hidden_size
  257. def __call__(self, hidden_states):
  258. '''
  259. inputs:
  260. hidden_states: [batch_size, seq_len, hidden_size]
  261. outputs:
  262. pooled_output: [batch_size, hidden_size]
  263. '''
  264. first_token_tensor = ht.slice_op(hidden_states,(0,0,0),(self.batch_size,1,self.hidden_size))
  265. first_token_tensor = ht.array_reshape_op(first_token_tensor, [self.batch_size, self.hidden_size])
  266. pooled_output = self.dense(first_token_tensor)
  267. return pooled_output
  268. '''-----------------------------------------------------------------------------------------------'''
  269. '''
  270. Bert Downstream Heads
  271. --------------------------------------------------------------------------------------------------'''
  272. class BertPredictionHeadTransform(object):
  273. def __init__(self, config):
  274. if config.hidden_act == "relu":
  275. self.hidden_act = ht.relu_op
  276. elif config.hidden_act == "gelu":
  277. self.hidden_act = ht.gelu_op
  278. print("Gelu activation is not implemented yet.")
  279. assert(False)
  280. linear_input_shape = [config.batch_size, config.max_position_embeddings, config.hidden_size]
  281. self.dense_act = Linear(config.hidden_size, config.hidden_size, activation=self.hidden_act, input_shape=linear_input_shape)
  282. self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
  283. def __call__(self, hidden_states):
  284. '''
  285. inputs:
  286. hidden_states: [batch_size, seq_len, hidden_size]
  287. outputs:
  288. hidden_states: [batch_size, seq_len, hidden_size]
  289. '''
  290. hidden_states = self.dense_act(hidden_states)
  291. hidden_states = self.LayerNorm(hidden_states)
  292. return hidden_states
  293. class BertLMPredictionHead(object):
  294. def __init__(self, config, bert_model_embedding_weights):
  295. '''
  296. bert_model_embedding_weights: [vocab_size, hidden_size]
  297. '''
  298. self.transform = BertPredictionHeadTransform(config)
  299. linear_input_shape = [config.batch_size, config.max_position_embeddings, config.hidden_size]
  300. self.decoder = Linear(config.hidden_size, config.vocab_size, bias_initializer=ht.init.zeros,input_shape=linear_input_shape)
  301. self.decoder.weights = ht.transpose_op(bert_model_embedding_weights)
  302. def __call__(self, hidden_states):
  303. '''
  304. inputs:
  305. hidden_states: [batch_size, seq_len, hidden_size]
  306. outputs:
  307. hidden_states: [batch_size, seq_len, vocab_size]
  308. '''
  309. hidden_states = self.transform(hidden_states)
  310. hidden_states = self.decoder(hidden_states)
  311. return hidden_states
  312. class BertOnlyMLMHead(object):
  313. def __init__(self, config, bert_model_embedding_weights):
  314. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  315. def __call__(self, sequence_output):
  316. '''
  317. inputs:
  318. sequence_output: [batch_size, seq_len, hidden_size]
  319. outputs:
  320. prediction_scores: [batch_size, seq_len, vocab_size]
  321. '''
  322. prediction_scores = self.predictions(sequence_output)
  323. return prediction_scores
  324. class BertOnlyNSPHead(object):
  325. def __init__(self, config):
  326. self.seq_relationship = Linear(config.hidden_size, 2)
  327. def __call__(self, pooled_output):
  328. '''
  329. inputs:
  330. pooled_output: [batch_size, hidden_size]
  331. outputs:
  332. seq_relationship_score: [batch_size, 2]
  333. '''
  334. seq_relationship_score = self.seq_relationship(pooled_output)
  335. return seq_relationship_score
  336. class BertPreTrainingHeads(object):
  337. def __init__(self, config, bert_model_embedding_weights):
  338. self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)
  339. self.seq_relationship = Linear(config.hidden_size, 2)
  340. def __call__(self, sequence_output, pooled_output):
  341. '''
  342. inputs:
  343. sequence_output: [batch_size, seq_len, hidden_size]
  344. pooled_output: [batch_size, hidden_size]
  345. outputs:
  346. prediction_scores: [batch_size, seq_len, vocab_size]
  347. seq_relationship_score: [batch_size, 2]
  348. '''
  349. prediction_scores = self.predictions(sequence_output)
  350. seq_relationship_score = self.seq_relationship(pooled_output)
  351. return prediction_scores, seq_relationship_score
  352. '''-----------------------------------------------------------------------------------------------'''
  353. '''
  354. BertModel:
  355. --------------------------------------------------------------------------------------------------'''
  356. class BertModel(object):
  357. """BERT model ("Bidirectional Embedding Representations from a Transformer").
  358. Params:
  359. config: a BertConfig class instance with the configuration to build a new model
  360. Inputs:
  361. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  362. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  363. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  364. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  365. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  366. a `sentence B` token (see BERT paper for more details).
  367. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  368. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  369. input sequence length in the current batch. It's the mask that we typically use for attention when
  370. a batch has varying length sentences.
  371. Outputs: Tuple of (encoded_layers, pooled_output)
  372. `encoded_layers`: controled by `output_all_encoded_layers` argument:
  373. - `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the end
  374. of each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), each
  375. encoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],
  376. - `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states corresponding
  377. to the last attention block of shape [batch_size, sequence_length, hidden_size],
  378. `pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of a
  379. classifier pretrained on top of the hidden state associated to the first character of the
  380. input (`CLS`) to train on the Next-Sentence task (see BERT's paper).
  381. Example usage:
  382. ```python
  383. # Already been converted into WordPiece token ids
  384. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  385. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  386. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  387. config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  388. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  389. model = modeling.BertModel(config=config)
  390. all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)
  391. ```
  392. """
  393. def __init__(self, config):
  394. self.embeddings = BertEmbeddings(config)
  395. self.encoder = BertEncoder(config)
  396. self.pooler = BertPooler(config)
  397. self.batch_size=config.batch_size
  398. self.seq_len=config.max_position_embeddings
  399. def __call__(self, input_ids, token_type_ids, attention_mask):
  400. extended_attention_mask = ht.array_reshape_op(attention_mask, [self.batch_size, 1, 1, self.seq_len])
  401. extended_attention_mask = (extended_attention_mask+(-1.0)) * 10000.0
  402. embedding_output = self.embeddings(input_ids, token_type_ids)
  403. sequence_output = self.encoder(embedding_output, extended_attention_mask)
  404. pooled_output = self.pooler(sequence_output)
  405. return sequence_output, pooled_output
  406. '''-----------------------------------------------------------------------------------------------'''
  407. '''
  408. BertForPreTraining:
  409. --------------------------------------------------------------------------------------------------'''
  410. class BertForPreTraining(object):
  411. """BERT model with pre-training heads.
  412. This module comprises the BERT model followed by the two pre-training heads:
  413. - the masked language modeling head, and
  414. - the next sentence classification head.
  415. Params:
  416. config: a BertConfig class instance with the configuration to build a new model.
  417. Inputs:
  418. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  419. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  420. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  421. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  422. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  423. a `sentence B` token (see BERT paper for more details).
  424. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  425. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  426. input sequence length in the current batch. It's the mask that we typically use for attention when
  427. a batch has varying length sentences.
  428. `masked_lm_labels`: optional masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  429. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  430. is only computed for the labels set in [0, ..., vocab_size]
  431. `next_sentence_label`: optional next sentence classification loss: torch.LongTensor of shape [batch_size]
  432. with indices selected in [0, 1].
  433. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  434. Outputs:
  435. if `masked_lm_labels` and `next_sentence_label` are not `None`:
  436. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  437. sentence classification loss.
  438. if `masked_lm_labels` or `next_sentence_label` is `None`:
  439. Outputs a tuple comprising
  440. - the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
  441. - the next sentence classification logits of shape [batch_size, 2].
  442. Example usage:
  443. ```python
  444. # Already been converted into WordPiece token ids
  445. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  446. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  447. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  448. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  449. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  450. model = BertForPreTraining(config)
  451. masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  452. ```
  453. """
  454. def __init__(self, config):
  455. self.bert = BertModel(config)
  456. self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
  457. self.vocab_size=config.vocab_size
  458. def __call__(self, input_ids, token_type_ids, attention_mask, masked_lm_labels=None, next_sentence_label=None):
  459. sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  460. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  461. return_op = [prediction_scores, seq_relationship_score]
  462. if masked_lm_labels is not None and next_sentence_label is not None:
  463. '''
  464. masked_lm_labels: [batch_size, seq_len, vocab_size], one hot form, masked places are filled with 0
  465. prediction_scores: [batch_size, seq_len, vocab_size]
  466. next_sentence_label: [batch_size, 2], one hot form, masked places are filled with 0
  467. seq_relationship_score: [batch_size, 2]
  468. masked_lm_loss: [batch_size*seq_len]
  469. next_sentence_loss: [batch_size]
  470. '''
  471. masked_lm_loss = ht.softmaxcrossentropy_sparse_op(prediction_scores, masked_lm_labels, ignored_index=-1)
  472. next_sentence_loss = ht.softmaxcrossentropy_sparse_op(seq_relationship_score, next_sentence_label, ignored_index=-1)
  473. return_op += [masked_lm_loss, next_sentence_loss]
  474. return return_op
  475. class BertForMaskedLM(object):
  476. """BERT model with the masked language modeling head.
  477. This module comprises the BERT model followed by the masked language modeling head.
  478. Params:
  479. config: a BertConfig class instance with the configuration to build a new model.
  480. Inputs:
  481. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  482. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  483. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  484. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  485. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  486. a `sentence B` token (see BERT paper for more details).
  487. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  488. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  489. input sequence length in the current batch. It's the mask that we typically use for attention when
  490. a batch has varying length sentences.
  491. `masked_lm_labels`: masked language modeling labels: torch.LongTensor of shape [batch_size, sequence_length]
  492. with indices selected in [-1, 0, ..., vocab_size]. All labels set to -1 are ignored (masked), the loss
  493. is only computed for the labels set in [0, ..., vocab_size]
  494. Outputs:
  495. if `masked_lm_labels` is not `None`:
  496. Outputs the masked language modeling loss.
  497. if `masked_lm_labels` is `None`:
  498. Outputs the masked language modeling logits of shape [batch_size, sequence_length, vocab_size].
  499. Example usage:
  500. ```python
  501. # Already been converted into WordPiece token ids
  502. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  503. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  504. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  505. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  506. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  507. model = BertForMaskedLM(config)
  508. masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
  509. ```
  510. """
  511. def __init__(self, config):
  512. self.bert = BertModel(config)
  513. self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
  514. self.vocab_size=config.vocab_size
  515. def __call__(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
  516. sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask)
  517. prediction_scores = self.cls(sequence_output)
  518. return_op = [prediction_scores]
  519. if masked_lm_labels is not None:
  520. '''
  521. masked_lm_labels: [batch_size, seq_len, vocab_size], one hot form, masked places are filled with 0
  522. prediction_scores: [batch_size, seq_len, vocab_size]
  523. masked_lm_loss: [batch_size*seq_len]
  524. '''
  525. masked_lm_loss = ht.softmaxcrossentropy_sparse_op(prediction_scores, masked_lm_labels, ignored_index=-1)
  526. return_op += [masked_lm_loss]
  527. return return_op
  528. class BertForNextSentencePrediction(object):
  529. """BERT model with next sentence prediction head.
  530. This module comprises the BERT model followed by the next sentence classification head.
  531. Params:
  532. config: a BertConfig class instance with the configuration to build a new model.
  533. Inputs:
  534. `input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]
  535. with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts
  536. `extract_features.py`, `run_classifier.py` and `run_squad.py`)
  537. `token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the token
  538. types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds to
  539. a `sentence B` token (see BERT paper for more details).
  540. `attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indices
  541. selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max
  542. input sequence length in the current batch. It's the mask that we typically use for attention when
  543. a batch has varying length sentences.
  544. `next_sentence_label`: next sentence classification loss: torch.LongTensor of shape [batch_size]
  545. with indices selected in [0, 1].
  546. 0 => next sentence is the continuation, 1 => next sentence is a random sentence.
  547. Outputs:
  548. if `next_sentence_label` is not `None`:
  549. Outputs the total_loss which is the sum of the masked language modeling loss and the next
  550. sentence classification loss.
  551. if `next_sentence_label` is `None`:
  552. Outputs the next sentence classification logits of shape [batch_size, 2].
  553. Example usage:
  554. ```python
  555. # Already been converted into WordPiece token ids
  556. input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
  557. input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
  558. token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
  559. config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
  560. num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)
  561. model = BertForNextSentencePrediction(config)
  562. seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
  563. ```
  564. """
  565. def __init__(self, config):
  566. self.bert = BertModel(config)
  567. self.cls = BertOnlyNSPHead(config)
  568. def __call__(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
  569. _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
  570. seq_relationship_score = self.cls(pooled_output)
  571. return_op = [seq_relationship_score]
  572. if next_sentence_label is not None:
  573. '''
  574. next_sentence_label: [batch_size, 2], one hot form, masked places are filled with 0
  575. seq_relationship_score: [batch_size, 2]
  576. next_sentence_loss: [batch_size]
  577. '''
  578. next_sentence_loss = ht.softmaxcrossentropy_sparse_op(seq_relationship_score, next_sentence_label, ignored_index=-1)
  579. return_op += [next_sentence_loss]
  580. return return_op
  581. '''-----------------------------------------------------------------------------------------------'''
  582. '''
  583. Bert Layer utils (Embedding & BerLayerNorm & Dropout & Linear)
  584. --------------------------------------------------------------------------------------------------'''
  585. class Embedding(object):
  586. def __init__(self, num_embeddings, embedding_dim, embedding_name=None, initializer=ht.init.xavier_normal):
  587. self.weight = initializer(name=embedding_name, shape=(num_embeddings, embedding_dim))
  588. def __call__(self, input_tensor):
  589. return ht.embedding_lookup_op(self.weight, input_tensor)
  590. class BertLayerNorm(object):
  591. def __init__(self, hidden_size, eps=1e-12):
  592. self.eps=eps
  593. self.scale = ht.init.ones(name='layer_norm_scale', shape=(hidden_size, ))
  594. self.bias = ht.init.zeros(name='layer_norm_bias', shape=(hidden_size, ))
  595. def __call__(self, input_tensor):
  596. return ht.layer_normalization_op(input_tensor, self.scale, self.bias, eps=self.eps)
  597. class Dropout(object):
  598. def __init__(self, dropout_prob=None):
  599. self.dropout_prob = dropout_prob
  600. def __call__(self, input_tensor):
  601. if self.dropout_prob is None or self.dropout_prob == 0.0:
  602. return input_tensor
  603. output = ht.dropout_op(input_tensor, 1.0 - self.dropout_prob)
  604. return output
  605. class Linear(object):
  606. def __init__(self, in_features, out_features, bias=True, activation=None, kernel_initializer=ht.init.xavier_normal, bias_initializer=ht.init.zeros, input_shape=None):
  607. self.bias_flag = bias
  608. self.activation = activation
  609. self.weights = kernel_initializer(name='dense_weights', shape=(in_features, out_features))
  610. if self.bias_flag:
  611. self.bias = bias_initializer(name='dense_bias', shape=(out_features,))
  612. self.input_shape=input_shape
  613. self.in_features = in_features
  614. self.out_features = out_features
  615. if self.input_shape is not None and self.input_shape[-1]!=in_features:
  616. print("Specified in_features is not equal to input_shape[-1].")
  617. assert(False)
  618. def __call__(self, input_tensor):
  619. if self.input_shape is not None and len(self.input_shape)!=2:
  620. input_tensor = ht.array_reshape_op(input_tensor, [-1, self.in_features])
  621. outputs = ht.matmul_op(input_tensor, self.weights)
  622. if self.bias_flag:
  623. outputs = outputs + ht.broadcastto_op(self.bias, outputs)
  624. if self.activation is not None:
  625. outputs = self.activation(outputs)
  626. if self.input_shape is not None and len(self.input_shape)!=2:
  627. outputs = ht.array_reshape_op(outputs, self.input_shape[:-1]+[self.out_features])
  628. return outputs
  629. '''-----------------------------------------------------------------------------------------------'''

分布式深度学习系统