bt365软件下载

DDIM模型代码解析(一)

DDIM模型代码解析(一)

目录

预备知识

main.py

解析命令行参数

解析配置文件

预备知识

由于代码中除了一些必要的对模型、数据进行操作的PyTorch函数外,还有一些辅助显示训练等过程有关信息的,或辅助对文件目录进行操作的库。因此,建议读者先对这些库进行了解,试着写一写示例代码,理解库中函数的使用方法后再阅读下面的讲解,这样可以更顺畅。

import argparse

import traceback

import shutil

import logging

import yaml

import sys

import os

main.py

首先对输出的选项进行设定,让输出的内容不按科学计数法模式。

torch.set_printoptions(sci_mode=False) # 设置为不按照科学计数法表示输出

然后程序进入main()函数中,在main函数中完成了以下任务:

解析命令行参数解析配置文件打印相关信息扩散过程实例化完成采样 / 测试 / 训练过程

后面我们逐一进行代码分析。

def main():

args, config = parse_args_and_config() # 解析命令行参数和配置文件

logging.info("Writing log file to {}".format(args.log_path)) # 显示日志存储路径信息

logging.info("Exp instance id = {}".format(os.getpid())) # 显示进程id信息

logging.info("Exp comment = {}".format(args.comment)) # 显示实验注释信息

try:

runner = Diffusion(args, config) # 构建扩散运行实例对象

if args.sample: # 如果是采样操作,就执行采样函数

runner.sample()

elif args.test: # 如果是测试模型,就执行测试函数

runner.test()

else: # 否则就执行训练函数

runner.train()

except Exception: # 如果报错就输出错误信息日志

logging.error(traceback.format_exc())

return 0

解析命令行参数

对命令行参数的解析在parse_args_and_config函数中完成,每一个参数的含义以注释的形式标明,如果有异议欢迎在评论中指出。

def parse_args_and_config():

parser = argparse.ArgumentParser(description=globals()["__doc__"])

parser.add_argument( # config文件路径

"--config", type=str, required=True, help="Path to the config file"

)

parser.add_argument("--seed", type=int, default=1234, help="Random seed") # 随机种子

parser.add_argument( # 用于保存运行相关数据的路径

"--exp", type=str, default="exp", help="Path for saving running related data."

)

parser.add_argument( # log日志文件夹名称

"--doc",

type=str,

required=True,

help="A string for documentation purpose. "

"Will be the name of the log folder.",

)

parser.add_argument( # 实验注释

"--comment", type=str, default="", help="A string for experiment comment"

)

parser.add_argument( # logging日志的级别: info, debug, warning, critical

"--verbose",

type=str,

default="info",

help="Verbose level: info | debug | warning | critical",

)

parser.add_argument("--test", action="store_true", help="Whether to test the model") # 是否测试模型

parser.add_argument( # 是否从模型产生采样

"--sample",

action="store_true",

help="Whether to produce samples from the model",

)

parser.add_argument("--fid", action="store_true") # FID指标

parser.add_argument("--interpolation", action="store_true") # 插值

parser.add_argument( # 是否为继续训练

"--resume_training", action="store_true", help="Whether to resume training"

)

parser.add_argument( # 采样的文件夹名称

"-i",

"--image_folder",

type=str,

default="images",

help="The folder name of samples",

)

parser.add_argument( # 无交互

"--ni",

action="store_true",

help="No interaction. Suitable for Slurm Job launcher",

)

parser.add_argument("--use_pretrained", action="store_true") # 使用预训练

parser.add_argument( # 采样类型

"--sample_type",

type=str,

default="generalized",

help="sampling approach (generalized or ddpm_noisy)",

)

parser.add_argument( # 跳跃类型

"--skip_type",

type=str,

default="uniform",

help="skip according to (uniform or quadratic)",

)

parser.add_argument( # 步数

"--timesteps", type=int, default=1000, help="number of steps involved"

)

parser.add_argument( # \eta超参数用于控制方差

"--eta",

type=float,

default=0.0,

help="eta used to control the variances of sigma",

)

parser.add_argument("--sequence", action="store_true") # 是否为序列

args = parser.parse_args() # 解析参数

args.log_path = os.path.join(args.exp, "logs", args.doc) # log日志路径: exp/logs/$doc$

...

解析配置文件

解析配置文件的过程也是在parse_args_and_config函数中,args.config应该是bedroom,celeba,church,cifar10中的一个。这样我们可以直接打开文件夹configs中对应数据集的yaml配置文件,此时config为字典类型。经过dict2namespace函数,将字典类型转换为argparse中命名空间的形式。

def parse_args_and_config():

...

# parse config file

with open(os.path.join("configs", args.config), "r") as f:

config = yaml.safe_load(f)

new_config = dict2namespace(config)

...

转换函数如下:

def dict2namespace(config):

namespace = argparse.Namespace()

for key, value in config.items():

if isinstance(value, dict):

new_value = dict2namespace(value)

else:

new_value = value

setattr(namespace, key, new_value)

return namespace

之后还有一步设定tensorboard日志的路径,可以在训练时用tensorboard查看训练进度信息:

def parse_args_and_config():

...

tb_path = os.path.join(args.exp, "tensorboard", args.doc) # tensorboard日志路径: exp/tensorboard/$doc$

...

之后会执行训练 / 采样 / 测试不同的代码部分:

首先看一下对于训练会执行的代码:

创建log日志文件夹创建tensorboard日志文件夹设置logging的logger

def parse_args_and_config():

...

if not args.test and not args.sample:

if not args.resume_training:

if os.path.exists(args.log_path): # 如果log输出路径存在的话

overwrite = False # 选择不覆盖

if args.ni: # 如果ni为True

overwrite = True # 选择覆盖

else:

response = input("Folder already exists. Overwrite? (Y/N)") # 询问是否覆盖

if response.upper() == "Y": # 如果Y, 则选择覆盖原有log

overwrite = True

if overwrite: # 如果选择覆盖

shutil.rmtree(args.log_path) # 删除原有log文件路径

shutil.rmtree(tb_path) # 删除原有tensorboard文件路径

os.makedirs(args.log_path) # 创建新的log文件路径

if os.path.exists(tb_path): # 如果tensorboard文件路径存在, 就删除它

shutil.rmtree(tb_path)

else: # 如果选择不覆盖, 则提示文件夹存在, 程序停止

print("Folder exists. Program halted.")

sys.exit(0)

else: # 如果log输出路径不存在就创建路径

os.makedirs(args.log_path)

with open(os.path.join(args.log_path, "config.yml"), "w") as f:

yaml.dump(new_config, f, default_flow_style=False)

new_config.tb_logger = tb.SummaryWriter(log_dir=tb_path)

# setup logger

level = getattr(logging, args.verbose.upper(), None) # 20 (logging.INFO) 或者其它的级别

if not isinstance(level, int): # 如果为None的话就会报错

raise ValueError("level {} not supported".format(args.verbose))

handler1 = logging.StreamHandler() # 将log在CLI输出的handler

handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) # 将log在文件输出的handler

formatter = logging.Formatter( # 控制log输出格式的formatter

"%(levelname)s - %(filename)s - %(asctime)s - %(message)s" # INFO - __main__ - ... - ....

)

handler1.setFormatter(formatter) # 设置CLI输出handler的格式

handler2.setFormatter(formatter) # 设置文件输出handler的格式

logger = logging.getLogger() # root logger

logger.addHandler(handler1) # 添加CLI输出handler

logger.addHandler(handler2) # 添加文件输出handler

logger.setLevel(level) # 设定root logger的级别

...

然后是采样 / 测试会执行的代码:

设置logging的logger对于采样,会创建图像文件夹

def parse_args_and_config():

...

else:

level = getattr(logging, args.verbose.upper(), None)

if not isinstance(level, int):

raise ValueError("level {} not supported".format(args.verbose))

handler1 = logging.StreamHandler()

formatter = logging.Formatter(

"%(levelname)s - %(filename)s - %(asctime)s - %(message)s"

)

handler1.setFormatter(formatter)

logger = logging.getLogger()

logger.addHandler(handler1)

logger.setLevel(level)

if args.sample: # 如果是采样

os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) # 创建目录: exp/image_samples

args.image_folder = os.path.join( # 添加图像文件夹参数: exp/image_samples/$image_folder$

args.exp, "image_samples", args.image_folder

)

if not os.path.exists(args.image_folder): # 如果图像文件夹不存在就创建一个

os.makedirs(args.image_folder)

else: # 如果图像文件夹存在

if not (args.fid or args.interpolation):

overwrite = False

if args.ni:

overwrite = True

else:

response = input(

f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)"

)

if response.upper() == "Y":

overwrite = True

if overwrite: # 如果覆盖, 删除并新建文件夹

shutil.rmtree(args.image_folder)

os.makedirs(args.image_folder)

else:

print("Output image folder exists. Program halted.")

sys.exit(0)

...

最后是对PyTorch进行设置:

device随机种子causes cuDNN to benchmark multiple convolution algorithms and select the fastest.

def parse_args_and_config():

...

# add device

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

logging.info("Using device: {}".format(device))

new_config.device = device

# set random seed

torch.manual_seed(args.seed)

np.random.seed(args.seed)

if torch.cuda.is_available():

torch.cuda.manual_seed_all(args.seed)

torch.backends.cudnn.benchmark = True

return args, new_config

至此,就基本结束main.py的学习了,后面讲进入Diffusion类中查看具体初始化、训练、采样、测试这些函数是如何实现的了。

相关推荐