You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bert_config.py 3.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. '''
  2. BERT Config:
  3. --------------------------------------------------------------------------------------------------'''
  4. class BertConfig(object):
  5. """Configuration class to store the configuration of a `BertModel`.
  6. """
  7. def __init__(self,
  8. vocab_size,
  9. hidden_size=768,
  10. num_hidden_layers=12,
  11. num_attention_heads=12,
  12. intermediate_size=3072,
  13. hidden_act="relu",
  14. hidden_dropout_prob=0.1,
  15. attention_probs_dropout_prob=0.1,
  16. max_position_embeddings=512,
  17. type_vocab_size=2,
  18. initializer_range=0.02,
  19. output_hidden_states=False,
  20. batch_size=100,
  21. ):
  22. """Constructs BertConfig.
  23. Args:
  24. vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
  25. hidden_size: Size of the encoder layers and the pooler layer.
  26. num_hidden_layers: Number of hidden layers in the Transformer encoder.
  27. num_attention_heads: Number of attention heads for each attention layer in
  28. the Transformer encoder.
  29. intermediate_size: The size of the "intermediate" (i.e., feed-forward)
  30. layer in the Transformer encoder.
  31. hidden_act: The non-linear activation function (function or string) in the
  32. encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
  33. hidden_dropout_prob: The dropout probabilitiy for all fully connected
  34. layers in the embeddings, encoder, and pooler.
  35. attention_probs_dropout_prob: The dropout ratio for the attention
  36. probabilities.
  37. max_position_embeddings: The maximum sequence length that this model might
  38. ever be used with. Typically set this to something large just in case
  39. (e.g., 512 or 1024 or 2048).
  40. type_vocab_size: The vocabulary size of the `token_type_ids` passed into
  41. `BertModel`.
  42. initializer_range: The sttdev of the truncated_normal_initializer for
  43. initializing all weight matrices.
  44. """
  45. self.vocab_size = vocab_size
  46. self.hidden_size = hidden_size
  47. self.num_hidden_layers = num_hidden_layers
  48. self.num_attention_heads = num_attention_heads
  49. self.hidden_act = hidden_act
  50. self.intermediate_size = intermediate_size
  51. self.hidden_dropout_prob = hidden_dropout_prob
  52. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  53. self.max_position_embeddings = max_position_embeddings
  54. self.type_vocab_size = type_vocab_size
  55. self.initializer_range = initializer_range
  56. self.output_hidden_states = output_hidden_states
  57. self.batch_size = batch_size
  58. '''-----------------------------------------------------------------------------------------------'''

分布式深度学习系统