出错的版本:
def compute_balance_loss(self, routing_probs):
"""
计算平衡损失(balance_loss)。
"""
balance_loss = -torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1).mean()
return balance_loss
def forward(self, x, gate_inputs):
batch_size = x.size(0)
# Compute routing probabilities
routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1) # Shape (batch_size, num_experts)
# Select the expert with the highest probability for each sample
top1_expert_indices = routing_probs.argmax(dim=-1) # Shape (batch_size,)
# Initialize the output tensor
expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
# Apply the selected expert
for i in range(self.num_experts):
expert_mask = top1_expert_indices == i
if expert_mask.any():
expert_input = x[expert_mask]
expert_output = self.experts[i](expert_input)
# Make a copy of the relevant part of routing_probs to avoid in-place operations on shared tensors
routing_prob_copy = routing_probs[expert_mask, i].unsqueeze(-1).clone()
expert_outputs[expert_mask] = expert_output * routing_prob_copy
# Compute the balance loss separately using the function
balance_loss = self.compute_balance_loss(routing_probs)
return expert_outputs, balance_loss
没有出错的版本:
def compute_balance_loss(self, routing_probs):
"""
计算平衡损失(balance_loss)。
"""
balance_loss = -torch.mean(torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1))
return balance_loss
def forward(self, x, gate_inputs):
batch_size = x.size(0)
# 计算路由概率
routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1) # Shape (batch_size, num_experts)
# 通过对每个专家的权重进行加权平均来计算最终的输出
expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
for i, expert in enumerate(self.experts):
# 获取当前专家的输出
expert_output = expert(x) # Shape (batch_size, expert_hidden_dim)
# 根据路由概率对输出进行加权
routing_prob = routing_probs[:, i].unsqueeze(-1) # Shape (batch_size, 1)
expert_outputs += expert_output * routing_prob # 加权求和
# 计算平衡损失
balance_loss = self.compute_balance_loss(routing_probs)
return expert_outputs, balance_loss