From e8d218dae885ea99f5dafd2ca7179d4f005fa434 Mon Sep 17 00:00:00 2001 From: Zephyr Date: Tue, 20 Aug 2024 19:23:07 +0800 Subject: [PATCH] Initial --- LICENSE | 18 +- README.md | 35 +- dreambooth/README.md | 35 ++ dreambooth/read_json.py | 19 + dreambooth/run_all.py | 54 ++ dreambooth/settings/color.json | 17 + dreambooth/settings/gauss_statistics.json | 62 ++ dreambooth/settings/style.json | 17 + dreambooth/settings/texture.json | 17 + dreambooth/test_all.sh | 6 + dreambooth/train.py | 730 ++++++++++++++++++++++ dreambooth/train_all.sh | 63 ++ 12 files changed, 1069 insertions(+), 4 deletions(-) create mode 100644 dreambooth/README.md create mode 100644 dreambooth/read_json.py create mode 100644 dreambooth/run_all.py create mode 100644 dreambooth/settings/color.json create mode 100644 dreambooth/settings/gauss_statistics.json create mode 100644 dreambooth/settings/style.json create mode 100644 dreambooth/settings/texture.json create mode 100644 dreambooth/test_all.sh create mode 100644 dreambooth/train.py create mode 100644 dreambooth/train_all.sh diff --git a/LICENSE b/LICENSE index de14edb..ed917bf 100644 --- a/LICENSE +++ b/LICENSE @@ -2,8 +2,20 @@ MIT License Copyright (c) 2024 yohu-cqu -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 614bdc6..80b648a 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,35 @@ -# jittor-SimpleTest-StableDiffusion-A +# DreamBooth-Lora +本项目参考自 [JDiffusion 的 DreamBooth-Lora](https://github.com/JittorRepos/JDiffusion/tree/master/examples/dreambooth) 。 + +## 环境安装 + +按照 [JDiffusion 的环境安装一节](https://github.com/JittorRepos/JDiffusion/blob/master/examples/dreambooth/README.md) 安装必要的依赖。 + +接着设置运行脚本的权限: +``` +chmod u+x ./dreambooth/*.sh +``` + +## 训练 + +1. 首先从比赛云盘下载对应的数据集,推荐将数据集目录 `A` 下载到 `./A` 下 +2. 将 `train_all.sh` 中的 `HF_HOME` 设置为本地模型路径, `root` 设置为项目所在目录, `BASE_INSTANCE_DIR` 设置为数据集对应的目录,`GPU_COUNT` 设置为对应可用的显卡数量,`MAX_NUM` 设置为数据集中的风格个数; +3. 然后进入目标文件夹下: `cd ./dreambooth/`, 运行 `bash train_all.sh` 即可训练。保存的模型会存放至 `./dreambooth/results/prompt_v1_color_test1/style_[训练epoch数]epoch` 目录下,例如:`./dreambooth/results/prompt_v1_color_test1/style_300epoch`。 + +## 推理 + +1. 将 `test_all.sh` 中的 `HF_HOME` 设置为本地模型路径,将 `run_all.py` 中的 `root` 设置为项目所在目录, `dataset_root` 修改为数据集对应的目录,将 `max_num` 修改为数据集中的风格个数; +2. 进入目标文件夹下: `cd ./dreambooth/`,运行 `bash test_all.sh` 进行推理。模型生成的图片会输出到 `./dreambooth/results/prompt_v1_color_test1/outputs_[保存点训练epoch数]ckpt_[推理轮数]steps_[种子值]seed` 文件夹下,例如:`./dreambooth/results/prompt_v1_color_test1/outputs_300ckpt_500steps_76587seed`。 + + +## 参考文献 + +``` +@inproceedings{ruiz2023dreambooth, + title={Dreambooth: Fine tuning text-to-image diffusion models for subject-driven generation}, + author={Ruiz, Nataniel and Li, Yuanzhen and Jampani, Varun and Pritch, Yael and Rubinstein, Michael and Aberman, Kfir}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2023} +} +``` diff --git a/dreambooth/README.md b/dreambooth/README.md new file mode 100644 index 0000000..332845e --- /dev/null +++ b/dreambooth/README.md @@ -0,0 +1,35 @@ +# DreamBooth-Lora + +本项目参考自 [JDiffusion 的 DreamBooth-Lora](https://github.com/JittorRepos/JDiffusion/tree/master/examples/dreambooth) 。 + +## 环境安装 + +按照 [JDiffusion 的环境安装一节](https://github.com/JittorRepos/JDiffusion/blob/master/examples/dreambooth/README.md) 安装必要的依赖。 + +接着设置运行脚本的权限: +``` +chmod u+x ./dreambooth/*.sh +``` + +## 训练 + +1. 首先从比赛云盘下载对应的数据集,推荐将数据集目录 `A` 下载到 `./dreambooth/A/` 下 +2. 将 `train_all.sh` 中的 `HF_HOME` 设置为本地模型路径, `root` 设置为项目所在目录, `BASE_INSTANCE_DIR` 设置为数据集对应的目录,`GPU_COUNT` 设置为对应可用的显卡数量,`MAX_NUM` 设置为数据集中的风格个数; +3. 然后进入目标文件夹下: `cd ./dreambooth/`, 运行 `bash train_all.sh` 即可训练,保存的模型会存放至 `./dreambooth/results/prompt_v1_color_test1/style_[训练epoch数]epoch` 目录下,例如:`./dreambooth/results/prompt_v1_color_test1/style_300epoch`。 + +## 推理 + +1. 将 `test_all.sh` 中的 `HF_HOME` 设置为本地模型路径,将 `run_all.py` 中的 `root` 设置为项目所在目录, `dataset_root` 修改为数据集对应的目录,将 `max_num` 修改为数据集中的风格个数; +2. 进入目标文件夹下: `cd ./dreambooth/`,运行 `bash test_all.sh` 进行训练,对应的图片会输出到 `./dreambooth/results/prompt_v1_color_test1/outputs_[保存点训练epoch数]ckpt_[推理轮数]steps_[种子值]seed` 文件夹下,例如:`./dreambooth/results/prompt_v1_color_test1/outputs_300ckpt_500steps_76587seed`。 + + +## 参考文献 + +``` +@inproceedings{ruiz2023dreambooth, + title={Dreambooth: Fine tuning text-to-image diffusion models for subject-driven generation}, + author={Ruiz, Nataniel and Li, Yuanzhen and Jampani, Varun and Pritch, Yael and Rubinstein, Michael and Aberman, Kfir}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2023} +} +``` \ No newline at end of file diff --git a/dreambooth/read_json.py b/dreambooth/read_json.py new file mode 100644 index 0000000..c3f1e6c --- /dev/null +++ b/dreambooth/read_json.py @@ -0,0 +1,19 @@ +import json +import sys + +def get_value_from_json(file_path, key): + with open(file_path, 'r') as f: + data = json.load(f) + if key in data: + return data[key] + else: + return None + +if __name__ == "__main__": + file_path = sys.argv[1] + key = sys.argv[2] + value = get_value_from_json(file_path, key) + if value is not None: + print(value) + else: + sys.exit(1) diff --git a/dreambooth/run_all.py b/dreambooth/run_all.py new file mode 100644 index 0000000..3e99690 --- /dev/null +++ b/dreambooth/run_all.py @@ -0,0 +1,54 @@ +import os, sys + +os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +import json, tqdm, torch +import jittor as jt +jt.flags.use_rocm = 1 + + +from JDiffusion.pipelines import StableDiffusionPipeline +# from mypipeline_stable_diffusion_jittor import StableDiffusionPipeline + +root = "/home/user1/jittor2024/jittor-A-commit" +save_root = f"{root}/dreambooth/results/" + "prompt_v1_color_test1" + +max_num = 15 +inference_steps = 500 +checkpoint_epoch = 300 +seed = 76587 +jt.set_global_seed(seed) +dataset_root = f"{root}/A/" +style_file = f"{root}/dreambooth/settings/style.json" +texture_file = f"{root}/dreambooth/settings/texture.json" +color_file = f"{root}/dreambooth/settings/color.json" + +with open(style_file, "r") as f: + style_dict = json.load(f) + +with open(texture_file, "r") as f: + texture_dict = json.load(f) + +with open(color_file, "r") as f: + color_dict = json.load(f) + +pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1").to("cuda") +with torch.no_grad(): + for tempid in tqdm.tqdm(range(0, max_num)): + taskid = "{:0>2d}".format(tempid) + if checkpoint_epoch is None: + pipe.load_lora_weights(os.path.join(save_root, f"style/style_{taskid}")) + else: + pipe.load_lora_weights(os.path.join(save_root, f"style_{checkpoint_epoch}epoch/style_{taskid}")) + + # load json + with open(f"{dataset_root}/{taskid}/prompt.json", "r") as file: + prompts = json.load(file) + + for id, prompt in prompts.items(): + new_prompt = f"A photo of {prompt} in {style_dict[taskid]} style, with a texture of {texture_dict[taskid]} and with a color style of {color_dict[taskid]}." + print(new_prompt) + image = pipe(prompt=new_prompt, num_inference_steps=inference_steps, width=512, height=512, seed=seed).images[0] + os.makedirs(os.path.join(save_root, f"outputs_{checkpoint_epoch}ckpt_{inference_steps}steps_{seed}seed/{taskid}"), exist_ok=True) + image.save(os.path.join(save_root, f"outputs_{checkpoint_epoch}ckpt_{inference_steps}steps_{seed}seed/{taskid}/{prompt}.png")) diff --git a/dreambooth/settings/color.json b/dreambooth/settings/color.json new file mode 100644 index 0000000..a0f7b3f --- /dev/null +++ b/dreambooth/settings/color.json @@ -0,0 +1,17 @@ +{ + "00": "which is dominated by shades of blue and white, with accents in yellow", + "01": "mainly bright and contrasting colors such as orange, red, yellow, blue, white, gray, and black", + "02": "which is characterized by muted, earthy tones and soft gradients and palette primarily including shades of brown, gray, beige, and subtle hints of orange and gree", + "03": "a vibrant and geometric color palette, featuring bright hues such as reds, blues, greens, yellows, and purples", + "04": "which is featured in muted colors such as light blues, greens, pinks, and yellows", + "05": "a rich palette including warm earth tones like browns and oranges, cool blues, greens, and reds", + "06": "an emphasis on rich blues, purples, pinks, and reds", + "07": "a warm, dark contrast and orange-yellow highlights", + "08": "a common black and white chalkboard and monochromatic", + "09": "which is dominated by colors include shades of orange, yellow, brown, black, white, and hints of blue", + "10": "a mix of cool tones like blue, purple, and pink", + "11": "earthy tones, muted grays and browns", + "12": "sepia tones, vintage textures, and monochromatic hues", + "13": "featuring bold blues, oranges, and greys", + "14": "a combination of warm tones, primarily pink and beige backgrounds with white origami objects" +} \ No newline at end of file diff --git a/dreambooth/settings/gauss_statistics.json b/dreambooth/settings/gauss_statistics.json new file mode 100644 index 0000000..44cdd90 --- /dev/null +++ b/dreambooth/settings/gauss_statistics.json @@ -0,0 +1,62 @@ +{ + "00": { + "mean_val": 0.0009739127941429615, + "std_val": 0.5314764261245728 + }, + "01": { + "mean_val": 0.46153217256069184, + "std_val": 0.7191255629062653 + }, + "02": { + "mean_val": 0.29199431985616686, + "std_val": 0.7220630288124085 + }, + "03": { + "mean_val": 0.33343668580055236, + "std_val": 0.5767290353775024 + }, + "04": { + "mean_val": 0.218064521253109, + "std_val": 0.5959844827651978 + }, + "05": { + "mean_val": 0.18604852110147477, + "std_val": 0.558800458908081 + }, + "06": { + "mean_val": 0.22389721274375915, + "std_val": 0.6452812850475311 + }, + "07": { + "mean_val": -0.09157774895429611, + "std_val": 1.0858253061771392 + }, + "08": { + "mean_val": -0.045120977237820624, + "std_val": 0.6651732742786407 + }, + "09": { + "mean_val": 0.08461952535435557, + "std_val": 0.670553731918335 + }, + "10": { + "mean_val": -0.03910434106364846, + "std_val": 0.665828937292099 + }, + "11": { + "mean_val": 0.02903309902176261, + "std_val": 0.5030718982219696 + }, + "12": { + "mean_val": 0.22357304841279985, + "std_val": 0.5448532164096832 + }, + "13": { + "mean_val": 0.1572735294699669, + "std_val": 0.5102830111980439 + }, + "14": { + "mean_val": 0.16023400127887727, + "std_val": 0.6309724390506745 + } +} \ No newline at end of file diff --git a/dreambooth/settings/style.json b/dreambooth/settings/style.json new file mode 100644 index 0000000..7661e0b --- /dev/null +++ b/dreambooth/settings/style.json @@ -0,0 +1,17 @@ +{ + "00": "Impressionist", + "01": "Geometric, Modern, Low-Poly, and Pop Art", + "02": "Watercolor", + "03": "Low-Poly, Abstract, Digital Art", + "04": "Watercolor, Illustration, Digital Art", + "05": "Realistic, Mosaic, and Artistic", + "06": "Digital Pixel Art", + "07": "Fiery, Dramatic", + "08": "Chalkboard Art", + "09": "Expressionist and Impressionist", + "10": "Low-Poly, Faceted, 3D Rendering and Geometric", + "11": "Relief", + "12": "Vintage, Hand-drawn, Sepia-toned and Weathered", + "13": "Impressionist, Expressionist, and Post-Impressionist", + "14": "Origami" +} \ No newline at end of file diff --git a/dreambooth/settings/texture.json b/dreambooth/settings/texture.json new file mode 100644 index 0000000..7661e0b --- /dev/null +++ b/dreambooth/settings/texture.json @@ -0,0 +1,17 @@ +{ + "00": "Impressionist", + "01": "Geometric, Modern, Low-Poly, and Pop Art", + "02": "Watercolor", + "03": "Low-Poly, Abstract, Digital Art", + "04": "Watercolor, Illustration, Digital Art", + "05": "Realistic, Mosaic, and Artistic", + "06": "Digital Pixel Art", + "07": "Fiery, Dramatic", + "08": "Chalkboard Art", + "09": "Expressionist and Impressionist", + "10": "Low-Poly, Faceted, 3D Rendering and Geometric", + "11": "Relief", + "12": "Vintage, Hand-drawn, Sepia-toned and Weathered", + "13": "Impressionist, Expressionist, and Post-Impressionist", + "14": "Origami" +} \ No newline at end of file diff --git a/dreambooth/test_all.sh b/dreambooth/test_all.sh new file mode 100644 index 0000000..59b0be4 --- /dev/null +++ b/dreambooth/test_all.sh @@ -0,0 +1,6 @@ +#!/bin/bash +export HF_ENDPOINT="https://hf-mirror.com" +export HF_HOME="/home/user1/jittor2024/JDiffusion/cached_path" + +COMMAND="python run_all.py" +eval $COMMAND & \ No newline at end of file diff --git a/dreambooth/train.py b/dreambooth/train.py new file mode 100644 index 0000000..a62ca81 --- /dev/null +++ b/dreambooth/train.py @@ -0,0 +1,730 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import jittor as jt + +jt.flags.use_rocm = 1 +import jittor.nn as nn +import argparse +import copy +import logging +import math +import os +import warnings +from pathlib import Path + +import numpy as np +import transformers +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from jittor import transform +from jittor.compatibility.optim import AdamW +from jittor.compatibility.utils.data import Dataset, DataLoader +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from JDiffusion import ( + AutoencoderKL, + UNet2DConditionModel, +) +from diffusers import DDPMScheduler +from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.utils import ( + convert_state_dict_to_diffusers, +) + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + return CLIPTextModel + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--num_process", + type=int, + default=1 + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=5, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--tokenizer_max_length", + type=int, + default=None, + required=False, + help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", + ) + parser.add_argument( + "--text_encoder_use_attention_mask", + action="store_true", + required=False, + help="Whether to use attention mask for the text encoder", + ) + parser.add_argument( + "--validation_images", + required=False, + default=None, + nargs="+", + help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", + ) + parser.add_argument( + "--class_labels_conditioning", + required=False, + default=None, + help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--checkpoint_save_steps", + type=int, + default=500, + help="Frequency number of saving checkpoint steps to perform.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + class_num=None, + size=512, + center_crop=False, + encoder_hidden_states=None, + class_prompt_encoder_hidden_states=None, + tokenizer_max_length=None, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + self.encoder_hidden_states = encoder_hidden_states + self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states + self.tokenizer_max_length = tokenizer_max_length + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transform.Compose( + [ + transform.Resize(size), + transform.CenterCrop(size) if center_crop else transform.RandomCrop(size), + transform.ToTensor(), + transform.ImageNormalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = exif_transpose(instance_image) + prompt = "A photo of {}".format( + str(self.instance_images_path[index % self.num_instance_images]). + split("/")[-1].split(".")[-2].replace("_", " ")) + self.instance_prompt + print("Prompt:", prompt) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.encoder_hidden_states is not None: + example["instance_prompt_ids"] = self.encoder_hidden_states + else: + text_inputs = tokenize_prompt( + self.tokenizer, prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["instance_prompt_ids"] = text_inputs.input_ids + example["instance_attention_mask"] = text_inputs.attention_mask + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + + if self.class_prompt_encoder_hidden_states is not None: + example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states + else: + class_text_inputs = tokenize_prompt( + self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length + ) + example["class_prompt_ids"] = class_text_inputs.input_ids + example["class_attention_mask"] = class_text_inputs.attention_mask + + return example + + +def collate_fn(examples, with_prior_preservation=False): + has_attention_mask = "instance_attention_mask" in examples[0] + + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + if has_attention_mask: + attention_mask = [example["instance_attention_mask"] for example in examples] + + pixel_values = jt.stack(pixel_values) + pixel_values = pixel_values.float() + + input_ids = jt.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + + if has_attention_mask: + batch["attention_mask"] = attention_mask + + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): + if tokenizer_max_length is not None: + max_length = tokenizer_max_length + else: + max_length = tokenizer.model_max_length + + text_inputs = tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=max_length, + return_tensors="pt", + ) + + return text_inputs + + +def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): + text_input_ids = input_ids.to(text_encoder.device) + + if text_encoder_use_attention_mask: + attention_mask = attention_mask.to(text_encoder.device) + else: + attention_mask = None + + prompt_embeds = text_encoder( + text_input_ids, + attention_mask=attention_mask, + return_dict=False, + ) + prompt_embeds = prompt_embeds[0] + + return prompt_embeds + + +def main(args): + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + + # Handle the repository creation + # if args.output_dir is not None: + # os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision + ) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + + # We only train the additional adapter LoRA layers + # if vae is not None: + # vae.requires_grad_(False) + # text_encoder.requires_grad_(False) + # unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = jt.float32 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # unet.to("cuda", dtype=weight_dtype) + # if vae is not None: + # vae.to("cuda", dtype=weight_dtype) + # text_encoder.to("cuda", dtype=weight_dtype) + + for name, param in unet.named_parameters(): + assert param.requires_grad == False, name + # now we will add new LoRA weights to the attention layers + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"], + ) + unet.add_adapter(unet_lora_config) + + # Optimizer creation + + optimizer = AdamW( + list(unet.parameters()), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + pre_computed_encoder_hidden_states = None + pre_computed_class_prompt_encoder_hidden_states = None + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + class_num=args.num_class_images, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + encoder_hidden_states=pre_computed_encoder_hidden_states, + class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, + tokenizer_max_length=args.tokenizer_max_length, + ) + + train_dataloader = DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, False), + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.num_process, + num_training_steps=args.max_train_steps * args.num_process, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + tracker_config = vars(copy.deepcopy(args)) + tracker_config.pop("validation_images") + + # Train! + total_batch_size = args.train_batch_size * args.num_process * args.gradient_accumulation_steps + + print("***** Running training *****") + print(f" Num examples = {len(train_dataset)}") + print(f" Num batches each epoch = {len(train_dataloader)}") + print(f" Num Epochs = {args.num_train_epochs}") + print(f" Instantaneous batch size per device = {args.train_batch_size}") + print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + print(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=False, + ) + + # print("epoch_params:", first_epoch, args.num_train_epochs, args.checkpoint_save_steps) + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + + # Sample noise that we'll add to the latents + noise = jt.randn_like(model_input) + bsz, channels, height, width = model_input.shape + # Sample a random timestep for each image + timesteps = jt.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), + ).to(device=model_input.device) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = encode_prompt( + text_encoder, + batch["input_ids"], + batch["attention_mask"], + text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, + ) + + if unet.config.in_channels == channels * 2: + noisy_model_input = jt.cat([noisy_model_input, noisy_model_input], dim=1) + + if args.class_labels_conditioning == "timesteps": + class_labels = timesteps + else: + class_labels = None + + # Predict the noise residual + model_pred = unet( + noisy_model_input, + timesteps, + encoder_hidden_states, + class_labels=class_labels, + return_dict=False, + )[0] + + # if model predicts variance, throw away the prediction. we will only train on the + # simplified training objective. This means that all schedulers using the fine tuned + # model must be configured to use one of the fixed variance variance types. + if model_pred.shape[1] == 6: + model_pred, _ = jt.chunk(model_pred, 2, dim=1) + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + loss = nn.mse_loss(model_pred, target) + loss.backward() + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item()} + # logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= args.max_train_steps: + break + + # print("epoch:", epoch, args.checkpoint_save_steps) + if (epoch + 1) % (args.checkpoint_save_steps / num_update_steps_per_epoch) == 0: + # Save the lora layers + unet = unet.to(jt.float32) + unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + + base_dir, style_dir = os.path.split(args.output_dir) + style_output_dir = os.path.join(base_dir + f"_{epoch + 1}epoch", style_dir) + print("style_output_dir:", style_output_dir) + if style_output_dir is not None: + os.makedirs(style_output_dir, exist_ok=True) + text_encoder_state_dict = None + LoraLoaderMixin.save_lora_weights( + save_directory=style_output_dir, + unet_lora_layers=unet_lora_state_dict, + text_encoder_lora_layers=text_encoder_state_dict, + safe_serialization=False + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/dreambooth/train_all.sh b/dreambooth/train_all.sh new file mode 100644 index 0000000..40d27ee --- /dev/null +++ b/dreambooth/train_all.sh @@ -0,0 +1,63 @@ +#!/bin/bash +export HF_ENDPOINT="https://hf-mirror.com" +export HF_HOME="/home/user1/jittor2024/JDiffusion/cached_path" + +root="/home/user1/jittor2024/jittor-A-commit" +save_root="${root}/dreambooth/results/prompt_v1_color_test1" +style_file="${root}/dreambooth/settings/style.json" +texture_file="${root}/dreambooth/settings/texture.json" +color_file="${root}/dreambooth/settings/color.json" + +MODEL_NAME="stabilityai/stable-diffusion-2-1" +BASE_INSTANCE_DIR="${root}/A" +OUTPUT_DIR_PREFIX="${save_root}/style/style_" +RESOLUTION=512 +TRAIN_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=1 +CHECKPOINTING_STEPS=500 +LEARNING_RATE=1e-4 +LR_SCHEDULER="constant" +LR_WARMUP_STEPS=0 +MAX_TRAIN_STEPS=5000 +SEED=0 +GPU_COUNT=1 +MAX_NUM=14 + +for ((folder_number = 0; folder_number <= $MAX_NUM; folder_number+=$GPU_COUNT)); do + for ((gpu_id = 0; gpu_id < GPU_COUNT; gpu_id++)); do + current_folder_number=$((folder_number + gpu_id)) + if [ $current_folder_number -gt $MAX_NUM ]; then + break + fi + key=$(printf "%02d" $current_folder_number) + style_prompt=$(python read_json.py "$style_file" "$key") + texture_prompt=$(python read_json.py "$texture_file" "$key") + color_prompt=$(python read_json.py "$color_file" "$key") + INSTANCE_DIR="${BASE_INSTANCE_DIR}/$(printf "%02d" $current_folder_number)/images" + OUTPUT_DIR="${OUTPUT_DIR_PREFIX}$(printf "%02d" $current_folder_number)" + CUDA_VISIBLE_DEVICES=$gpu_id +# CUDA_VISIBLE_DEVICES="1,0" +# PROMPT=$(printf "style_%02d" $current_folder_number) + PROMPT=" in $style_prompt style, with a texture of $texture_prompt and with a color style of $color_prompt." + echo "current_folder_number: $current_folder_number, PROMPT: $PROMPT" + + COMMAND="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python train.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --instance_prompt='$PROMPT' \ + --resolution=$RESOLUTION \ + --train_batch_size=$TRAIN_BATCH_SIZE \ + --gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \ + --learning_rate=$LEARNING_RATE \ + --lr_scheduler=$LR_SCHEDULER \ + --lr_warmup_steps=$LR_WARMUP_STEPS \ + --max_train_steps=$MAX_TRAIN_STEPS \ + --seed=$SEED \ + --checkpoint_save_steps=$CHECKPOINTING_STEPS" + + eval $COMMAND & + sleep 3 + done + wait +done