|
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
-
- import tensorflow as tf
-
- from bert_base.train.bert_lstm_ner import train
-
-
- flags = tf.flags
-
- FLAGS = flags.FLAGS
-
- #输入输出地址
- flags.DEFINE_string('data_dir', 'data', '数据集地址')
- flags.DEFINE_string('output_dir', 'output', '输出地址')
-
- #Bert相关参数
- flags.DEFINE_string('bert_config_file', 'chinese_L-12_H-768_A-12/bert_config.json', 'Bert配置文件')
- flags.DEFINE_string('vocab_file', 'chinese_L-12_H-768_A-12/vocab.txt','vocab_file')
- flags.DEFINE_string('init_checkpoint','chinese_L-12_H-768_A-12/bert_model.ckpt', 'init_checkpoint')
-
- #训练和校验的相关参数
- flags.DEFINE_bool('do_train', False, '是否开始训练')
- flags.DEFINE_bool('do_dev', False, '是否开始校验')
- flags.DEFINE_bool('do_test', True, '是否开始测试')
-
- flags.DEFINE_bool('do_lower_case', True, '是否转换小写')
-
- #模型相关的
- flags.DEFINE_integer('lstm_size', 128, 'lstm_size')
- flags.DEFINE_integer('num_layers', 1, 'num_layers')
- flags.DEFINE_integer('max_seq_length', 128, 'max_seq_length')
- flags.DEFINE_integer('train_batch_size', 64, 'train_batch_size')
- flags.DEFINE_integer('dev_batch_size',64, 'dev_batch_size')
- flags.DEFINE_integer('test_batch_size', 32, 'test_batch_size')
- flags.DEFINE_integer('save_checkpoints_steps', 500, 'save_checkpoints_steps')
- flags.DEFINE_integer('iterations_per_loop', 500, 'iterations_per_loop')
- flags.DEFINE_integer('save_summary_steps', 500, 'save_summary_steps')
-
- flags.DEFINE_string('cell', 'lstm', 'cell')
-
- flags.DEFINE_float('learning_rate', 5e-5, 'learning_rate')
- flags.DEFINE_float('dropout_rate', 0.5, 'dropout_rate')
- flags.DEFINE_float('clip', 0.5, 'clip')
- flags.DEFINE_float('num_train_epochs', 10.0, 'num_train_epochs')
- flags.DEFINE_float("warmup_proportion", 0.1,'warmup_proportion')
-
- def train_ner():
- train(FLAGS)
-
- if __name__ == "__main__":
- train_ner()
|