里面用到dataset.py
运行make train_demo命令
train_demo:
python bin/train.py demo_data/filelist.txt output/bayer \
--pretrained pretrained_models/bayer \
--val_data demo_data/filelist.txt --batch_size 1
运行train.py ,后面是输入的参变量filelist.txt
预训练权重通常是 .npy文件,是numpy专用的二进制文件。
然后就是看train.py文件咯
各种导入模块
import argparse
#日志模块
import logging
import os
#更改进程名用的
import setproctitle
#和时刻相关的操作模块
import time
#torch系列模块
import numpy as np
import torch as th
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms
from torchlib.trainer import Trainer
#demasaic系列包
import demosaic.dataset as dset
import demosaic.modules as modules
import demosaic.losses as losses
import demosaic.callbacks as callbacks
import demosaic.converter as converter
import torchlib.callbacks as default_callbacks
import torchlib.optim as toptim
用getLogger方法间接实例化一个logger
log = logging.getLogger("demosaick")
接下来是主函数
def main(args, model_params):
# 默认设置fix_seed是false
if args.fix_seed:
# 设置了seed,相同seed下,每次生成的随机数序列都是一样的
np.random.seed(0)
# 为CPU设置种子用于生成随机数,以使得结果是确定的
th.manual_seed(0)
# ------------ Set up datasets ----------------------------------------------
# 使用demosaic.dataset的ToTensor方法,,把dataset.py里面的sample给变成tensor
# xform包含了masaic,mask和im,但是这几个都没有值
xforms = [dset.ToTensor()]
# 默认下不用
if args.green_only:
xforms.append(dset.GreenOnly())
xforms = transforms.Compose(xforms)
#默认下不用
if args.xtrans:
data = dset.XtransDataset(args.data_dir, transform=xforms, augment=True, linearize=args.linear)
else:
data = dset.BayerDataset(args.data_dir, transform=xforms, augment=True, linearize=args.linear)
# 默认情况下data是一个空列表
data[0]
if args.val_data is not None:
if args.xtrans:
val_data = dset.XtransDataset(args.val_data, transform=xforms, augment=False)
else:
val_data = dset.BayerDataset(args.val_data, transform=xforms, augment=False)
else:
val_data = None
# ---------------------------------------------------------------------------
model = modules.get(model_params)
log.info("Model configuration: {}".format(model_params))
if args.pretrained:
log.info("Loading Caffe weights")
if args.xtrans:
model_ref = modules.get({"model": "XtransNetwork"})
cvt = converter.Converter(args.pretrained, "XtransNetwork")
else:
model_ref = modules.get({"model": "BayerNetwork"})
cvt = converter.Converter(args.pretrained, "BayerNetwork")
cvt.convert(model_ref)
model_ref.cuda()
else:
model_ref = None
if args.green_only:
model = modules.GreenOnly(model)
model_ref = modules.GreenOnly(model_ref)
if args.subsample:
dx = 1
dy = 0
if args.xtrans:
period = 6
else:
period = 2
model = modules.Subsample(model, period, dx=dx, dy=dy)
model_ref = modules.Subsample(model_ref, period, dx=dx, dy=dy)
if args.linear:
model = modules.DeLinearize(model)
model_ref = modules.DeLinearize(model_ref)
name = os.path.basename(args.output)
cbacks = [
default_callbacks.LossCallback(env=name),
callbacks.DemosaicVizCallback(val_data, model, model_ref, cuda=True,
shuffle=False, env=name),
callbacks.PSNRCallback(env=name),
]
metrics = {
"psnr": losses.PSNR(crop=4)
}
log.info("Using {} loss".format(args.loss))
if args.loss == "l2":
criteria = { "l2": losses.L2Loss(), }
elif args.loss == "l1":
criteria = { "l1": losses.L1Loss(), }
elif args.loss == "gradient":
criteria = {
"gradient": losses.GradientLoss(),
}
elif args.loss == "laplacian":
criteria = {
"laplacian": losses.LaplacianLoss(),
}
elif args.loss == "vgg":
criteria = { "vgg": losses.VGGLoss(), }
else:
raise ValueError("not implemented")
optimizer = optim.Adam
optimizer_params = {}
if args.optimizer == "sgd":
optimizer = optim.SGD
optimizer_params = {"momentum": 0.9}
train_params = Trainer.Parameters(
viz_step=100, lr=args.lr, batch_size=args.batch_size,
optimizer=optimizer, optimizer_params=optimizer_params)
trainer = Trainer(
data, model, criteria, output=args.output,
params = train_params,
model_params=model_params, verbose=args.debug,
callbacks=cbacks,
metrics=metrics,
valset=val_data, cuda=True)
trainer.train()
运行主函数
#因为是作为脚本运行的,所以这个if条件为真
if __name__ == "__main__":
# 创建 ArgumentParser() 对象
parser = argparse.ArgumentParser()
# I/O params
# 设置要输入输入的参数
parser.add_argument('data_dir')
parser.add_argument('output')
parser.add_argument('--val_data')
parser.add_argument('--checkpoint')
parser.add_argument('--pretrained')
# Training
# 设置训练网络参数
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--fix_seed', dest="fix_seed", action="store_true")
parser.add_argument('--loss', default="l2", choices=["l1", "vgg", "l2", "gradient", "laplacian"])
parser.add_argument('--optimizer', default="adam", choices=["adam", "sgd"])
# Monitoring
parser.add_argument('--debug', dest="debug", action="store_true")
# Model
#dest是函数参数的名字,例如默认情况下,xtrans=false
#action意思是一旦有这个参数,就将它设置为action的值
parser.add_argument('--xtrans', dest="xtrans", action="store_true")
parser.add_argument('--green_only', dest="green_only", action="store_true")
parser.add_argument('--subsample', dest="subsample", action="store_true")
parser.add_argument('--linear', dest="linear", action="store_true")
parser.add_argument(
'--params', nargs="*", default=["model=BayerNetwork"])
# 设置了一些参数的默认值
parser.set_defaults(debug=False, fix_seed=False, xtrans=False,
green_only=False, subsample=False, linear=False)
# 属性给与args实例: 把parser中设置的所有"add_argument"给返回到args子类实例当中
args = parser.parse_args()
# 定义一个空字典
params = {}
# 如果args实例不是空的
if args.params is not None:
for p in args.params:
# string类中的方法,以=为分界,key给k,value给v
k, v = p.split("=")
# 查看v是不是数字,如果是,就转换成整数;
if v.isdigit():
v = int(v)
# 判断v是不是bool型,如果是,转换成bool
elif v == "False":
v = False
elif v == "True":
v = True
# 将处理后的value和key加进空字典params里
params[k] = v
# 设置日志,显示process id,日志等级,文件名,行号,信息
logging.basicConfig(
format="[%(process)d] %(levelname)s %(filename)s:%(lineno)s | %(message)s")
# log中debug<info<warning<error<critial,若等级设置成info,则debug信息将无法被记录
if args.debug:
log.setLevel(logging.DEBUG)
else:
log.setLevel(logging.INFO)
# os.path.basename(args.output)返回最后的文件名,不包括后缀,这里是bayer(在output文件夹下)
# format用来格式化字符串,官方推荐,结果就是demosaic_bayer
# setproctitle.setproctitle修改进程名字,不用这个语句,进程显示是python。用这个进程显示‘进程别名’。
setproctitle.setproctitle(
'demosaic_{}'.format(os.path.basename(args.output)))
运行主函数
main(args, params)
if name == “main”:的作用python3 if x 和 if x is not None 区别
split用法的小试验