里面用到dataset.py

运行make train_demo命令

train_demo:
	python bin/train.py demo_data/filelist.txt output/bayer \
	--pretrained pretrained_models/bayer \
	--val_data demo_data/filelist.txt --batch_size 1

运行train.py ,后面是输入的参变量filelist.txt
预训练权重通常是 .npy文件,是numpy专用的二进制文件。
然后就是看train.py文件咯
各种导入模块

import argparse

#日志模块
import logging
import os

#更改进程名用的
import setproctitle

#和时刻相关的操作模块
import time

#torch系列模块
import numpy as np
import torch as th
import torch.optim as optim
from torch.autograd import Variable
from torchvision import transforms

from torchlib.trainer import Trainer

#demasaic系列包
import demosaic.dataset as dset
import demosaic.modules as modules
import demosaic.losses as losses
import demosaic.callbacks as callbacks
import demosaic.converter as converter
import torchlib.callbacks as default_callbacks
import torchlib.optim as toptim

用getLogger方法间接实例化一个logger
log = logging.getLogger("demosaick")

接下来是主函数

def main(args, model_params):
  # 默认设置fix_seed是false
  if args.fix_seed:
  	# 设置了seed,相同seed下,每次生成的随机数序列都是一样的
    np.random.seed(0)
    # 为CPU设置种子用于生成随机数,以使得结果是确定的
    th.manual_seed(0)

  # ------------ Set up datasets ----------------------------------------------
  # 使用demosaic.dataset的ToTensor方法,,把dataset.py里面的sample给变成tensor
  # xform包含了masaic,mask和im,但是这几个都没有值

  xforms = [dset.ToTensor()]
  # 默认下不用
  if args.green_only:
    xforms.append(dset.GreenOnly())
  xforms = transforms.Compose(xforms)
  #默认下不用
  if args.xtrans:
    data = dset.XtransDataset(args.data_dir, transform=xforms, augment=True, linearize=args.linear)
  else:
    data = dset.BayerDataset(args.data_dir, transform=xforms, augment=True, linearize=args.linear)
  # 默认情况下data是一个空列表
  data[0]

  if args.val_data is not None:
    if args.xtrans:
      val_data = dset.XtransDataset(args.val_data, transform=xforms, augment=False)
    else:
      val_data = dset.BayerDataset(args.val_data, transform=xforms, augment=False)
  else:
    val_data = None
  # ---------------------------------------------------------------------------

  model = modules.get(model_params)
  log.info("Model configuration: {}".format(model_params))

  if args.pretrained:
    log.info("Loading Caffe weights")
    if args.xtrans:
      model_ref = modules.get({"model": "XtransNetwork"})
      cvt = converter.Converter(args.pretrained, "XtransNetwork")
    else:
      model_ref = modules.get({"model": "BayerNetwork"})
      cvt = converter.Converter(args.pretrained, "BayerNetwork")
    cvt.convert(model_ref)
    model_ref.cuda()
  else:
    model_ref = None

  if args.green_only:
    model = modules.GreenOnly(model)
    model_ref = modules.GreenOnly(model_ref)

  if args.subsample:
    dx = 1
    dy = 0
    if args.xtrans:
      period = 6
    else:
      period = 2
    model = modules.Subsample(model, period, dx=dx, dy=dy)
    model_ref = modules.Subsample(model_ref, period, dx=dx, dy=dy)

  if args.linear:
    model = modules.DeLinearize(model)
    model_ref = modules.DeLinearize(model_ref)

  name = os.path.basename(args.output)
  cbacks = [
      default_callbacks.LossCallback(env=name),
      callbacks.DemosaicVizCallback(val_data, model, model_ref, cuda=True, 
                                    shuffle=False, env=name),
      callbacks.PSNRCallback(env=name),
      ]

  metrics = {
      "psnr": losses.PSNR(crop=4)
      }

  log.info("Using {} loss".format(args.loss))
  if args.loss == "l2":
    criteria = { "l2": losses.L2Loss(), }
  elif args.loss == "l1":
    criteria = { "l1": losses.L1Loss(), }
  elif args.loss == "gradient":
    criteria = { 
      "gradient": losses.GradientLoss(), 
    }
  elif args.loss == "laplacian":
    criteria = { 
      "laplacian": losses.LaplacianLoss(), 
    }
  elif args.loss == "vgg":
    criteria = { "vgg": losses.VGGLoss(), }
  else:
    raise ValueError("not implemented")

  optimizer = optim.Adam
  optimizer_params = {}
  if args.optimizer == "sgd":
    optimizer = optim.SGD
    optimizer_params = {"momentum": 0.9}
  train_params = Trainer.Parameters(
      viz_step=100, lr=args.lr, batch_size=args.batch_size,
      optimizer=optimizer, optimizer_params=optimizer_params)

  trainer = Trainer(
      data, model, criteria, output=args.output, 
      params = train_params,
      model_params=model_params, verbose=args.debug, 
      callbacks=cbacks,
      metrics=metrics,
      valset=val_data, cuda=True)

  trainer.train()

运行主函数

#因为是作为脚本运行的,所以这个if条件为真
if __name__ == "__main__":
  # 创建 ArgumentParser() 对象
  parser = argparse.ArgumentParser()
  # I/O params
  # 设置要输入输入的参数
  parser.add_argument('data_dir')
  parser.add_argument('output')
  parser.add_argument('--val_data')
  parser.add_argument('--checkpoint')
  parser.add_argument('--pretrained')

  # Training
  # 设置训练网络参数
  parser.add_argument('--batch_size', type=int, default=16)
  parser.add_argument('--lr', type=float, default=1e-4)
  parser.add_argument('--fix_seed', dest="fix_seed", action="store_true")
  parser.add_argument('--loss', default="l2", choices=["l1", "vgg", "l2", "gradient", "laplacian"])
  parser.add_argument('--optimizer', default="adam", choices=["adam", "sgd"])

  # Monitoring
  parser.add_argument('--debug', dest="debug", action="store_true")

  # Model
  #dest是函数参数的名字,例如默认情况下,xtrans=false
  #action意思是一旦有这个参数,就将它设置为action的值
  parser.add_argument('--xtrans', dest="xtrans", action="store_true")
  parser.add_argument('--green_only', dest="green_only", action="store_true")
  parser.add_argument('--subsample', dest="subsample", action="store_true")
  parser.add_argument('--linear', dest="linear", action="store_true")
  parser.add_argument(
      '--params', nargs="*", default=["model=BayerNetwork"])

  # 设置了一些参数的默认值
  parser.set_defaults(debug=False, fix_seed=False, xtrans=False,
                      green_only=False, subsample=False, linear=False)
  
  # 属性给与args实例: 把parser中设置的所有"add_argument"给返回到args子类实例当中
  args = parser.parse_args()

  # 定义一个空字典
  params = {}
  # 如果args实例不是空的
  if args.params is not None:
    for p in args.params:
      # string类中的方法,以=为分界,key给k,value给v
      k, v = p.split("=")
      # 查看v是不是数字,如果是,就转换成整数;
      if v.isdigit():
        v = int(v)
      # 判断v是不是bool型,如果是,转换成bool
      elif v == "False":
        v = False
      elif v == "True":
        v = True

	  # 将处理后的value和key加进空字典params里
      params[k] = v
  # 设置日志,显示process id,日志等级,文件名,行号,信息
  logging.basicConfig(
      format="[%(process)d] %(levelname)s %(filename)s:%(lineno)s | %(message)s")
  # log中debug<info<warning<error<critial,若等级设置成info,则debug信息将无法被记录
  if args.debug:
    log.setLevel(logging.DEBUG)
  else:
    log.setLevel(logging.INFO)
  # os.path.basename(args.output)返回最后的文件名,不包括后缀,这里是bayer(在output文件夹下)
  # format用来格式化字符串,官方推荐,结果就是demosaic_bayer
  # setproctitle.setproctitle修改进程名字,不用这个语句,进程显示是python。用这个进程显示‘进程别名’。
  setproctitle.setproctitle(
      'demosaic_{}'.format(os.path.basename(args.output)))
  运行主函数
  main(args, params)

if name == “main”:的作用python3 if x 和 if x is not None 区别

split用法的小试验

make train_demo和train.py阅读笔记_主函数


log的python官方科普

make train_demo和train.py阅读笔记_python_02