出错的版本:

def compute_balance_loss(self, routing_probs):
        """
        计算平衡损失(balance_loss)。
        """
        balance_loss = -torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1).mean()
        return balance_loss

    def forward(self, x, gate_inputs):
        batch_size = x.size(0)
        
        # Compute routing probabilities
        routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1)  # Shape (batch_size, num_experts)
        
        # Select the expert with the highest probability for each sample
        top1_expert_indices = routing_probs.argmax(dim=-1)  # Shape (batch_size,)
        
        # Initialize the output tensor
        expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
        
        # Apply the selected expert
        for i in range(self.num_experts):
            expert_mask = top1_expert_indices == i
            if expert_mask.any():
                expert_input = x[expert_mask]
                expert_output = self.experts[i](expert_input)
                
                # Make a copy of the relevant part of routing_probs to avoid in-place operations on shared tensors
                routing_prob_copy = routing_probs[expert_mask, i].unsqueeze(-1).clone()
                expert_outputs[expert_mask] = expert_output * routing_prob_copy
        
        # Compute the balance loss separately using the function
        balance_loss = self.compute_balance_loss(routing_probs)
        
        return expert_outputs, balance_loss

没有出错的版本:

def compute_balance_loss(self, routing_probs):
        """
        计算平衡损失(balance_loss)。
        """
        balance_loss = -torch.mean(torch.sum(routing_probs * torch.log(routing_probs + 1e-10), dim=-1))
        return balance_loss

    def forward(self, x, gate_inputs):
        batch_size = x.size(0)
        
        # 计算路由概率
        routing_probs = F.softmax(self.gating_network(gate_inputs), dim=-1)  # Shape (batch_size, num_experts)
        
        # 通过对每个专家的权重进行加权平均来计算最终的输出
        expert_outputs = torch.zeros(batch_size, self.expert_hidden_dim, device=x.device)
        
        for i, expert in enumerate(self.experts):
            # 获取当前专家的输出
            expert_output = expert(x)  # Shape (batch_size, expert_hidden_dim)
            
            # 根据路由概率对输出进行加权
            routing_prob = routing_probs[:, i].unsqueeze(-1)  # Shape (batch_size, 1)
            expert_outputs += expert_output * routing_prob  # 加权求和
        
        # 计算平衡损失
        balance_loss = self.compute_balance_loss(routing_probs)
        
        return expert_outputs, balance_loss