| from gensim.models import KeyedVectors |
| import json |
|
|
| |
| |
| |
| model_path = "/home/yuqian_fu/Projects/VQA/GoogleNews-vectors-negative300.bin.gz" |
| word2vec_model = KeyedVectors.load_word2vec_format(model_path, binary=True) |
|
|
| |
| data = [ |
| {"obj_name": "bicycle chain", "llava_text": "bicycle wheel"}, |
| {"obj_name": "cup", "llava_text": "mug"}, |
| {"obj_name": "pot", "llava_text": "pan"}, |
| |
| ] |
|
|
| |
| SIMILARITY_THRESHOLD = 0.8 |
|
|
| def get_word2vec_embedding(word): |
| """ |
| 获取单词的 Word2Vec 嵌入向量 |
| """ |
| try: |
| return word2vec_model[word] |
| except KeyError: |
| print(f"Word '{word}' not found in Word2Vec vocabulary.") |
| return None |
|
|
| def calculate_cosine_similarity(vec1, vec2): |
| """ |
| 计算两个向量的余弦相似性 |
| """ |
| return word2vec_model.cosine_similarities(vec1, [vec2])[0] |
|
|
| |
| correct_count = 0 |
| similarity_list = [] |
|
|
| for item in data: |
| obj_name = item["obj_name"] |
| llava_text = item["llava_text"] |
|
|
| |
| obj_vec = get_word2vec_embedding(obj_name) |
| llava_vec = get_word2vec_embedding(llava_text) |
|
|
| if obj_vec is None or llava_vec is None: |
| similarity = 0.0 |
| else: |
| |
| similarity = calculate_cosine_similarity(obj_vec, llava_vec) |
|
|
| print(f"Similarity between '{obj_name}' and '{llava_text}': {similarity:.4f}") |
| similarity_list.append(similarity) |
|
|
| |
| if similarity > SIMILARITY_THRESHOLD: |
| correct_count += 1 |
|
|
| |
| total_samples = len(data) |
| accuracy = correct_count / total_samples |
| average_similarity = sum(similarity_list) / total_samples |
|
|
| print(f"正确样本数: {correct_count}") |
| print(f"总样本数: {total_samples}") |
| print(f"正确样本比例: {accuracy:.2%}") |
| print(f"平均相似性: {average_similarity:.4f}") |