| import os |
| import base64 |
| import requests |
| import numpy as np |
| import faiss |
| import re |
| import logging |
| from pathlib import Path |
|
|
| |
| |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| from sentence_transformers import SentenceTransformer, CrossEncoder |
| from langchain_groq import ChatGroq |
| from langchain_core.prompts import ChatPromptTemplate |
|
|
| |
| try: |
| from rank_bm25 import BM25Okapi |
| except ImportError: |
| BM25Okapi = None |
|
|
| |
| |
| |
| |
| GITHUB_API_KEY = os.getenv("GITHUB_API_KEY") |
| |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
| |
| HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") |
|
|
| CROSS_ENCODER_MODEL = os.getenv("CROSS_ENCODER_MODEL", "cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
| |
| session = requests.Session() |
| session.headers.update({ |
| "Authorization": f"token {GITHUB_API_KEY}", |
| "Accept": "application/vnd.github.v3+json" |
| }) |
|
|
| |
| |
| |
| llm = ChatGroq( |
| model="deepseek-r1-distill-llama-70b", |
| temperature=0.3, |
| max_tokens=512, |
| max_retries=3, |
| api_key=GROQ_API_KEY |
| ) |
|
|
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", |
| """You are a GitHub search optimization expert. |
| |
| Your job is to: |
| 1. Read a user's query about tools, research, or tasks. |
| 2. Detect if the query mentions a specific programming language other than Python (for example, JavaScript or JS). If so, record that language as the target language. |
| 3. Think iteratively and generate your internal chain-of-thought enclosed in <think> ... </think> tags. |
| 4. After your internal reasoning, output up to five GitHub-style search tags or library names that maximize repository discovery. |
| Use as many tags as necessary based on the query's complexity, but never more than five. |
| 5. If you detected a non-Python target language, append an additional tag at the end in the format target-[language] (e.g., target-javascript). |
| If no specific language is mentioned, do not include any target tag. |
| |
| Output Format: |
| tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]] |
| |
| Rules: |
| - Use lowercase and hyphenated keywords (e.g., image-augmentation, chain-of-thought). |
| - Use terms commonly found in GitHub repo names, topics, or descriptions. |
| - Avoid generic terms like "python", "ai", "tool", "project". |
| - Do NOT use full phrases or vague words like "no-code", "framework", or "approach". |
| - Prefer real tools, popular methods, or dataset names when mentioned. |
| - If your output does not strictly match the required format, correct it after your internal reasoning. |
| - Choose high-signal keywords to ensure the search yields the most relevant GitHub repositories. |
| |
| Excellent Examples: |
| |
| Input: "No code tool to augment image and annotation" |
| Output: image-augmentation:albumentations |
| |
| Input: "Repos around chain of thought prompting mainly for finetuned models" |
| Output: chain-of-thought:finetuned-llm |
| |
| Input: "Find repositories implementing data augmentation pipelines in JavaScript" |
| Output: data-augmentation:target-javascript |
| |
| Output must be ONLY the search tags separated by colons. Do not include any extra text, bullet points, or explanations. |
| """), |
| ("human", "{query}") |
| ]) |
| chain = prompt | llm |
|
|
| def valid_tags(tags: str) -> bool: |
| pattern = r'^[a-z0-9-]+(?::[a-z0-9-]+){1,5}$' |
| return re.match(pattern, tags) is not None |
|
|
| def parse_search_tags(response: str) -> str: |
| |
| cleaned = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL) |
| pattern = r'([a-z0-9-]+(?::[a-z0-9-]+){1,5})' |
| match = re.search(pattern, cleaned) |
| if match: |
| return match.group(1).strip() |
| return cleaned.strip() |
|
|
| def iterative_convert_to_search_tags(query: str, max_iterations: int = 2) -> str: |
| print(f"\n [iterative_convert_to_search_tags] Input Query: {query}") |
| refined_query = query |
| tags_output = "" |
| for iteration in range(max_iterations): |
| print(f"\n Iteration {iteration+1}") |
| response = chain.invoke({"query": refined_query}) |
| full_output = response.content.strip() |
| tags_output = parse_search_tags(full_output) |
| print(f"Output Tags: {tags_output}") |
| if valid_tags(tags_output): |
| print("Valid tags format detected.") |
| return tags_output |
| else: |
| print(" Invalid tags format. Requesting refinement...") |
| refined_query = f"{query}\nPlease refine your answer so that the output strictly matches the format: tag1:tag2[:tag3[:tag4[:tag5[:target-language]]]]." |
| print("Final output (may be invalid):", tags_output) |
| return tags_output |
|
|
| |
| |
| |
| def fetch_readme_content(repo_full_name: str) -> str: |
| readme_url = f"https://api.github.com/repos/{repo_full_name}/readme" |
| response = session.get(readme_url) |
| if response.status_code == 200: |
| readme_data = response.json() |
| try: |
| return base64.b64decode(readme_data.get('content', '')).decode('utf-8', errors='replace') |
| except Exception: |
| return "" |
| return "" |
|
|
| def fetch_markdown_contents(repo_full_name: str) -> str: |
| url = f"https://api.github.com/repos/{repo_full_name}/contents" |
| response = session.get(url) |
| contents = "" |
| if response.status_code == 200: |
| items = response.json() |
| for item in items: |
| if item.get("type") == "file" and item.get("name", "").lower().endswith(".md"): |
| file_url = item.get("download_url") |
| if file_url: |
| file_resp = requests.get(file_url) |
| if file_resp.status_code == 200: |
| contents += "\n" + file_resp.text |
| return contents |
|
|
| def fetch_all_markdown(repo_full_name: str) -> str: |
| readme = fetch_readme_content(repo_full_name) |
| other_md = fetch_markdown_contents(repo_full_name) |
| return readme + "\n" + other_md |
|
|
| def fetch_github_repositories(query: str, max_results: int = 10) -> list: |
| url = "https://api.github.com/search/repositories" |
| params = { |
| "q": query, |
| "per_page": max_results |
| } |
| response = session.get(url, params=params) |
| if response.status_code != 200: |
| print(f"Error {response.status_code}: {response.json().get('message')}") |
| return [] |
| repo_list = [] |
| for repo in response.json().get('items', []): |
| repo_link = repo.get('html_url') |
| description = repo.get('description') or "" |
| combined_markdown = fetch_all_markdown(repo.get('full_name')) |
| combined_text = (description + "\n" + combined_markdown).strip() |
| repo_list.append({ |
| "title": repo.get('name', 'No title available'), |
| "link": repo_link, |
| "combined_text": combined_text |
| }) |
| return repo_list |
|
|
| |
| |
| |
| try: |
| |
| model = SentenceTransformer('all-mpnet-base-v2', device='cpu') |
| except Exception as e: |
| print("Error initializing GPU for SentenceTransformer; falling back to CPU:", e) |
| model = SentenceTransformer('all-mpnet-base-v2', device='cpu') |
|
|
| def robust_min_max_norm(scores: np.ndarray) -> np.ndarray: |
| min_val = scores.min() |
| max_val = scores.max() |
| if max_val - min_val < 1e-10: |
| return np.ones_like(scores) |
| return (scores - min_val) / (max_val - min_val) |
|
|
| |
| |
| |
| def cross_encoder_rerank_candidates(candidates: list, query: str, model_name: str, top_n: int = 10) -> list: |
| try: |
| cross_encoder = CrossEncoder(model_name, device='cpu') |
| except Exception as e: |
| print("Error initializing CrossEncoder on GPU; falling back to CPU:", e) |
| cross_encoder = CrossEncoder(model_name, device='cpu') |
| |
| CHUNK_SIZE = 2000 |
| MAX_DOC_LENGTH = 5000 |
| MIN_DOC_LENGTH = 200 |
|
|
| def split_text(text: str, chunk_size: int = CHUNK_SIZE) -> list: |
| return [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] |
| |
| for candidate in candidates: |
| doc = candidate.get("combined_text", "") |
| if len(doc) > MAX_DOC_LENGTH: |
| doc = doc[:MAX_DOC_LENGTH] |
| try: |
| if len(doc) < MIN_DOC_LENGTH: |
| score = cross_encoder.predict([[query, doc]]) |
| if hasattr(score, '__len__') and len(score) == 1: |
| candidate["cross_encoder_score"] = float(score[0]) |
| else: |
| candidate["cross_encoder_score"] = float(score) |
| else: |
| chunks = split_text(doc) |
| pairs = [[query, chunk] for chunk in chunks] |
| scores = cross_encoder.predict(pairs) |
| scores = np.array(scores) |
| max_score = float(np.max(scores)) if scores.size > 0 else 0.0 |
| avg_score = float(np.mean(scores)) if scores.size > 0 else 0.0 |
| candidate["cross_encoder_score"] = 0.5 * max_score + 0.5 * avg_score |
| except Exception as e: |
| logging.debug(f"[cross-encoder] Error scoring candidate {candidate.get('link', 'unknown')}: {e}") |
| candidate["cross_encoder_score"] = 0.0 |
|
|
| all_scores = [candidate["cross_encoder_score"] for candidate in candidates] |
| if all_scores: |
| min_score = min(all_scores) |
| if min_score < 0: |
| for candidate in candidates: |
| candidate["cross_encoder_score"] += -min_score |
|
|
| return candidates |
|
|
| |
| |
| |
| def run_repository_ranking(query: str, num_results: int = 10) -> str: |
| logging.info("[DeepGit] Step 1: Generate search tags from the query.") |
| search_tags = iterative_convert_to_search_tags(query) |
| tag_list = [tag.strip() for tag in search_tags.split(":") if tag.strip()] |
| |
| |
| logging.info("[DeepGit] Step 2: Handle target language extraction.") |
| if any(tag.startswith("target-") for tag in tag_list): |
| target_tag = next(tag for tag in tag_list if tag.startswith("target-")) |
| lang_query = f"language:{target_tag.replace('target-', '')}" |
| tag_list = [tag for tag in tag_list if not tag.startswith("target-")] |
| else: |
| lang_query = "language:python" |
| |
| |
| logging.info("[DeepGit] Step 3: Build advanced search qualifiers and fetch repositories.") |
| advanced_qualifier = "in:name,description,readme" |
| all_repositories = [] |
| |
| for tag in tag_list: |
| github_query = f"{tag} {advanced_qualifier} {lang_query}" |
| logging.info(f"[DeepGit] GitHub Query: {github_query}") |
| repos = fetch_github_repositories(github_query, max_results=15) |
| all_repositories.extend(repos) |
| |
| combined_query = " OR ".join(tag_list) |
| combined_query = f"({combined_query}) {advanced_qualifier} {lang_query}" |
| logging.info(f"[DeepGit] Combined GitHub Query: {combined_query}") |
| repos = fetch_github_repositories(combined_query, max_results=15) |
| all_repositories.extend(repos) |
| |
| unique_repositories = {} |
| for repo in all_repositories: |
| if repo["link"] not in unique_repositories: |
| unique_repositories[repo["link"]] = repo |
| else: |
| existing_text = unique_repositories[repo["link"]]["combined_text"] |
| unique_repositories[repo["link"]]["combined_text"] = existing_text + "\n" + repo["combined_text"] |
| repositories = list(unique_repositories.values()) |
| |
| if not repositories: |
| return "No repositories found for your query." |
| |
| |
| logging.info("[DeepGit] Step 4: Prepare documents for dense retrieval.") |
| docs = [repo.get("combined_text", "") for repo in repositories] |
| |
| |
| logging.info("[DeepGit] Step 5: Compute dense embeddings and scores.") |
| doc_embeddings = model.encode(docs, convert_to_numpy=True, show_progress_bar=True, batch_size=16) |
| if doc_embeddings.ndim == 1: |
| doc_embeddings = doc_embeddings.reshape(1, -1) |
| norms = np.linalg.norm(doc_embeddings, axis=1, keepdims=True) |
| norm_doc_embeddings = doc_embeddings / (norms + 1e-10) |
| |
| query_embedding = model.encode(query, convert_to_numpy=True) |
| if query_embedding.ndim == 1: |
| query_embedding = query_embedding.reshape(1, -1) |
| norm_query_embedding = query_embedding / (np.linalg.norm(query_embedding) + 1e-10) |
| |
| dim = norm_doc_embeddings.shape[1] |
| index = faiss.IndexFlatIP(dim) |
| index.add(norm_doc_embeddings) |
| k = norm_doc_embeddings.shape[0] |
| D, I = index.search(norm_query_embedding, k) |
| dense_scores = D.squeeze() |
| norm_dense_scores = robust_min_max_norm(dense_scores) |
| |
| |
| logging.info("[DeepGit] Step 6: Compute BM25 scores.") |
| if BM25Okapi is not None: |
| tokenized_docs = [re.findall(r'\w+', doc.lower()) for doc in docs] |
| bm25 = BM25Okapi(tokenized_docs) |
| query_tokens = re.findall(r'\w+', query.lower()) |
| bm25_scores = np.array(bm25.get_scores(query_tokens)) |
| norm_bm25_scores = robust_min_max_norm(bm25_scores) |
| else: |
| norm_bm25_scores = np.zeros_like(norm_dense_scores) |
| |
| |
| logging.info("[DeepGit] Step 7: Combine dense and BM25 scores.") |
| alpha = 0.8 |
| combined_scores = alpha * norm_dense_scores + (1 - alpha) * norm_bm25_scores |
| for idx, repo in enumerate(repositories): |
| repo["combined_score"] = float(combined_scores[idx]) |
| |
| |
| logging.info("[DeepGit] Step 8: Initial ranking by combined score.") |
| ranked_repositories = sorted(repositories, key=lambda x: x.get("combined_score", 0), reverse=True) |
| |
| |
| logging.info("[DeepGit] Step 9: Cross-encoder re-ranking.") |
| top_candidates = ranked_repositories[:100] if len(ranked_repositories) > 100 else ranked_repositories |
| cross_encoder_rerank_candidates(top_candidates, query, model_name=CROSS_ENCODER_MODEL, top_n=len(top_candidates)) |
| |
| |
| logging.info("[DeepGit] Step 10: Final scoring and output formatting.") |
| w1 = 0.7 |
| w2 = 0.3 |
| for candidate in top_candidates: |
| candidate["final_score"] = w1 * candidate.get("combined_score", 0) + w2 * candidate.get("cross_encoder_score", 0) |
| |
| final_ranked = sorted(top_candidates, key=lambda x: x.get("final_score", 0), reverse=True)[:num_results] |
| |
| |
| output = "\n=== Ranked Repositories ===\n" |
| for rank, repo in enumerate(final_ranked, 1): |
| output += f"Final Rank: {rank}\n" |
| output += f"Title: {repo['title']}\n" |
| output += f"Link: {repo['link']}\n" |
| output += f"Combined Score: {repo.get('combined_score', 0) * 100:.2f}%\n" |
| output += f"Cross-Encoder Score: {repo.get('cross_encoder_score', 0) * 100:.2f}%\n" |
| output += f"Final Score: {repo.get('final_score', 0) * 100:.2f}%\n" |
| snippet = repo['combined_text'][:300].replace('\n', ' ') |
| output += f"Snippet: {snippet}...\n" |
| output += '-' * 80 + "\n" |
| output += "\n=== End of Results ===" |
| return output |
|
|
| |
| |
| |
| if __name__ == "__main__": |
| test_query = "Chain of thought prompting for reasoning models" |
| result = run_repository_ranking(test_query) |
| print(result) |
|
|