| |
|
|
| import os |
| import numpy as np |
| import faiss |
| from sentence_transformers import SentenceTransformer |
| import requests |
| from sklearn.cluster import KMeans |
| import networkx as nx |
| import csv |
|
|
| def get_vocab(): |
| |
| url = "https://raw.githubusercontent.com/dwyl/english-words/master/words.txt" |
| response = requests.get(url) |
| if response.status_code == 200: |
| return [word.strip().lower() for word in response.text.splitlines() if word.strip() and len(word) > 2] |
| else: |
| raise Exception("Failed to fetch vocabulary list") |
|
|
| class CrosswordGenerator2: |
| def get_dict_vocab(self): |
| """Read the dictionary CSV file and return list of words.""" |
| dict_path = os.path.join(os.path.dirname(__file__), 'dict-words', 'dict.csv') |
| words = [] |
| |
| try: |
| with open(dict_path, 'r', encoding='utf-8') as csvfile: |
| reader = csv.DictReader(csvfile) |
| for row in reader: |
| word = row['word'].strip().lower() |
| if word and len(word) > 2: |
| words.append(word) |
| except FileNotFoundError: |
| raise Exception(f"Dictionary file not found: {dict_path}") |
| except Exception as e: |
| raise Exception(f"Error reading dictionary file: {e}") |
| |
| return words |
|
|
| def __init__(self, cache_dir='./model_cache'): |
| self.vocab = self.get_dict_vocab() |
| self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', cache_folder=cache_dir) |
| embeddings = self.model.encode(self.vocab, convert_to_numpy=True) |
| embeddings = np.ascontiguousarray(embeddings, dtype=np.float32) |
| faiss.normalize_L2(embeddings) |
| self.dimension = embeddings.shape[1] |
| self.faiss_index = faiss.IndexFlatIP(self.dimension) |
| self.faiss_index.add(embeddings) |
| self.max_results = 50 |
|
|
| def get_wikipedia_subcats(self, topic): |
| topic_cap = topic.capitalize().replace(' ', '_') |
| url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{topic_cap}&cmtype=subcat&format=json&cmlimit=50" |
| try: |
| response = requests.get(url).json() |
| members = response.get('query', {}).get('categorymembers', []) |
| if members: |
| return [member['title'].replace('Category:', '').lower() for member in members] |
| else: |
| |
| search_url = f"https://en.wikipedia.org/w/api.php?action=query&list=search&srsearch={topic}&format=json" |
| search_response = requests.get(search_url).json() |
| search_results = search_response.get('query', {}).get('search', []) |
| if search_results: |
| main_title = search_results[0]['title'] |
| cat_url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&titles={main_title}&format=json&cllimit=50" |
| cat_response = requests.get(cat_url).json() |
| pages = cat_response.get('query', {}).get('pages', {}) |
| if pages: |
| cats = list(pages.values())[0].get('categories', []) |
| cat_titles = [cat['title'].replace('Category:', '').lower() for cat in cats] |
| relevant_cats = [c for c in cat_titles if any(t in c for t in topic.lower().split())] |
| if relevant_cats: |
| subcat_topic = relevant_cats[0].capitalize().replace(' ', '_') |
| sub_url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{subcat_topic}&cmtype=subcat&format=json&cmlimit=50" |
| sub_response = requests.get(sub_url).json() |
| sub_members = sub_response.get('query', {}).get('categorymembers', []) |
| return [m['title'].replace('Category:', '').lower() for m in sub_members] |
| return [] |
| except Exception: |
| return [] |
|
|
| def get_category_pages(self, category): |
| cat_cap = category.capitalize().replace(' ', '_') |
| url = f"https://en.wikipedia.org/w/api.php?action=query&list=categorymembers&cmtitle=Category:{cat_cap}&cmtype=page&format=json&cmlimit=50" |
| try: |
| response = requests.get(url).json() |
| members = response.get('query', {}).get('categorymembers', []) |
| |
| return [member['title'].lower() for member in members if ' ' not in member['title'] and len(member['title']) > 3] |
| except Exception: |
| return [] |
|
|
| def is_subcategory(self, topic, word): |
| url = f"https://en.wikipedia.org/w/api.php?action=query&prop=categories&format=json&titles={word.capitalize()}" |
| try: |
| response = requests.get(url).json() |
| pages = response.get('query', {}).get('pages', {}) |
| if pages: |
| cats = list(pages.values())[0].get('categories', []) |
| return any(topic.lower() in cat['title'].lower() for cat in cats) |
| return False |
| except Exception: |
| return False |
|
|
| def generate_words(self, topic, num_words=20): |
| variations = [topic.lower()] |
| if topic.endswith('s'): |
| variations.append(topic[:-1]) |
| else: |
| variations.append(topic + 's') |
|
|
| all_results = {} |
|
|
| subcats = self.get_wikipedia_subcats(topic) |
| print('wiki subcats', subcats) |
|
|
| |
| for sub in subcats: |
| pages = self.get_category_pages(sub) |
| for p in pages: |
| |
| all_results[p] = all_results.get(p, 0) + 0.8 |
|
|
| for variation in variations: |
| |
| topic_embedding = self.model.encode([variation], convert_to_numpy=True) |
| noise_factor = float(os.getenv("SEARCH_RANDOMNESS", "0.02")) |
| if noise_factor > 0: |
| noise = np.random.normal(0, noise_factor, topic_embedding.shape) |
| topic_embedding += noise |
| topic_embedding = np.ascontiguousarray(topic_embedding, dtype=np.float32) |
| faiss.normalize_L2(topic_embedding) |
|
|
| search_size = min(self.max_results * 3, len(self.vocab)) |
| scores, indices = self.faiss_index.search(topic_embedding, search_size) |
|
|
| initial_results = [] |
| for i in range(len(indices[0])): |
| idx = indices[0][i] |
| score = scores[0][i] |
| if score > 0.3: |
| initial_results.append(self.vocab[idx]) |
|
|
| |
| if not subcats: |
| additional_subcats = [w for w in initial_results[:30] if self.is_subcategory(topic, w)] |
| subcats.extend(additional_subcats) |
|
|
| |
| if not subcats and len(initial_results) >= 3: |
| result_embeddings = self.model.encode(initial_results, convert_to_numpy=True) |
| result_embeddings = np.ascontiguousarray(result_embeddings, dtype=np.float32) |
| faiss.normalize_L2(result_embeddings) |
| kmeans = KMeans(n_clusters=min(3, len(initial_results)), random_state=42).fit(result_embeddings) |
| cluster_centers = kmeans.cluster_centers_.astype(np.float32) |
| faiss.normalize_L2(cluster_centers) |
| _, subcat_indices = self.faiss_index.search(cluster_centers, 1) |
| subcats = [self.vocab[subcat_indices[j][0]] for j in range(len(subcat_indices))] |
|
|
| |
| for level, subs in enumerate([subcats], start=1): |
| for sub in subs: |
| sub_embedding = self.model.encode([sub], convert_to_numpy=True) |
| sub_embedding = np.ascontiguousarray(sub_embedding, dtype=np.float32) |
| faiss.normalize_L2(sub_embedding) |
| sub_scores, sub_indices = self.faiss_index.search(sub_embedding, search_size) |
| for i in range(len(sub_indices[0])): |
| idx = sub_indices[0][i] |
| score = sub_scores[0][i] |
| if score > 0.3: |
| w = self.vocab[idx] |
| weighted_score = score * (0.8 ** level) |
| all_results[w] = all_results.get(w, 0) + weighted_score |
|
|
| |
| for i in range(len(indices[0])): |
| idx = indices[0][i] |
| score = scores[0][i] |
| if score > 0.3: |
| w = self.vocab[idx] |
| all_results[w] = all_results.get(w, 0) + score |
|
|
| |
| G = nx.Graph() |
| G.add_node(topic) |
| for w, score in all_results.items(): |
| G.add_edge(topic, w, weight=score) |
| pr = nx.pagerank(G, weight='weight') |
|
|
| |
| sorted_results = sorted(pr.items(), key=lambda x: x[1], reverse=True) |
| final_words = [w for w, _ in sorted_results if w != topic][:num_words] |
|
|
| return final_words |
|
|
| if __name__ == "__main__": |
| |
| cache_dir = os.path.join(os.path.dirname(__file__), 'model_cache') |
| os.makedirs(cache_dir, exist_ok=True) |
| |
| generator = CrosswordGenerator2(cache_dir=cache_dir) |
| topics = ["animal", "animal", "science", "technology", "food", "indian food", "chinese food"] |
| for topic in topics: |
| print(f"------------- {topic} ------------") |
| generated_words = generator.generate_words(topic) |
| sorted_generated_words = sorted(generated_words) |
| print(f"Generated words for topic '{topic}':") |
| print(sorted_generated_words) |
|
|