device = torch.device('cuda' if (self.worker == 'gpu' and torch.cuda.is_available()) else 'cpu')

if torch.cuda.device_count() > 1: # 多gpu
model = torch.nn.DataParallel(model, device_ids=[x for x in range(self.config.gpu_num)])


几个需要添加to.device的地方

  1. model(如:model.to(device))
  2. input(通常需要使用Variable包装,如:input = Variable(input).to(device))
  3. target(通常需要使用Variable包装,如:target = Variable(torch.from_numpy(np.array(target)).long()).to(device)
  4. nn.CrossEntropyLoss()(如:criterion = nn.CrossEntropyLoss().to(device))