You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import csv
  2. from pathlib import Path
  3. import safetensors.torch
  4. from nltk import tokenize
  5. from sklearn.model_selection import train_test_split
  6. from safetensors.torch import load
  7. import torch
  8. from torch.utils.data import DataLoader
  9. from tqdm import tqdm
  10. from CVSSDataset import CVSSDataset, read_cvss_csv, read_cvss_txt
  11. import numpy as np
  12. import argparse
  13. from sklearn.metrics import f1_score, precision_score, recall_score, balanced_accuracy_score, accuracy_score
  14. from lemmatization import lemmatize, lemmatize_noun
  15. from remove_stop_words import remove_stop_words
  16. from stemmatization import stemmatize
  17. import pandas as pd
  18. # -------------------------------------- MODEL -------------------------------------
  19. def load_model(model_path, model):
  20. with open(model_path, "rb") as f:
  21. data = f.read()
  22. loaded = load(data)
  23. model.load_state_dict(loaded)
  24. return model
  25. def select_tokenizer_model(model_name, extra_tokens, token_file, model_path, config_path):
  26. global lemmatization
  27. if model_name == 'distilbert':
  28. from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, DistilBertConfig
  29. config = DistilBertConfig.from_pretrained(config_path)
  30. tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
  31. model = DistilBertForSequenceClassification(config)
  32. elif model_name == 'bert':
  33. from transformers import BertTokenizerFast, BertForSequenceClassification, BertConfig
  34. config = BertConfig.from_pretrained(config_path)
  35. tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
  36. model = BertForSequenceClassification(config)
  37. elif model_name == 'deberta':
  38. from transformers import DebertaConfig, DebertaTokenizerFast, DebertaForSequenceClassification
  39. config = DebertaConfig.from_pretrained(config_path)
  40. tokenizer = DebertaTokenizerFast.from_pretrained('microsoft/deberta-base')
  41. model = DebertaForSequenceClassification(config)
  42. elif model_name == 'albert':
  43. from transformers import AlbertConfig, AlbertTokenizerFast, AlbertForSequenceClassification
  44. config = AlbertConfig.from_pretrained(config_path)
  45. tokenizer = AlbertTokenizerFast.from_pretrained('albert-base-v1')
  46. model = AlbertForSequenceClassification(config)
  47. elif model_name == 'roberta':
  48. from transformers import RobertaConfig, RobertaTokenizerFast, RobertaForSequenceClassification
  49. config = RobertaConfig.from_pretrained(config_path)
  50. tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')
  51. model = RobertaForSequenceClassification(config)
  52. elif model_name == 'Llama':
  53. from transformers import LlamaConfig, LlamaTokenizerFast, LlamaForSequenceClassification
  54. config = LlamaConfig.from_pretrained(config_path)
  55. tokenizer = LlamaTokenizerFast.from_pretrained('meta-llama/Prompt-Guard-86M')
  56. model = LlamaForSequenceClassification(config)
  57. ### Add Tokens
  58. if extra_tokens:
  59. add_tokens_from_file(token_file, tokenizer, lemmatization)
  60. number_tokens = len(tokenizer)
  61. print("### Number of tokens in Tokenizer: " + str(number_tokens))
  62. model.resize_token_embeddings(number_tokens)
  63. return tokenizer, model
  64. def add_tokens_from_file(token_file, tokenizer, lemmatize=False):
  65. print("### Adding Tokens")
  66. file_ = open(token_file, 'r', encoding='UTF-8')
  67. token_list = []
  68. for line in file_:
  69. if lemmatize:
  70. token_list.append(lemmatize_noun(line.rstrip("\n")))
  71. else:
  72. token_list.append(line.rstrip("\n"))
  73. file_.close()
  74. tokenizer.add_tokens(token_list)
  75. # -------------------------------------- METRICS -----------------------------------
  76. def get_pred_accuracy(target, output):
  77. output = output.argmax(axis=1) # -> multi label
  78. tot_right = np.sum(target == output)
  79. tot = target.size
  80. return (tot_right / tot) * 100
  81. def get_accuracy_score(target, output):
  82. return accuracy_score(target, output)
  83. def get_f1_score(target, output):
  84. return f1_score(target, output, average='weighted')
  85. def get_precision_score(target, output):
  86. return precision_score(target, output, average='weighted')
  87. def get_recall_score(target, output):
  88. return recall_score(target, output, average='weighted')
  89. def get_mean_accuracy(target, output):
  90. eps = 1e-20
  91. output = output.argmax(axis=1)
  92. # TP + FN
  93. gt_pos = np.sum((target == 1), axis=0).astype(float)
  94. # TN + FP
  95. gt_neg = np.sum((target == 0), axis=0).astype(float)
  96. # TP
  97. true_pos = np.sum((target == 1) * (output == 1), axis=0).astype(float)
  98. # TN
  99. true_neg = np.sum((target == 0) * (output == 0), axis=0).astype(float)
  100. label_pos_recall = 1.0 * true_pos / (gt_pos + eps) # true positive
  101. label_neg_recall = 1.0 * true_neg / (gt_neg + eps) # true negative
  102. # mean accuracy
  103. return (label_pos_recall + label_neg_recall) / 2
  104. def get_balanced_accuracy(target, output):
  105. return balanced_accuracy_score(target, output)
  106. def disLabel(classNames, pred_props, labelPosition):
  107. data = []
  108. for i in pred_props:
  109. data.append(classNames[i])
  110. file_name = "../data_process_cache/output.csv"
  111. with open(file_name, mode='w', newline='') as file:
  112. writer = csv.writer(file)
  113. writer.writerows([data])
  114. df = pd.read_csv('../data_process_cache/output.csv', header=None)
  115. if labelPosition > 1:
  116. df_exist = pd.read_csv('output/output1_last.csv', header=None)
  117. # print(df_exist.head(10))
  118. df_exist_transposed = df_exist.T
  119. # print(df_exist_transposed.head(10))
  120. df_combined = pd.concat([df_exist_transposed, df], ignore_index=True)
  121. df_combined = df_combined.T
  122. df_combined.to_csv('output/output1_last.csv', index=False, header=False)
  123. else:
  124. df_transposed = df.T
  125. df_transposed.to_csv('output/output1_last.csv', index=False, header=False)
  126. return
  127. # -------------------------------------- MAIN -----------------------------------
  128. def main():
  129. global lemmatization
  130. parser = argparse.ArgumentParser()
  131. parser.add_argument('--classes_names', type=str, required=True, help='Names used to distinguish class values')
  132. parser.add_argument('--label_position', type=int, required=True, help='The label position in CSV file')
  133. parser.add_argument('--root_dir', type=str, required=True, help='Path to model and config files')
  134. parser.add_argument('--model', type=str, help='The name of the model to use')
  135. parser.add_argument('--test_batch', type=int, help='Batch size for test')
  136. parser.add_argument('--extra_tokens', type=int, help='Extra tokens')
  137. parser.add_argument('--lemmatization', type=int, help='Lemmatization')
  138. parser.add_argument('--stemming', type=int, help='Stemming')
  139. parser.add_argument('--rem_stop_words', type=int, help='Remove Stop Words')
  140. parser.add_argument('--token_file', type=str, help='Tokens file')
  141. args = parser.parse_args()
  142. string = args.classes_names
  143. classNames = string.split(',')
  144. labelPosition = args.label_position
  145. print(classNames)
  146. model_name = args.model if args.model else 'distilbert'
  147. extra_tokens = bool(args.extra_tokens) if args.extra_tokens else False
  148. token_file = args.token_file
  149. lemmatization = bool(args.lemmatization) if args.lemmatization else False
  150. stemming = bool(args.stemming) if args.stemming else False
  151. rem_stop_words = bool(args.rem_stop_words) if args.rem_stop_words else False
  152. root_dir = args.root_dir
  153. model_path = root_dir + 'model.safetensors'
  154. config_path = root_dir + 'config.json'
  155. batch_size = args.test_batch if args.test_batch else 2
  156. list_classes = args.classes_names.rsplit(",")
  157. label_position = args.label_position
  158. print("### modelName: " + model_name)
  159. # device = torch.device('cpu')
  160. device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  161. print("### Device: ", device)
  162. ### Select Model
  163. tokenizer, model = select_tokenizer_model(model_name, extra_tokens, token_file, model_path, config_path)
  164. ### Load Dataset
  165. print("### Loading Dataset")
  166. test_texts, test_labels = read_cvss_csv(r'E:\pythonProject_open\dataset_more\test.csv', label_position, list_classes)
  167. ### Lemmatize Sentences
  168. if lemmatization:
  169. print("### Lemmatizing Sentences")
  170. lemmatized_test, _ = lemmatize(test_texts)
  171. if stemming:
  172. print("### Stemmatize Sentences")
  173. stemmatized_test, _ = stemmatize(test_texts)
  174. if rem_stop_words:
  175. print("### Remove Stop Words from Sentences")
  176. filtered_test, _ = remove_stop_words(test_texts)
  177. ### Tokenize Sentences
  178. print("### Tokenizing Sentences")
  179. if lemmatization:
  180. test_encodings = tokenizer(lemmatized_test, truncation=True, padding=True)
  181. elif stemming:
  182. test_encodings = tokenizer(stemmatized_test, truncation=True, padding=True)
  183. elif rem_stop_words:
  184. test_encodings = tokenizer(filtered_test, truncation=True, padding=True)
  185. else:
  186. test_encodings = tokenizer(test_texts, truncation=True, padding=True)
  187. ### Dataset Encodings
  188. test_dataset = CVSSDataset(test_encodings, test_labels)
  189. print("### Dataset Encodings")
  190. model = load_model(model_path, model)
  191. model.to(device)
  192. test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
  193. model.eval()
  194. pred_probs = []
  195. gt_list = []
  196. for batch in tqdm(test_loader):
  197. input_ids = batch['input_ids'].to(device)
  198. attention_mask = batch['attention_mask'].to(device)
  199. labels = batch['labels'].to(device)
  200. outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
  201. soft = torch.nn.Softmax(dim=1)
  202. output_logits = soft(outputs.logits)
  203. gt_list.append(labels.cpu().detach().numpy())
  204. pred_probs.append(output_logits.cpu().detach().numpy())
  205. # print(pred_probs) # 0001001000
  206. gt_list = np.concatenate(gt_list, axis=0)
  207. pred_probs = np.concatenate(pred_probs, axis=0)
  208. pred_probs = pred_probs.argmax(axis=1)
  209. # print(pred_probs)
  210. disLabel(classNames, pred_probs,labelPosition)
  211. print(
  212. "Accuracy = {:.6f} F1-score = {:.6f} Precision = {:.6f} Recall = {:.6f} mean Accuracy = {:.6f}".format(
  213. get_accuracy_score(gt_list, pred_probs), get_f1_score(gt_list, pred_probs),
  214. get_precision_score(gt_list, pred_probs), get_recall_score(gt_list, pred_probs),
  215. balanced_accuracy_score(gt_list, pred_probs)))
  216. if __name__ == '__main__':
  217. main()

在信息安全领域,漏洞评估和管理是关键任务之一。本作品探讨了如何利用预训练文本大模型来评估和研判漏洞的严重等级,具体基于通用漏洞评分系统。传统漏洞评分方法依赖于手动分析和专家评审。而基于自然语言处理文本大模型通过其深度学习能力,可以自动化地处理和分析大量的安全相关文本数据,从而提高漏洞评估的效率和准确性。结合词干提取、词性还原能够更好地发挥自然语言处理文本大模型的预测能力与准确度。