class Attention(nn.Module):
    """
    层级Attention,汇集某方玩家的轨迹为一个向量
    """
    def __init__(self):
        super(Attention,self).__init__()
        self.w_omega = nn.Parameter(torch.Tensor(32,32)) 
        self.u_omega = nn.Parameter(torch.Tensor(32,1))
        nn.init.uniform_(self.w_omega,-0.1,0.1)
        nn.init.uniform_(self.u_omega,-0.1,0.1)
        self.attention = nn.Softmax(dim=0) 
    def forward(self,x):
        self.w_omega.data = torch.clamp(self.w_omega,max=-1)
        u = torch.tanh(torch.matmul(x,self.w_omega)) 
        att = torch.matmul(u,self.u_omega) 
        att_score = self.attention(att) 
        score_x = x * att_score 
        out = torch.sum(score_x,dim=0) 
        return out,att_score