Browse Source

Initial

master
Zephyr 1 year ago
parent
commit
e8d218dae8
12 changed files with 1069 additions and 4 deletions
  1. +15
    -3
      LICENSE
  2. +34
    -1
      README.md
  3. +35
    -0
      dreambooth/README.md
  4. +19
    -0
      dreambooth/read_json.py
  5. +54
    -0
      dreambooth/run_all.py
  6. +17
    -0
      dreambooth/settings/color.json
  7. +62
    -0
      dreambooth/settings/gauss_statistics.json
  8. +17
    -0
      dreambooth/settings/style.json
  9. +17
    -0
      dreambooth/settings/texture.json
  10. +6
    -0
      dreambooth/test_all.sh
  11. +730
    -0
      dreambooth/train.py
  12. +63
    -0
      dreambooth/train_all.sh

+ 15
- 3
LICENSE View File

@@ -2,8 +2,20 @@ MIT License

Copyright (c) 2024 yohu-cqu

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

+ 34
- 1
README.md View File

@@ -1,2 +1,35 @@
# jittor-SimpleTest-StableDiffusion-A
# DreamBooth-Lora

本项目参考自 [JDiffusion 的 DreamBooth-Lora](https://github.com/JittorRepos/JDiffusion/tree/master/examples/dreambooth) 。

## 环境安装

按照 [JDiffusion 的环境安装一节](https://github.com/JittorRepos/JDiffusion/blob/master/examples/dreambooth/README.md) 安装必要的依赖。

接着设置运行脚本的权限:
```
chmod u+x ./dreambooth/*.sh
```

## 训练

1. 首先从比赛云盘下载对应的数据集,推荐将数据集目录 `A` 下载到 `./A` 下
2. 将 `train_all.sh` 中的 `HF_HOME` 设置为本地模型路径, `root` 设置为项目所在目录, `BASE_INSTANCE_DIR` 设置为数据集对应的目录,`GPU_COUNT` 设置为对应可用的显卡数量,`MAX_NUM` 设置为数据集中的风格个数;
3. 然后进入目标文件夹下: `cd ./dreambooth/`, 运行 `bash train_all.sh` 即可训练。保存的模型会存放至 `./dreambooth/results/prompt_v1_color_test1/style_[训练epoch数]epoch` 目录下,例如:`./dreambooth/results/prompt_v1_color_test1/style_300epoch`。

## 推理

1. 将 `test_all.sh` 中的 `HF_HOME` 设置为本地模型路径,将 `run_all.py` 中的 `root` 设置为项目所在目录, `dataset_root` 修改为数据集对应的目录,将 `max_num` 修改为数据集中的风格个数;
2. 进入目标文件夹下: `cd ./dreambooth/`,运行 `bash test_all.sh` 进行推理。模型生成的图片会输出到 `./dreambooth/results/prompt_v1_color_test1/outputs_[保存点训练epoch数]ckpt_[推理轮数]steps_[种子值]seed` 文件夹下,例如:`./dreambooth/results/prompt_v1_color_test1/outputs_300ckpt_500steps_76587seed`。


## 参考文献

```
@inproceedings{ruiz2023dreambooth,
title={Dreambooth: Fine tuning text-to-image diffusion models for subject-driven generation},
author={Ruiz, Nataniel and Li, Yuanzhen and Jampani, Varun and Pritch, Yael and Rubinstein, Michael and Aberman, Kfir},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
```

+ 35
- 0
dreambooth/README.md View File

@@ -0,0 +1,35 @@
# DreamBooth-Lora

本项目参考自 [JDiffusion 的 DreamBooth-Lora](https://github.com/JittorRepos/JDiffusion/tree/master/examples/dreambooth) 。

## 环境安装

按照 [JDiffusion 的环境安装一节](https://github.com/JittorRepos/JDiffusion/blob/master/examples/dreambooth/README.md) 安装必要的依赖。

接着设置运行脚本的权限:
```
chmod u+x ./dreambooth/*.sh
```

## 训练

1. 首先从比赛云盘下载对应的数据集,推荐将数据集目录 `A` 下载到 `./dreambooth/A/` 下
2. 将 `train_all.sh` 中的 `HF_HOME` 设置为本地模型路径, `root` 设置为项目所在目录, `BASE_INSTANCE_DIR` 设置为数据集对应的目录,`GPU_COUNT` 设置为对应可用的显卡数量,`MAX_NUM` 设置为数据集中的风格个数;
3. 然后进入目标文件夹下: `cd ./dreambooth/`, 运行 `bash train_all.sh` 即可训练,保存的模型会存放至 `./dreambooth/results/prompt_v1_color_test1/style_[训练epoch数]epoch` 目录下,例如:`./dreambooth/results/prompt_v1_color_test1/style_300epoch`。

## 推理

1. 将 `test_all.sh` 中的 `HF_HOME` 设置为本地模型路径,将 `run_all.py` 中的 `root` 设置为项目所在目录, `dataset_root` 修改为数据集对应的目录,将 `max_num` 修改为数据集中的风格个数;
2. 进入目标文件夹下: `cd ./dreambooth/`,运行 `bash test_all.sh` 进行训练,对应的图片会输出到 `./dreambooth/results/prompt_v1_color_test1/outputs_[保存点训练epoch数]ckpt_[推理轮数]steps_[种子值]seed` 文件夹下,例如:`./dreambooth/results/prompt_v1_color_test1/outputs_300ckpt_500steps_76587seed`。


## 参考文献

```
@inproceedings{ruiz2023dreambooth,
title={Dreambooth: Fine tuning text-to-image diffusion models for subject-driven generation},
author={Ruiz, Nataniel and Li, Yuanzhen and Jampani, Varun and Pritch, Yael and Rubinstein, Michael and Aberman, Kfir},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
```

+ 19
- 0
dreambooth/read_json.py View File

@@ -0,0 +1,19 @@
import json
import sys

def get_value_from_json(file_path, key):
with open(file_path, 'r') as f:
data = json.load(f)
if key in data:
return data[key]
else:
return None

if __name__ == "__main__":
file_path = sys.argv[1]
key = sys.argv[2]
value = get_value_from_json(file_path, key)
if value is not None:
print(value)
else:
sys.exit(1)

+ 54
- 0
dreambooth/run_all.py View File

@@ -0,0 +1,54 @@
import os, sys

os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import json, tqdm, torch
import jittor as jt
jt.flags.use_rocm = 1


from JDiffusion.pipelines import StableDiffusionPipeline
# from mypipeline_stable_diffusion_jittor import StableDiffusionPipeline

root = "/home/user1/jittor2024/jittor-A-commit"
save_root = f"{root}/dreambooth/results/" + "prompt_v1_color_test1"

max_num = 15
inference_steps = 500
checkpoint_epoch = 300
seed = 76587
jt.set_global_seed(seed)
dataset_root = f"{root}/A/"
style_file = f"{root}/dreambooth/settings/style.json"
texture_file = f"{root}/dreambooth/settings/texture.json"
color_file = f"{root}/dreambooth/settings/color.json"

with open(style_file, "r") as f:
style_dict = json.load(f)

with open(texture_file, "r") as f:
texture_dict = json.load(f)

with open(color_file, "r") as f:
color_dict = json.load(f)

pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1").to("cuda")
with torch.no_grad():
for tempid in tqdm.tqdm(range(0, max_num)):
taskid = "{:0>2d}".format(tempid)
if checkpoint_epoch is None:
pipe.load_lora_weights(os.path.join(save_root, f"style/style_{taskid}"))
else:
pipe.load_lora_weights(os.path.join(save_root, f"style_{checkpoint_epoch}epoch/style_{taskid}"))

# load json
with open(f"{dataset_root}/{taskid}/prompt.json", "r") as file:
prompts = json.load(file)

for id, prompt in prompts.items():
new_prompt = f"A photo of {prompt} in {style_dict[taskid]} style, with a texture of {texture_dict[taskid]} and with a color style of {color_dict[taskid]}."
print(new_prompt)
image = pipe(prompt=new_prompt, num_inference_steps=inference_steps, width=512, height=512, seed=seed).images[0]
os.makedirs(os.path.join(save_root, f"outputs_{checkpoint_epoch}ckpt_{inference_steps}steps_{seed}seed/{taskid}"), exist_ok=True)
image.save(os.path.join(save_root, f"outputs_{checkpoint_epoch}ckpt_{inference_steps}steps_{seed}seed/{taskid}/{prompt}.png"))

+ 17
- 0
dreambooth/settings/color.json View File

@@ -0,0 +1,17 @@
{
"00": "which is dominated by shades of blue and white, with accents in yellow",
"01": "mainly bright and contrasting colors such as orange, red, yellow, blue, white, gray, and black",
"02": "which is characterized by muted, earthy tones and soft gradients and palette primarily including shades of brown, gray, beige, and subtle hints of orange and gree",
"03": "a vibrant and geometric color palette, featuring bright hues such as reds, blues, greens, yellows, and purples",
"04": "which is featured in muted colors such as light blues, greens, pinks, and yellows",
"05": "a rich palette including warm earth tones like browns and oranges, cool blues, greens, and reds",
"06": "an emphasis on rich blues, purples, pinks, and reds",
"07": "a warm, dark contrast and orange-yellow highlights",
"08": "a common black and white chalkboard and monochromatic",
"09": "which is dominated by colors include shades of orange, yellow, brown, black, white, and hints of blue",
"10": "a mix of cool tones like blue, purple, and pink",
"11": "earthy tones, muted grays and browns",
"12": "sepia tones, vintage textures, and monochromatic hues",
"13": "featuring bold blues, oranges, and greys",
"14": "a combination of warm tones, primarily pink and beige backgrounds with white origami objects"
}

+ 62
- 0
dreambooth/settings/gauss_statistics.json View File

@@ -0,0 +1,62 @@
{
"00": {
"mean_val": 0.0009739127941429615,
"std_val": 0.5314764261245728
},
"01": {
"mean_val": 0.46153217256069184,
"std_val": 0.7191255629062653
},
"02": {
"mean_val": 0.29199431985616686,
"std_val": 0.7220630288124085
},
"03": {
"mean_val": 0.33343668580055236,
"std_val": 0.5767290353775024
},
"04": {
"mean_val": 0.218064521253109,
"std_val": 0.5959844827651978
},
"05": {
"mean_val": 0.18604852110147477,
"std_val": 0.558800458908081
},
"06": {
"mean_val": 0.22389721274375915,
"std_val": 0.6452812850475311
},
"07": {
"mean_val": -0.09157774895429611,
"std_val": 1.0858253061771392
},
"08": {
"mean_val": -0.045120977237820624,
"std_val": 0.6651732742786407
},
"09": {
"mean_val": 0.08461952535435557,
"std_val": 0.670553731918335
},
"10": {
"mean_val": -0.03910434106364846,
"std_val": 0.665828937292099
},
"11": {
"mean_val": 0.02903309902176261,
"std_val": 0.5030718982219696
},
"12": {
"mean_val": 0.22357304841279985,
"std_val": 0.5448532164096832
},
"13": {
"mean_val": 0.1572735294699669,
"std_val": 0.5102830111980439
},
"14": {
"mean_val": 0.16023400127887727,
"std_val": 0.6309724390506745
}
}

+ 17
- 0
dreambooth/settings/style.json View File

@@ -0,0 +1,17 @@
{
"00": "Impressionist",
"01": "Geometric, Modern, Low-Poly, and Pop Art",
"02": "Watercolor",
"03": "Low-Poly, Abstract, Digital Art",
"04": "Watercolor, Illustration, Digital Art",
"05": "Realistic, Mosaic, and Artistic",
"06": "Digital Pixel Art",
"07": "Fiery, Dramatic",
"08": "Chalkboard Art",
"09": "Expressionist and Impressionist",
"10": "Low-Poly, Faceted, 3D Rendering and Geometric",
"11": "Relief",
"12": "Vintage, Hand-drawn, Sepia-toned and Weathered",
"13": "Impressionist, Expressionist, and Post-Impressionist",
"14": "Origami"
}

+ 17
- 0
dreambooth/settings/texture.json View File

@@ -0,0 +1,17 @@
{
"00": "Impressionist",
"01": "Geometric, Modern, Low-Poly, and Pop Art",
"02": "Watercolor",
"03": "Low-Poly, Abstract, Digital Art",
"04": "Watercolor, Illustration, Digital Art",
"05": "Realistic, Mosaic, and Artistic",
"06": "Digital Pixel Art",
"07": "Fiery, Dramatic",
"08": "Chalkboard Art",
"09": "Expressionist and Impressionist",
"10": "Low-Poly, Faceted, 3D Rendering and Geometric",
"11": "Relief",
"12": "Vintage, Hand-drawn, Sepia-toned and Weathered",
"13": "Impressionist, Expressionist, and Post-Impressionist",
"14": "Origami"
}

+ 6
- 0
dreambooth/test_all.sh View File

@@ -0,0 +1,6 @@
#!/bin/bash
export HF_ENDPOINT="https://hf-mirror.com"
export HF_HOME="/home/user1/jittor2024/JDiffusion/cached_path"

COMMAND="python run_all.py"
eval $COMMAND &

+ 730
- 0
dreambooth/train.py View File

@@ -0,0 +1,730 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import jittor as jt

jt.flags.use_rocm = 1
import jittor.nn as nn
import argparse
import copy
import logging
import math
import os
import warnings
from pathlib import Path

import numpy as np
import transformers
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from PIL import Image
from PIL.ImageOps import exif_transpose
from jittor import transform
from jittor.compatibility.optim import AdamW
from jittor.compatibility.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig

import diffusers
from JDiffusion import (
AutoencoderKL,
UNet2DConditionModel,
)
from diffusers import DDPMScheduler
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import (
convert_state_dict_to_diffusers,
)


def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]

if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "T5EncoderModel":
from transformers import T5EncoderModel
return T5EncoderModel
else:
raise ValueError(f"{model_class} is not supported.")


def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--num_process",
type=int,
default=1
)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="lora-dreambooth-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_batch_size", type=int, default=5, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--tokenizer_max_length",
type=int,
default=None,
required=False,
help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
)
parser.add_argument(
"--text_encoder_use_attention_mask",
action="store_true",
required=False,
help="Whether to use attention mask for the text encoder",
)
parser.add_argument(
"--validation_images",
required=False,
default=None,
nargs="+",
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
)
parser.add_argument(
"--class_labels_conditioning",
required=False,
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)
parser.add_argument(
"--rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--checkpoint_save_steps",
type=int,
default=500,
help="Frequency number of saving checkpoint steps to perform.",
)

if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()

# logger is not available yet
if args.class_data_dir is not None:
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")

return args


class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""

def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
class_num=None,
size=512,
center_crop=False,
encoder_hidden_states=None,
class_prompt_encoder_hidden_states=None,
tokenizer_max_length=None,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.encoder_hidden_states = encoder_hidden_states
self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
self.tokenizer_max_length = tokenizer_max_length

self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")

self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images

if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
if class_num is not None:
self.num_class_images = min(len(self.class_images_path), class_num)
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.image_transforms = transform.Compose(
[
transform.Resize(size),
transform.CenterCrop(size) if center_crop else transform.RandomCrop(size),
transform.ToTensor(),
transform.ImageNormalize([0.5], [0.5]),
]
)

def __len__(self):
return self._length

def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
instance_image = exif_transpose(instance_image)
prompt = "A photo of {}".format(
str(self.instance_images_path[index % self.num_instance_images]).
split("/")[-1].split(".")[-2].replace("_", " ")) + self.instance_prompt
print("Prompt:", prompt)

if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)

if self.encoder_hidden_states is not None:
example["instance_prompt_ids"] = self.encoder_hidden_states
else:
text_inputs = tokenize_prompt(
self.tokenizer, prompt, tokenizer_max_length=self.tokenizer_max_length
)
example["instance_prompt_ids"] = text_inputs.input_ids
example["instance_attention_mask"] = text_inputs.attention_mask

if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)

if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)

if self.class_prompt_encoder_hidden_states is not None:
example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
else:
class_text_inputs = tokenize_prompt(
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
)
example["class_prompt_ids"] = class_text_inputs.input_ids
example["class_attention_mask"] = class_text_inputs.attention_mask

return example


def collate_fn(examples, with_prior_preservation=False):
has_attention_mask = "instance_attention_mask" in examples[0]

input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]

if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]

pixel_values = jt.stack(pixel_values)
pixel_values = pixel_values.float()

input_ids = jt.cat(input_ids, dim=0)

batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}

if has_attention_mask:
batch["attention_mask"] = attention_mask

return batch


class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."

def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples

def __len__(self):
return self.num_samples

def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example


def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
if tokenizer_max_length is not None:
max_length = tokenizer_max_length
else:
max_length = tokenizer.model_max_length

text_inputs = tokenizer(
prompt,
truncation=True,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)

return text_inputs


def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
text_input_ids = input_ids.to(text_encoder.device)

if text_encoder_use_attention_mask:
attention_mask = attention_mask.to(text_encoder.device)
else:
attention_mask = None

prompt_embeds = text_encoder(
text_input_ids,
attention_mask=attention_mask,
return_dict=False,
)
prompt_embeds = prompt_embeds[0]

return prompt_embeds


def main(args):
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()

# Handle the repository creation
# if args.output_dir is not None:
# os.makedirs(args.output_dir, exist_ok=True)

# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)

# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)

# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)

unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

# We only train the additional adapter LoRA layers
# if vae is not None:
# vae.requires_grad_(False)
# text_encoder.requires_grad_(False)
# unet.requires_grad_(False)

# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = jt.float32

# Move unet, vae and text_encoder to device and cast to weight_dtype
# unet.to("cuda", dtype=weight_dtype)
# if vae is not None:
# vae.to("cuda", dtype=weight_dtype)
# text_encoder.to("cuda", dtype=weight_dtype)

for name, param in unet.named_parameters():
assert param.requires_grad == False, name
# now we will add new LoRA weights to the attention layers
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0", "add_k_proj", "add_v_proj"],
)
unet.add_adapter(unet_lora_config)

# Optimizer creation

optimizer = AdamW(
list(unet.parameters()),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

pre_computed_encoder_hidden_states = None
pre_computed_class_prompt_encoder_hidden_states = None

# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
class_num=args.num_class_images,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
encoder_hidden_states=pre_computed_encoder_hidden_states,
class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
tokenizer_max_length=args.tokenizer_max_length,
)

train_dataloader = DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, False),
num_workers=args.dataloader_num_workers,
)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True

lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.num_process,
num_training_steps=args.max_train_steps * args.num_process,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
tracker_config = vars(copy.deepcopy(args))
tracker_config.pop("validation_images")

# Train!
total_batch_size = args.train_batch_size * args.num_process * args.gradient_accumulation_steps

print("***** Running training *****")
print(f" Num examples = {len(train_dataset)}")
print(f" Num batches each epoch = {len(train_dataloader)}")
print(f" Num Epochs = {args.num_train_epochs}")
print(f" Instantaneous batch size per device = {args.train_batch_size}")
print(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
print(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0

initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=False,
)

# print("epoch_params:", first_epoch, args.num_train_epochs, args.checkpoint_save_steps)
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor

# Sample noise that we'll add to the latents
noise = jt.randn_like(model_input)
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = jt.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,),
).to(device=model_input.device)
timesteps = timesteps.long()

# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)

# Get the text embedding for conditioning
encoder_hidden_states = encode_prompt(
text_encoder,
batch["input_ids"],
batch["attention_mask"],
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)

if unet.config.in_channels == channels * 2:
noisy_model_input = jt.cat([noisy_model_input, noisy_model_input], dim=1)

if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None

# Predict the noise residual
model_pred = unet(
noisy_model_input,
timesteps,
encoder_hidden_states,
class_labels=class_labels,
return_dict=False,
)[0]

# if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned
# model must be configured to use one of the fixed variance variance types.
if model_pred.shape[1] == 6:
model_pred, _ = jt.chunk(model_pred, 2, dim=1)

# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

loss = nn.mse_loss(model_pred, target)
loss.backward()

optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()

progress_bar.update(1)
global_step += 1

logs = {"loss": loss.detach().item()}
# logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)

if global_step >= args.max_train_steps:
break

# print("epoch:", epoch, args.checkpoint_save_steps)
if (epoch + 1) % (args.checkpoint_save_steps / num_update_steps_per_epoch) == 0:
# Save the lora layers
unet = unet.to(jt.float32)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

base_dir, style_dir = os.path.split(args.output_dir)
style_output_dir = os.path.join(base_dir + f"_{epoch + 1}epoch", style_dir)
print("style_output_dir:", style_output_dir)
if style_output_dir is not None:
os.makedirs(style_output_dir, exist_ok=True)
text_encoder_state_dict = None
LoraLoaderMixin.save_lora_weights(
save_directory=style_output_dir,
unet_lora_layers=unet_lora_state_dict,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False
)


if __name__ == "__main__":
args = parse_args()
main(args)

+ 63
- 0
dreambooth/train_all.sh View File

@@ -0,0 +1,63 @@
#!/bin/bash
export HF_ENDPOINT="https://hf-mirror.com"
export HF_HOME="/home/user1/jittor2024/JDiffusion/cached_path"

root="/home/user1/jittor2024/jittor-A-commit"
save_root="${root}/dreambooth/results/prompt_v1_color_test1"
style_file="${root}/dreambooth/settings/style.json"
texture_file="${root}/dreambooth/settings/texture.json"
color_file="${root}/dreambooth/settings/color.json"

MODEL_NAME="stabilityai/stable-diffusion-2-1"
BASE_INSTANCE_DIR="${root}/A"
OUTPUT_DIR_PREFIX="${save_root}/style/style_"
RESOLUTION=512
TRAIN_BATCH_SIZE=1
GRADIENT_ACCUMULATION_STEPS=1
CHECKPOINTING_STEPS=500
LEARNING_RATE=1e-4
LR_SCHEDULER="constant"
LR_WARMUP_STEPS=0
MAX_TRAIN_STEPS=5000
SEED=0
GPU_COUNT=1
MAX_NUM=14

for ((folder_number = 0; folder_number <= $MAX_NUM; folder_number+=$GPU_COUNT)); do
for ((gpu_id = 0; gpu_id < GPU_COUNT; gpu_id++)); do
current_folder_number=$((folder_number + gpu_id))
if [ $current_folder_number -gt $MAX_NUM ]; then
break
fi
key=$(printf "%02d" $current_folder_number)
style_prompt=$(python read_json.py "$style_file" "$key")
texture_prompt=$(python read_json.py "$texture_file" "$key")
color_prompt=$(python read_json.py "$color_file" "$key")
INSTANCE_DIR="${BASE_INSTANCE_DIR}/$(printf "%02d" $current_folder_number)/images"
OUTPUT_DIR="${OUTPUT_DIR_PREFIX}$(printf "%02d" $current_folder_number)"
CUDA_VISIBLE_DEVICES=$gpu_id
# CUDA_VISIBLE_DEVICES="1,0"
# PROMPT=$(printf "style_%02d" $current_folder_number)
PROMPT=" in $style_prompt style, with a texture of $texture_prompt and with a color style of $color_prompt."
echo "current_folder_number: $current_folder_number, PROMPT: $PROMPT"

COMMAND="CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python train.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt='$PROMPT' \
--resolution=$RESOLUTION \
--train_batch_size=$TRAIN_BATCH_SIZE \
--gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
--learning_rate=$LEARNING_RATE \
--lr_scheduler=$LR_SCHEDULER \
--lr_warmup_steps=$LR_WARMUP_STEPS \
--max_train_steps=$MAX_TRAIN_STEPS \
--seed=$SEED \
--checkpoint_save_steps=$CHECKPOINTING_STEPS"

eval $COMMAND &
sleep 3
done
wait
done

Loading…
Cancel
Save