You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

logger.py 1.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import os
  2. import sys
  3. import logging
  4. from time import strftime
  5. # 设置日志格式#和时间格式
  6. # FMT = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s: %(message)s'
  7. FMT = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s: %(message)s'
  8. DATEFMT = '%Y-%m-%d %H:%M:%S'
  9. class MyLog(object):
  10. def __init__(self, log_path, type_='train'):
  11. self.logger = logging.getLogger()
  12. self.formatter = logging.Formatter(fmt=FMT, datefmt=DATEFMT)
  13. self.log_filename = os.path.join(log_path, 'train.log' if type_=='train' else 'test.log')
  14. # self.log_filename = '{0}{1}.log'.format(log_path, strftime("%Y-%m-%d"))
  15. # 输出到文件
  16. self.logger.addHandler(self.get_file_handler(self.log_filename))
  17. # 输出到控制台
  18. # self.logger.addHandler(self.get_console_handler())
  19. # 设置日志的默认级别
  20. # 打印DEBUG级别以及以上的日志
  21. # 级别排序为:CRITICAL > ERROR > WARNING > INFO > DEBUG
  22. self.logger.setLevel(logging.INFO)
  23. # 输出到文件handler的函数定义
  24. def get_file_handler(self, filename):
  25. filehandler = logging.FileHandler(filename, encoding="utf-8")
  26. filehandler.setFormatter(self.formatter)
  27. return filehandler
  28. # 输出到控制台handler的函数定义
  29. def get_console_handler(self):
  30. console_handler = logging.StreamHandler(sys.stdout)
  31. console_handler.setFormatter(self.formatter)
  32. return console_handler
  33. def setup_logger(log_path=None, type_='train'):
  34. # 输出日志路径
  35. if log_path is not None:
  36. os.makedirs(log_path, exist_ok=True)
  37. return MyLog(log_path, type_).logger
  38. log_path = os.path.abspath('.') + '/logs/'
  39. return MyLog(log_path, type_).logger

冻结ViT-B/32版本的CLIP模型中的全部图像层,用Adan优化器训练模型,训练100个epoch,每隔5个epoch对模型进行保存;完成CLIP模型训练后,运行test_clip.py用测试集中的数据和自定义的提示词对保存的模型进行测试,选取测试精度最好的模型和对应的提示词,运行predict.py文件,选择“min_loss.pth”模型,提交官方系统测试,top1的精度是0.6788。

Contributors (1)