class Attention(nn.Module):
"""
层级Attention,汇集某方玩家的轨迹为一个向量
"""
def __init__(self):
super(Attention,self).__init__()
self.w_omega = nn.Parameter(torch.Tensor(32,32))
self.u_omega = nn.Parameter(torch.Tensor(32,1))
nn.init.uniform_(self.w_omega,-0.1,0.1)
nn.init.uniform_(self.u_omega,-0.1,0.1)
self.attention = nn.Softmax(dim=0)
def forward(self,x):
self.w_omega.data = torch.clamp(self.w_omega,max=-1)
u = torch.tanh(torch.matmul(x,self.w_omega))
att = torch.matmul(u,self.u_omega)
att_score = self.attention(att)
score_x = x * att_score
out = torch.sum(score_x,dim=0)
return out,att_score