| |
| |
|
|
| import torch |
| import torch.nn.functional as F |
| import pandas as pd |
| import sys |
| import os |
|
|
| from transformers.utils.hub import cached_file |
|
|
| resolved_module_file = cached_file( |
| 'Cainiao-AI/TAAS', |
| 'htc_mask_dict_old.pkl' |
| ) |
|
|
| htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333] |
| htc_mask_dict = pd.read_pickle(resolved_module_file) |
| import numpy as np |
| import operator |
| def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6): |
| acc_cnt = np.array([0, 0, 0, 0, 0]) |
| y = y.view(-1, sequence_len, 5).tolist() |
| predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist() |
| batch_size = len(y) |
| total_cnt = np.array([0, 0, 0, 0, 0]) |
| for batch_i in range(batch_size): |
| for index, s2 in enumerate(y[batch_i]): |
| for c, i in enumerate(range(5)): |
| y_l10 = y[batch_i][index][:i+1] |
| p_l10 = predicted[batch_i][index][:i+1] |
| if -100 in y_l10: |
| break |
| |
| if operator.eq(y_l10, p_l10): |
| acc_cnt[c] += 1 |
| total_cnt[c] += 1 |
| |
| return acc_cnt, total_cnt |
| |
|
|
| class HTCLoss(torch.nn.Module): |
| def __init__(self, device, reduction='mean', using_htc = True): |
| super(HTCLoss, self).__init__() |
| self.reduction = reduction |
| self.htc_weights = htc_weights |
| self.device = device |
| self.using_htc = using_htc |
| self.htc_mask_dict = htc_mask_dict |
| for key, value in self.htc_mask_dict.items(): |
| |
| self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device) |
|
|
| def forward(self, logits, target): |
| |
| target = target.reshape(-1, 1) |
| target_mask = target != -100 |
| target_mask = target_mask.squeeze() |
| target_mask_idx = torch.where(target == -100) |
| target_new = target.clone() |
| target_new[target_mask_idx] = 0 |
| predict_res = [] |
| if not self.using_htc: |
| log_pro = -1.0 * F.log_softmax(logits, dim=1) |
| |
| |
| |
| |
| else: |
| |
| logits_reshaped = logits.clone() |
| logits_reshaped = logits_reshaped.reshape(-1, 5, 100) |
| _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) |
| aa_predicted += 1 |
| logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) |
| logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] |
| for sample_idx, aa in enumerate(aa_predicted): |
| bb_idx = htc_mask_dict['{:02d}'.format(aa)] |
| _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) |
| bb = bb_idx[bb_idy] |
| logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] |
| cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] |
| _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) |
| logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] |
| cc = cc_idx[cc_idy] |
| d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] |
| _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) |
| logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] |
| d = d_idx[d_idy] |
| ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] |
| _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) |
| logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] |
| ee = ee_idx[ee_idy] |
| predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) |
| |
| |
| |
| |
| logits_new = logits_new.reshape(-1, 100) |
| log_pro = -1.0 * F.log_softmax(logits_new, dim=1) |
| logits = logits.contiguous().view(-1, 100) |
| one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) |
| one_hot = one_hot.scatter_(1, target_new, 1) |
| loss = torch.mul(log_pro, one_hot).sum(dim=1) |
| loss = loss*target_mask |
| bs = int(loss.shape[0] / 5) |
| w_loss = [] |
| for i in range(bs): |
| w_loss.extend(self.htc_weights) |
| w_loss = torch.FloatTensor(w_loss).to(self.device) |
| loss = loss.mul(w_loss) * 5 |
| if self.reduction == 'mean': |
| loss = loss[torch.where(loss>0)].mean() |
| elif self.reduction == 'sum': |
| loss = loss[torch.where(loss>0)].sum() |
| return loss, predict_res |
|
|
| def get_htc_code(self, logits): |
| logits_reshaped = logits.clone() |
| logits_reshaped = logits_reshaped.reshape(-1, 5, 100) |
| _, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1) |
| aa_predicted += 1 |
| logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device) |
| logits_new[:,0,1:32] = logits_reshaped[:,0,1:32] |
| predict_res = [] |
| for sample_idx, aa in enumerate(aa_predicted): |
| bb_idx = htc_mask_dict['{:02d}'.format(aa)] |
| _, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0) |
| bb = bb_idx[bb_idy] |
| logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx] |
| cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)] |
| _, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0) |
| logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx] |
| cc = cc_idx[cc_idy] |
| d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)] |
| _, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0) |
| logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx] |
| d = d_idx[d_idy] |
| ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)] |
| _, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0) |
| logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx] |
| ee = ee_idx[ee_idy] |
| predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()]) |
| return predict_res |
| |
|
|