You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. from transformers import Trainer, TrainingArguments, AdamW
  2. from transformers.optimization import get_linear_schedule_with_warmup
  3. from pathlib import Path
  4. import torch
  5. import nltk
  6. from torch.optim import Adam, lr_scheduler
  7. from torch.utils.data import DataLoader
  8. from tqdm import tqdm
  9. from CVSSDataset import CVSSDataset, read_cvss_csv
  10. from lemmatization import lemmatize, lemmatize_word, lemmatize_noun
  11. from remove_stop_words import remove_stop_words
  12. from stemmatization import stemmatize
  13. import numpy as np
  14. import argparse
  15. import os
  16. def select_tokenizer_model(model_name, extra_tokens, token_file, num_labels):
  17. global lemmatization
  18. print("### Selecting Model and Tokenizer")
  19. if model_name == 'distilbert':
  20. from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertConfig
  21. config = DistilBertConfig.from_pretrained('distilbert-base-cased')
  22. tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
  23. model = DistilBertForSequenceClassification(config)
  24. elif model_name == 'bert':
  25. from transformers import BertTokenizerFast, BertForSequenceClassification, BertConfig
  26. config = BertConfig.from_pretrained('bert-base-uncased')
  27. tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
  28. model = BertForSequenceClassification(config)
  29. elif model_name == 'deberta':
  30. from transformers import DebertaConfig, DebertaTokenizerFast, DebertaForSequenceClassification
  31. config = DebertaConfig.from_pretrained('microsoft/deberta-base')
  32. tokenizer = DebertaTokenizerFast.from_pretrained('microsoft/deberta-base')
  33. model = DebertaForSequenceClassification(config)
  34. elif model_name == 'albert':
  35. from transformers import AlbertConfig, AlbertTokenizerFast, AlbertForSequenceClassification
  36. config = AlbertConfig.from_pretrained('albert-base-v1')
  37. tokenizer = AlbertTokenizerFast.from_pretrained('albert-base-v1')
  38. model = AlbertForSequenceClassification(config)
  39. elif model_name == 'roberta':
  40. from transformers import RobertaConfig, RobertaTokenizerFast, RobertaForSequenceClassification
  41. config = RobertaConfig.from_pretrained('roberta-base')
  42. tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
  43. model = RobertaForSequenceClassification(config)
  44. ### Add Tokens
  45. if extra_tokens:
  46. add_tokens_from_file(token_file, tokenizer, lemmatization)
  47. number_tokens = len(tokenizer)
  48. print("### Number of tokens in Tokenizer")
  49. print(number_tokens)
  50. # print("### Configuration")
  51. # print(model.config)
  52. model.resize_token_embeddings(number_tokens)
  53. return tokenizer, model
  54. def add_tokens_from_file(token_file, tokenizer, lemmatize=False):
  55. print("### Adding Tokens")
  56. file_ = open(token_file, 'r', encoding='UTF-8')
  57. token_list = []
  58. for line in file_:
  59. if lemmatize:
  60. token_list.append(lemmatize_noun(line.rstrip("\n")))
  61. else:
  62. token_list.append(line.rstrip("\n"))
  63. file_.close()
  64. tokenizer.add_tokens(token_list)
  65. def get_pred_accuracy(target, output):
  66. output = output.argmax(axis=1) # -> multi label
  67. tot_right = np.sum(target == output)
  68. tot = target.size
  69. return (tot_right/tot) * 100
  70. def get_binary_mean_accuracy(target, output):
  71. eps = 1e-20
  72. output = output.argmax(axis=1)
  73. # TP + FN
  74. gt_pos = np.sum((target == 1), axis=0).astype(float)
  75. # TN + FP
  76. gt_neg = np.sum((target == 0), axis=0).astype(float)
  77. # TP
  78. true_pos = np.sum((target == 1) * (output == 1), axis=0).astype(float)
  79. # TN
  80. true_neg = np.sum((target == 0) * (output == 0), axis=0).astype(float)
  81. label_pos_recall = 1.0 * true_pos / (gt_pos + eps) # true positive
  82. label_neg_recall = 1.0 * true_neg / (gt_neg + eps) # true negative
  83. # mean accuracy
  84. return (label_pos_recall + label_neg_recall) / 2
  85. def get_evaluation_metrics(target, output, num_labels):
  86. accuracy = get_pred_accuracy(target, output, num_labels)
  87. return accuracy
  88. def infer(trainer, test_loader, num_labels):
  89. predicts = trainer.predict(test_loader)
  90. soft = torch.nn.Softmax(dim=1)
  91. pred_probs = torch.from_numpy(predicts.predictions)
  92. pred_probs = soft(pred_probs).numpy()
  93. gt_list = predicts.label_ids
  94. return get_pred_accuracy(gt_list, pred_probs)
  95. def main():
  96. # nltk.download('stopwords')
  97. # nltk.download('punkt_tab')
  98. global lemmatization
  99. parser = argparse.ArgumentParser()
  100. parser.add_argument('--num_labels', type=int, required=True, default=2, help='Number of classes in 1 label')
  101. parser.add_argument('--classes_names', type=str, required=True, help='Names used to distinguish class values')
  102. parser.add_argument('--label_position', type=int, required=True, help='The label position in CSV file')
  103. parser.add_argument('--output_dir', type=str, required=True)
  104. parser.add_argument('--model', type=str, help='The name of the model to use')
  105. parser.add_argument('--train_batch', type=int, help='Batch size for training')
  106. parser.add_argument('--epochs', type=int, help='Epochs for training')
  107. parser.add_argument('--lr', type=float, help='Learning rate for training')
  108. parser.add_argument('--weight_decay', type=float, help='Weight decay for training')
  109. parser.add_argument('--warmup_steps', type=int, help='Warmup steps for training')
  110. parser.add_argument('--warmup_ratio', type=float, help='Warmup ratio for training')
  111. parser.add_argument('--extra_tokens', type=int, help='Extra tokens')
  112. parser.add_argument('--lemmatization', type=int, help='Lemmatization')
  113. parser.add_argument('--stemming', type=int, help='Stemming')
  114. parser.add_argument('--rem_stop_words', type=int, help='Remove Stop Words')
  115. parser.add_argument('--token_file', type=str, help='Tokens file')
  116. args = parser.parse_args()
  117. extra_tokens = bool(args.extra_tokens) if args.extra_tokens else False
  118. token_file = args.token_file
  119. lemmatization = bool(args.lemmatization) if args.lemmatization else False
  120. stemming = bool(args.stemming) if args.stemming else False
  121. rem_stop_words = bool(args.rem_stop_words) if args.rem_stop_words else False
  122. # Automatic
  123. list_classes = args.classes_names.rsplit(",")
  124. label_position = args.label_position
  125. output_dir = args.output_dir
  126. model_name = args.model if args.model else 'distilbert'
  127. print("### modelName: "+model_name)
  128. num_labels = args.num_labels
  129. train_batch_size = args.train_batch if args.train_batch else 8
  130. test_batch_size = 4
  131. epochs = args.epochs if args.epochs else 4
  132. learning_rate = args.lr if args.lr else 5e-5
  133. weight_decay = args.weight_decay if args.weight_decay else 0
  134. warmup_steps = args.warmup_steps if args.warmup_steps else 0
  135. warmup_ratio = args.warmup_ratio if args.warmup_ratio else 0
  136. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  137. print("### Device: ", device)
  138. os.makedirs(output_dir, exist_ok=True)
  139. ### Select Model
  140. tokenizer, model = select_tokenizer_model(model_name, extra_tokens, token_file, num_labels)
  141. ### Split Dataset
  142. print("### Splitting Dataset")
  143. train_texts, train_labels = read_cvss_csv(r'../dataset/SIR_train_set.csv', label_position, list_classes)
  144. test_texts, test_labels = read_cvss_csv(r'../dataset/SIR_test_set.csv', label_position, list_classes)
  145. ### Remove Stop Words from Sentences
  146. if rem_stop_words:
  147. print("### Remove Stop Words from Sentences")
  148. filtered_train, filtered_test = remove_stop_words(train_texts, test_texts)
  149. ### Lemmatize Sentences
  150. if lemmatization:
  151. print("### Lemmatizing Sentences")
  152. if rem_stop_words:
  153. lemmatized_train, lemmatized_test = lemmatize(filtered_train, filtered_test)
  154. else:
  155. lemmatized_train, lemmatized_test = lemmatize(train_texts, test_texts)
  156. ### Stemmatize Sentences
  157. if stemming:
  158. print("### Stemmatize Sentences")
  159. stemmatized_train, stemmatized_test = stemmatize(train_texts, test_texts)
  160. ### Tokenize Sentences
  161. print("### Tokenizing Sentences")
  162. if lemmatization:
  163. train_encodings = tokenizer(lemmatized_train, truncation=True, padding=True) # truncate to the model max length and pad all sentences to the same size
  164. test_encodings = tokenizer(lemmatized_test, truncation=True, padding=True)
  165. elif rem_stop_words:
  166. train_encodings = tokenizer(filtered_train, truncation=True, padding=True) # truncate to the model max length and pad all sentences to the same size
  167. test_encodings = tokenizer(filtered_test, truncation=True, padding=True)
  168. elif stemming:
  169. train_encodings = tokenizer(stemmatized_train, truncation=True, padding=True) # truncate to the model max length and pad all sentences to the same size
  170. test_encodings = tokenizer(stemmatized_test, truncation=True, padding=True)
  171. else:
  172. train_encodings = tokenizer(train_texts, truncation=True, padding=True) # truncate to the model max length and pad all sentences to the same size
  173. test_encodings = tokenizer(test_texts, truncation=True, padding=True)
  174. ### Dataset Encodings
  175. print("### Encoding Dataset")
  176. train_dataset = CVSSDataset(train_encodings, train_labels)
  177. test_dataset = CVSSDataset(test_encodings, test_labels)
  178. optimizer = torch.optim.Adam(model.parameters(), learning_rate)
  179. scheduler = lr_scheduler.OneCycleLR(
  180. optimizer,
  181. max_lr=learning_rate,
  182. epochs=epochs,
  183. steps_per_epoch=len(train_dataset)
  184. )
  185. print("### Training")
  186. training_args = TrainingArguments(
  187. output_dir=output_dir, # output directory
  188. num_train_epochs=epochs, # total # of training epochs
  189. per_device_train_batch_size=train_batch_size, # batch size per device during training‘
  190. per_device_eval_batch_size=test_batch_size, # batch size for evaluation
  191. learning_rate=learning_rate, # learning rate
  192. save_strategy="epoch",
  193. weight_decay=weight_decay,
  194. warmup_steps=warmup_steps,
  195. warmup_ratio=warmup_ratio,
  196. )
  197. trainer = Trainer(
  198. model=model, # the instantiated 🤗 Transformers model to be trained
  199. args=training_args, # training arguments, defined above
  200. train_dataset=train_dataset, # training dataset
  201. eval_dataset=test_dataset, # evaluation dataset
  202. optimizers=(optimizer,scheduler), # optimizer and scheduler
  203. )
  204. trainer.train()
  205. print(len(train_dataset))
  206. trainer.save_model()
  207. acc = infer(trainer, test_dataset, num_labels)
  208. print("Accuracy = {:.6f}".format(acc))
  209. if __name__ == '__main__':
  210. main()

在信息安全领域,漏洞评估和管理是关键任务之一。本作品探讨了如何利用预训练文本大模型来评估和研判漏洞的严重等级,具体基于通用漏洞评分系统。传统漏洞评分方法依赖于手动分析和专家评审。而基于自然语言处理文本大模型通过其深度学习能力,可以自动化地处理和分析大量的安全相关文本数据,从而提高漏洞评估的效率和准确性。结合词干提取、词性还原能够更好地发挥自然语言处理文本大模型的预测能力与准确度。