简单关键的

pytroch一机多卡训练_pytorch

device_count=torch.cuda.device_count()
device_ids=list(range(device_count))
model=nn.DataParallel(model,device_ids=device_ids)
criterion=nn.DataParallel(criterion,device_ids=device_ids)