| |
| import os |
| import numpy as np |
| import faiss |
| from sentence_transformers import SentenceTransformer |
| import requests |
| from sklearn.cluster import KMeans |
| import networkx as nx |
|
|
| def get_vocab(): |
| |
| url = "https://raw.githubusercontent.com/first20hours/google-10000-english/master/google-10000-english-no-swears.txt" |
| response = requests.get(url) |
| if response.status_code == 200: |
| return [word.strip().lower() for word in response.text.splitlines() if word.strip()] |
| else: |
| raise Exception("Failed to fetch vocabulary list") |
|
|
| class CrosswordGenerator: |
| def __init__(self): |
| self.vocab = get_vocab() |
| self.model = SentenceTransformer('all-MiniLM-L6-v2') |
| embeddings = self.model.encode(self.vocab, convert_to_numpy=True) |
| embeddings = np.ascontiguousarray(embeddings, dtype=np.float32) |
| faiss.normalize_L2(embeddings) |
| self.dimension = embeddings.shape[1] |
| |
| self.faiss_index = faiss.IndexFlatIP(self.dimension) |
| self.faiss_index.add(embeddings) |
| self.max_results = 50 |
|
|
| def is_subcategory(self, topic, word): |
| |
| url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&format=json&titles={word.capitalize()}" |
| try: |
| response = requests.get(url).json() |
| pages = response.get('query', {}).get('pages', {}) |
| if pages: |
| cats = list(pages.values())[0].get('categories', []) |
| return any(topic.lower() in cat['title'].lower() for cat in cats) |
| return False |
| except Exception: |
| return False |
|
|
| def generate_words(self, topic, num_words=20): |
| variations = [topic.lower()] |
| |
| |
| |
| |
|
|
| all_results = {} |
|
|
| for variation in variations: |
| |
| topic_embedding = self.model.encode([variation], convert_to_numpy=True) |
| |
| noise_factor = float(os.getenv("SEARCH_RANDOMNESS", "0.02")) |
| if noise_factor > 0: |
| noise = np.random.normal(0, noise_factor, topic_embedding.shape) |
| topic_embedding += noise |
| topic_embedding = np.ascontiguousarray(topic_embedding, dtype=np.float32) |
| faiss.normalize_L2(topic_embedding) |
|
|
| search_size = min(self.max_results * 3, len(self.vocab)) |
| scores, indices = self.faiss_index.search(topic_embedding, search_size) |
|
|
| |
| initial_results = [] |
| for i in range(len(indices[0])): |
| idx = indices[0][i] |
| score = scores[0][i] |
| if score > 0.3: |
| initial_results.append(self.vocab[idx]) |
|
|
| |
| subcats = [w for w in initial_results[:30] if self.is_subcategory(topic, w)] |
| print(f"subcats {subcats}") |
|
|
| |
| if not subcats and len(initial_results) >= 3: |
| result_embeddings = self.model.encode(initial_results, convert_to_numpy=True) |
| result_embeddings = np.ascontiguousarray(result_embeddings, dtype=np.float32) |
| faiss.normalize_L2(result_embeddings) |
| kmeans = KMeans(n_clusters=min(3, len(initial_results)), random_state=42).fit(result_embeddings) |
| cluster_centers = kmeans.cluster_centers_.astype(np.float32) |
| faiss.normalize_L2(cluster_centers) |
| _, subcat_indices = self.faiss_index.search(cluster_centers, 1) |
| subcats = [self.vocab[subcat_indices[j][0]] for j in range(len(subcat_indices))] |
|
|
| |
| for level, subs in enumerate([subcats], start=1): |
| for sub in subs: |
| sub_embedding = self.model.encode([sub], convert_to_numpy=True) |
| sub_embedding = np.ascontiguousarray(sub_embedding, dtype=np.float32) |
| faiss.normalize_L2(sub_embedding) |
| sub_scores, sub_indices = self.faiss_index.search(sub_embedding, search_size) |
| for i in range(len(sub_indices[0])): |
| idx = sub_indices[0][i] |
| score = sub_scores[0][i] |
| if score > 0.3: |
| w = self.vocab[idx] |
| |
| weighted_score = score * (0.8 ** level) |
| all_results[w] = all_results.get(w, 0) + weighted_score |
|
|
| |
| for i in range(len(indices[0])): |
| idx = indices[0][i] |
| score = scores[0][i] |
| if score > 0.3: |
| w = self.vocab[idx] |
| all_results[w] = all_results.get(w, 0) + score |
|
|
| |
| G = nx.Graph() |
| G.add_node(topic) |
| for w, score in all_results.items(): |
| G.add_edge(topic, w, weight=score) |
| pr = nx.pagerank(G, weight='weight') |
|
|
| |
| sorted_results = sorted(pr.items(), key=lambda x: x[1], reverse=True) |
| final_words = [w for w, _ in sorted_results if w != topic][:num_words] |
|
|
| return final_words |
|
|
| if __name__ == "__main__": |
| generator = CrosswordGenerator() |
| topics = ["animal", "animal", "science", "technology", "food", "indian food", "chinese food"] |
| for topic in topics: |
| print(f"------------- {topic} ------------") |
| generated_words = generator.generate_words(topic) |
| sorted_generated_words = sorted(generated_words) |
| print(f"Generated words for topic '{topic}':") |
| print(sorted_generated_words) |
|
|