LigUnity / HGNN /PL_Aggregator.py
Irwiny123's picture
提交LigUnity初始代码
94391f2
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import random
from Attention import Attention
class PLAggregator(nn.Module):
"""
item and user aggregator: for aggregating embeddings of neighbors (item/user aggreagator).
"""
def __init__(self, v2e=None, r2e=None, u2e=None, embed_dim=128, cuda="cpu", uv=True):
super(PLAggregator, self).__init__()
self.uv = uv
self.v2e = v2e
self.r2e = r2e
self.u2e = u2e
self.device = cuda
self.embed_dim = embed_dim
self.w_r1 = nn.Linear(self.embed_dim * 2, self.embed_dim)
self.w_r2 = nn.Linear(self.embed_dim, self.embed_dim)
self.att = Attention(self.embed_dim)
if self.v2e is not None:
self.v2e.requires_grad = False
if self.u2e is not None:
self.u2e.requires_grad = False
def forward(self, nodes_u, input_hist):
embed_matrix = torch.zeros(len(input_hist), self.embed_dim, dtype=torch.float).to(self.device)
for i in range(len(input_hist)):
history = []
label = []
for idx in range(len(input_hist[i])):
vid_hist = input_hist[i][idx][0]
vlabel_hist = input_hist[i][idx][1]
history.append(vid_hist)
label.append(vlabel_hist)
num_histroy_item = len(history)
if num_histroy_item > 0:
e_uv = self.v2e.weight[history]
uv_rep = self.u2e.weight[nodes_u[i]]
e_r = self.r2e.weight[label]
x = torch.cat((e_uv, e_r), 1)
x = F.relu(self.w_r1(x))
o_history = F.relu(self.w_r2(x))
att_w = self.att(o_history, uv_rep, num_histroy_item)
# print([(a,b) for a,b in zip(label, att_w)])
att_history = torch.mm(o_history.t(), att_w)
att_history = att_history.t()
embed_matrix[i] = (att_history + uv_rep) / 2
else:
embed_matrix[i] = self.u2e.weight[nodes_u[i]]
return embed_matrix
def forward_inference(self, pocket_embed, neighbor_list):
neighbor_embed = torch.stack([x[1] for x in neighbor_list])
rel_embed = self.r2e.weight[torch.stack([x[2] for x in neighbor_list])]
x = torch.cat((neighbor_embed, rel_embed), 1)
x = F.relu(self.w_r1(x))
o_neighbor = F.relu(self.w_r2(x))
att_w = self.att(o_neighbor, pocket_embed, len(neighbor_list))
# print([(a,b) for a,b in zip(label, att_w)])
att_res = torch.mm(o_neighbor.t(), att_w).t()
return (att_res + pocket_embed) / 2