| from sentence_transformers import SentenceTransformer |
| import torch |
| import torch.nn.functional as F |
| import json |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| data_path = "/home/yuqian_fu/Projects/PSALM/check_text_select_scene_600_objname_llavatext_correct.json" |
| |
| |
| with open(data_path, "r") as f: |
| datas = json.load(f) |
|
|
| |
| SIMILARITY_THRESHOLD = 0.5 |
|
|
| |
| model = SentenceTransformer("all-MiniLM-L6-v2").to(device) |
| |
| model.eval() |
|
|
| def get_sbert_embedding(text): |
| """ |
| 使用Sentence-BERT提取文本特征向量 |
| """ |
| with torch.no_grad(): |
| embedding = model.encode(text, convert_to_tensor=True, device=device) |
| return embedding |
|
|
| def calculate_cosine_similarity(embedding1, embedding2): |
| """ |
| 计算两个特征向量的余弦相似性 |
| """ |
| similarity = F.cosine_similarity(embedding1.unsqueeze(0), embedding2.unsqueeze(0)) |
| return similarity.item() |
|
|
| |
| correct_count = 0 |
| num_total = 0 |
| |
| similarity_list = [] |
|
|
| '''v1:以物体个数为分母''' |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| '''v2:以视频帧数为分母''' |
| for data in datas: |
| annos = data["first_frame_anns"] |
| score = 0 |
| for anno in annos: |
| obj_name = anno["obj_name"] |
| llava_text = anno["llava_text"] |
|
|
| |
| obj_embedding = get_sbert_embedding(obj_name) |
| llava_embedding = get_sbert_embedding(llava_text) |
| |
| |
| similarity = calculate_cosine_similarity(obj_embedding, llava_embedding) |
| score += similarity |
| |
| sim_avg = score / len(annos) |
| similarity_list.append(sim_avg) |
| if sim_avg > SIMILARITY_THRESHOLD: |
| correct_count += 1 |
| |
|
|
|
|
| |
| |
| total_samples = len(datas) |
| |
| accuracy = correct_count / total_samples |
| average_similarity = sum(similarity_list) / total_samples |
|
|
| print(f"正确样本数: {correct_count}") |
| print(f"总样本数: {total_samples}") |
| print(f"正确样本比例: {accuracy:.2%}") |
| print(f"平均相似性: {average_similarity:.4f}") |