| @@ -0,0 +1,149 @@ | |||
| import os | |||
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |||
| import jittor as jt | |||
| from tqdm import tqdm | |||
| from tabulate import tabulate | |||
| from tensorboardX import SummaryWriter | |||
| from utils.logger import setup_logger | |||
| from utils.helper import (EarlyStop, get_optimizer, get_save_path, | |||
| accuracy, compute_loss) | |||
| from dataset import get_dataloader | |||
| from jittor.lr_scheduler import CosineAnnealingLR | |||
| from test_clip import zeroshot_classifier | |||
| from models import load_clip | |||
| jt.flags.use_cuda = 1 | |||
| class Args: | |||
| """设置训练参数""" | |||
| def __init__(self): | |||
| self.seed = 123 # 随机种子 | |||
| self.optimizer = 'Adan' | |||
| self.lr = 3e-6 | |||
| self.betas = (0.9, 0.98, 0.98) | |||
| self.eps = 1e-8 | |||
| self.weight_decay = 0.2 | |||
| self.batch_size = 256 | |||
| self.num_workers = 8 # 导入数据使用的线程数 | |||
| self.epochs = 100 # 训练总的轮数 | |||
| self.early_stop = False | |||
| self.patience = 10 # 早停“忍耐”的次数 | |||
| self.delta = 0.0001 # 早停的阈值 | |||
| self.caption_version = 2 # 图片描述使用的版本 | |||
| # 1. a photo of xxx | |||
| # 2. 针对数据集自定义的描述 | |||
| self.data_augment = 0 # 数据增强版本 | |||
| self.model_save_path = 'ckptFE' | |||
| self.log_save_path = 'logs' | |||
| self.compute_acc_frequency = 0 # 每隔多少个epoch计算一次训练集acc,0表示不计算 | |||
| self.use_scheduler = False # 是否使用学习率调度策略 | |||
| self.save_frequency = 5 # 模型保存频率,每隔多少个epoch进行保存 | |||
| self.freeze_version = 2 | |||
| def __str__(self): | |||
| # 将参数转换为字典 | |||
| args_dict = self.__dict__ | |||
| # 将元组转换为列表,以便tabulate可以正确处理 | |||
| if isinstance(args_dict['betas'], tuple): | |||
| args_dict['betas'] = list(args_dict['betas']) | |||
| # 使用tabulate生成表格 | |||
| table = tabulate([(key, value) for key, value in args_dict.items()], | |||
| tablefmt="grid", headers=["Parameter", "Value"]) | |||
| return table | |||
| def train(model, optimizer, train_loader, scheduler, args, log_dir, save_dir, logger): | |||
| print('model parameters: \n', args) | |||
| logger.info('model parameters: \n') | |||
| logger.info(args) | |||
| logger.info('\n\nStart training...') | |||
| if args.early_stop: | |||
| early_stop = EarlyStop(patience=args.patience, delta=args.delta) | |||
| writer = SummaryWriter(log_dir=log_dir) | |||
| min_loss = float('inf') | |||
| acc = None | |||
| zeroshot_weights = zeroshot_classifier(model) | |||
| pbar = tqdm(range(1, args.epochs+1)) | |||
| for epoch in pbar: | |||
| model.train() | |||
| running_loss = 0.0 | |||
| for images, label, texts in train_loader: | |||
| logits_per_image, logits_per_text = model(images, texts) | |||
| loss = compute_loss(logits_per_image, logits_per_text) | |||
| optimizer.step(loss) | |||
| running_loss += loss.item() | |||
| if args.compute_acc_frequency != 0 and epoch % args.compute_acc_frequency == 0 : | |||
| acc = accuracy(model, train_loader, zeroshot_weights) | |||
| if args.use_scheduler: | |||
| scheduler.step() | |||
| if running_loss < min_loss: | |||
| min_loss = running_loss | |||
| jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'min_loss'))) | |||
| if epoch % args.save_frequency == 0: | |||
| jt.save(model.state_dict(), '{}.pth'.format(os.path.join(save_dir, 'epoch_{}'.format(epoch)))) | |||
| if acc is not None: | |||
| # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") | |||
| writer.add_scalar('Train/Acc', acc, epoch) | |||
| logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}, Acc: {acc:.4f}") | |||
| else: | |||
| # pbar.set_description(f"Epoch: {epoch}, Loss: {running_loss:.4f}, Min Loss: {min_loss:.4f}") | |||
| logger.info(f"Epoch: {epoch}, Loss: {running_loss:.4f} Min Loss: {min_loss:.4f}") | |||
| writer.add_scalar('Train/Loss', running_loss, epoch) | |||
| if args.early_stop and early_stop(running_loss): | |||
| logger.info(f'Early stop triggered..., epoch: {epoch}') | |||
| # print(f'early stop triggered..., epoch: {epoch}') | |||
| break | |||
| pbar.close() | |||
| logger.info('\n\nFinish training...') | |||
| writer.close() | |||
| def main(args): | |||
| model, transforms = load_clip(freeze_version=args.freeze_version) | |||
| # model, transforms = clip.load('pretrained/ViT-B-32.pkl') | |||
| train_loader = get_dataloader(transforms, args.batch_size, | |||
| args.num_workers, shuffle=True, version=args.caption_version) | |||
| optimizer = get_optimizer(args, model) | |||
| scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.lr * 0.1) | |||
| model_save_path = get_save_path(args.model_save_path, args.optimizer) | |||
| log_save_path = get_save_path(args.log_save_path, args.optimizer) | |||
| logger = setup_logger(log_save_path) | |||
| print(f'Model will be saved at {model_save_path}') | |||
| logger.info('Model will be saved at {}'.format(model_save_path)) | |||
| train(model, optimizer, train_loader, scheduler, | |||
| args, log_save_path, model_save_path, logger) | |||
| if __name__ == "__main__": | |||
| args = Args() | |||
| jt.misc.set_global_seed(args.seed) | |||
| main(args) | |||