|
|
| """
|
| Created on Sun Sep 15 18:27:17 2024
|
|
|
| @author: salikha4
|
| """
|
|
|
| import os
|
| import csv
|
| import json
|
| import shutil
|
| import random
|
| import argparse
|
| from datetime import datetime
|
| import pandas as pd
|
| import time
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch.utils.data import Dataset, DataLoader, TensorDataset
|
| from torch.optim import Adam
|
| import numpy as np
|
| import warnings
|
| warnings.filterwarnings('ignore')
|
|
|
| def lwm_inference(preprocessed_chs, input_type, lwm_model, device):
|
|
|
| dataset = prepare_for_lwm(preprocessed_chs, device)
|
|
|
| lwm_loss, embedding_data = evaluate(lwm_model, dataset)
|
|
|
|
|
| if input_type == 'cls_emb':
|
| embedding_data = embedding_data[:, 0]
|
| elif input_type == 'channel_emb':
|
| embedding_data = embedding_data[:, 1:]
|
|
|
| dataset = embedding_data.float()
|
| return dataset
|
|
|
| def prepare_for_lwm(data, device, batch_size=64, shuffle=False):
|
|
|
| input_ids, masked_tokens, masked_pos = zip(*data)
|
|
|
| input_ids_tensor = torch.tensor(input_ids, device=device).float()
|
| masked_tokens_tensor = torch.tensor(masked_tokens, device=device).float()
|
| masked_pos_tensor = torch.tensor(masked_pos, device=device).long()
|
|
|
| dataset = TensorDataset(input_ids_tensor, masked_tokens_tensor, masked_pos_tensor)
|
|
|
| return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
|
|
| def evaluate(model, dataloader):
|
|
|
| model.eval()
|
| running_loss = 0.0
|
| outputs = []
|
| criterionMCM = nn.MSELoss()
|
|
|
| with torch.no_grad():
|
| for idx, batch in enumerate(dataloader):
|
| input_ids = batch[0]
|
| masked_tokens = batch[1]
|
| masked_pos = batch[2]
|
|
|
| logits_lm, output = model(input_ids, masked_pos)
|
|
|
| output_batch_preproc = output
|
| outputs.append(output_batch_preproc)
|
|
|
| loss_lm = criterionMCM(logits_lm, masked_tokens)
|
| loss = loss_lm / torch.var(masked_tokens)
|
| running_loss += loss.item()
|
|
|
| average_loss = running_loss / len(dataloader)
|
| output_total = torch.cat(outputs, dim=0)
|
|
|
| return average_loss, output_total
|
|
|
| def create_raw_dataset(data, device):
|
| """Create a dataset for raw channel data."""
|
| input_ids, _, _ = zip(*data)
|
| input_data = torch.tensor(input_ids, device=device)[:, 1:]
|
| return input_data.float()
|
| |