| """ |
| π§ Perplexity AI Integration for AI Dataset Studio |
| Automatically discovers relevant sources based on project descriptions |
| """ |
|
|
| import os |
| import requests |
| import json |
| import logging |
| import time |
| import re |
| from typing import List, Dict, Optional, Tuple |
| from urllib.parse import urlparse, urljoin |
| from dataclasses import dataclass |
| from enum import Enum |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class SearchType(Enum): |
| """Types of searches supported by Perplexity AI""" |
| GENERAL = "general" |
| ACADEMIC = "academic" |
| NEWS = "news" |
| SOCIAL = "social" |
| TECHNICAL = "technical" |
|
|
| @dataclass |
| class SourceResult: |
| """Structure for individual source results""" |
| url: str |
| title: str |
| description: str |
| relevance_score: float |
| source_type: str |
| domain: str |
| publication_date: Optional[str] = None |
| author: Optional[str] = None |
|
|
| @dataclass |
| class SearchResults: |
| """Container for search results""" |
| query: str |
| sources: List[SourceResult] |
| total_found: int |
| search_time: float |
| perplexity_response: str |
| suggestions: List[str] |
|
|
| class PerplexityClient: |
| """ |
| π§ Perplexity AI Client for Smart Source Discovery |
| |
| Features: |
| - Intelligent source discovery based on project descriptions |
| - Multiple search strategies (academic, news, technical, etc.) |
| - Quality filtering and relevance scoring |
| - Rate limiting and error handling |
| - Domain validation and safety checks |
| """ |
| |
| def __init__(self, api_key: Optional[str] = None): |
| """ |
| Initialize Perplexity AI client |
| |
| Args: |
| api_key: Perplexity API key (if not provided, will try env var) |
| """ |
| self.api_key = api_key or os.getenv('PERPLEXITY_API_KEY') |
| self.base_url = "https://api.perplexity.ai" |
| self.session = requests.Session() |
| |
| |
| if self.api_key: |
| self.session.headers.update({ |
| 'Authorization': f'Bearer {self.api_key}', |
| 'Content-Type': 'application/json', |
| 'User-Agent': 'AI-Dataset-Studio/1.0' |
| }) |
| |
| |
| self.last_request_time = 0 |
| self.min_request_interval = 1.0 |
| |
| |
| self.max_retries = 3 |
| self.timeout = 30 |
| |
| logger.info("π§ Perplexity AI client initialized") |
| |
| def _validate_api_key(self) -> bool: |
| """Validate that API key is available and working""" |
| if not self.api_key: |
| logger.error("β No Perplexity API key found. Set PERPLEXITY_API_KEY environment variable.") |
| return False |
| return True |
| |
| def _rate_limit(self): |
| """Implement rate limiting to respect API limits""" |
| current_time = time.time() |
| time_since_last = current_time - self.last_request_time |
| |
| if time_since_last < self.min_request_interval: |
| sleep_time = self.min_request_interval - time_since_last |
| logger.debug(f"β±οΈ Rate limiting: sleeping {sleep_time:.2f}s") |
| time.sleep(sleep_time) |
| |
| self.last_request_time = time.time() |
| |
| def _make_request(self, payload: Dict) -> Optional[Dict]: |
| """ |
| Make API request to Perplexity with error handling |
| |
| Args: |
| payload: Request payload |
| |
| Returns: |
| API response or None if failed |
| """ |
| if not self._validate_api_key(): |
| return None |
| |
| self._rate_limit() |
| |
| for attempt in range(self.max_retries): |
| try: |
| logger.debug(f"π‘ Making Perplexity API request (attempt {attempt + 1})") |
| |
| response = self.session.post( |
| f"{self.base_url}/chat/completions", |
| json=payload, |
| timeout=self.timeout |
| ) |
| |
| if response.status_code == 200: |
| logger.debug("β
Perplexity API request successful") |
| return response.json() |
| elif response.status_code == 429: |
| logger.warning("π¦ Rate limit hit, waiting longer...") |
| time.sleep(2 ** attempt) |
| continue |
| else: |
| logger.error(f"β API request failed: {response.status_code} - {response.text}") |
| |
| except requests.exceptions.Timeout: |
| logger.warning(f"β° Request timeout (attempt {attempt + 1})") |
| except requests.exceptions.RequestException as e: |
| logger.error(f"π Request error: {str(e)}") |
| |
| if attempt < self.max_retries - 1: |
| time.sleep(2 ** attempt) |
| |
| logger.error("β All retry attempts failed") |
| return None |
| |
| def discover_sources( |
| self, |
| project_description: str, |
| search_type: SearchType = SearchType.GENERAL, |
| max_sources: int = 20, |
| include_academic: bool = True, |
| include_news: bool = True, |
| domain_filter: Optional[List[str]] = None |
| ) -> SearchResults: |
| """ |
| π Discover relevant sources based on project description |
| |
| Args: |
| project_description: User's project description |
| search_type: Type of search to perform |
| max_sources: Maximum number of sources to return |
| include_academic: Include academic sources |
| include_news: Include news sources |
| domain_filter: Optional list of domains to focus on |
| |
| Returns: |
| SearchResults object with discovered sources |
| """ |
| start_time = time.time() |
| |
| logger.info(f"π Discovering sources for: {project_description[:100]}...") |
| |
| |
| search_prompt = self._build_search_prompt( |
| project_description, |
| search_type, |
| max_sources, |
| include_academic, |
| include_news, |
| domain_filter |
| ) |
| |
| |
| payload = { |
| "model": "llama-3.1-sonar-large-128k-online", |
| "messages": [ |
| { |
| "role": "system", |
| "content": "You are an expert research assistant specializing in finding high-quality, relevant sources for AI/ML dataset creation. Always provide specific URLs, titles, and descriptions." |
| }, |
| { |
| "role": "user", |
| "content": search_prompt |
| } |
| ], |
| "max_tokens": 4000, |
| "temperature": 0.3, |
| "top_p": 0.9 |
| } |
| |
| |
| response = self._make_request(payload) |
| |
| if not response: |
| logger.error("β Failed to get response from Perplexity API") |
| return self._create_empty_results(project_description, time.time() - start_time) |
| |
| |
| try: |
| content = response['choices'][0]['message']['content'] |
| sources = self._parse_sources_from_response(content) |
| suggestions = self._extract_suggestions(content) |
| |
| search_time = time.time() - start_time |
| |
| logger.info(f"β
Found {len(sources)} sources in {search_time:.2f}s") |
| |
| return SearchResults( |
| query=project_description, |
| sources=sources[:max_sources], |
| total_found=len(sources), |
| search_time=search_time, |
| perplexity_response=content, |
| suggestions=suggestions |
| ) |
| |
| except Exception as e: |
| logger.error(f"β Error parsing Perplexity response: {str(e)}") |
| return self._create_empty_results(project_description, time.time() - start_time) |
| |
| def _build_search_prompt( |
| self, |
| project_description: str, |
| search_type: SearchType, |
| max_sources: int, |
| include_academic: bool, |
| include_news: bool, |
| domain_filter: Optional[List[str]] |
| ) -> str: |
| """Build optimized search prompt for Perplexity AI""" |
| |
| prompt = f""" |
| Find {max_sources} high-quality, diverse sources for an AI/ML dataset creation project: |
| |
| PROJECT DESCRIPTION: {project_description} |
| |
| SEARCH REQUIREMENTS: |
| - Find sources with extractable text content suitable for ML training |
| - Prioritize sources with structured, high-quality content |
| - Include diverse perspectives and data types |
| - Focus on sources that are legally scrapable (respect robots.txt) |
| |
| SEARCH TYPE: {search_type.value} |
| """ |
| |
| if include_academic: |
| prompt += "\n- Include academic papers, research articles, and scholarly sources" |
| |
| if include_news: |
| prompt += "\n- Include news articles, press releases, and journalistic content" |
| |
| if domain_filter: |
| prompt += f"\n- Focus on these domains: {', '.join(domain_filter)}" |
| |
| prompt += f""" |
| |
| OUTPUT FORMAT: |
| For each source, provide: |
| 1. **URL**: Direct link to the content |
| 2. **Title**: Clear, descriptive title |
| 3. **Description**: 2-3 sentence summary of content and why it's relevant |
| 4. **Type**: [academic/news/blog/government/technical/forum/social] |
| 5. **Quality Score**: 1-10 rating for dataset suitability |
| |
| ADDITIONAL REQUIREMENTS: |
| - Verify URLs are accessible and contain substantial text |
| - Avoid paywalled or login-required content when possible |
| - Prioritize sources with consistent formatting |
| - Include publication dates when available |
| - Suggest related search terms for expanding the dataset |
| |
| Please provide concrete, actionable sources that can be immediately scraped for dataset creation. |
| """ |
| |
| return prompt |
| |
| def _parse_sources_from_response(self, content: str) -> List[SourceResult]: |
| """Parse source information from Perplexity AI response""" |
| sources = [] |
| |
| |
| |
| url_pattern = r'https?://[^\s<>"{}|\\^`\[\]]+[^\s<>"{}|\\^`\[\].,!?;:]' |
| |
| |
| sections = re.split(r'\n\s*\n', content) |
| |
| for section in sections: |
| |
| urls = re.findall(url_pattern, section, re.IGNORECASE) |
| |
| if urls: |
| for url in urls: |
| try: |
| |
| url = url.strip() |
| |
| |
| title = self._extract_title_from_section(section, url) |
| |
| |
| description = self._extract_description_from_section(section, url) |
| |
| |
| source_type = self._determine_source_type(url, section) |
| |
| |
| relevance_score = self._calculate_relevance_score(section, url) |
| |
| |
| domain = self._extract_domain(url) |
| |
| |
| if self._is_valid_url(url): |
| source = SourceResult( |
| url=url, |
| title=title, |
| description=description, |
| relevance_score=relevance_score, |
| source_type=source_type, |
| domain=domain |
| ) |
| sources.append(source) |
| |
| except Exception as e: |
| logger.debug(f"β οΈ Error parsing source: {str(e)}") |
| continue |
| |
| |
| seen_urls = set() |
| unique_sources = [] |
| |
| for source in sources: |
| if source.url not in seen_urls: |
| seen_urls.add(source.url) |
| unique_sources.append(source) |
| |
| |
| unique_sources.sort(key=lambda x: x.relevance_score, reverse=True) |
| |
| return unique_sources |
| |
| def _extract_title_from_section(self, section: str, url: str) -> str: |
| """Extract title from section text""" |
| lines = section.split('\n') |
| |
| for line in lines: |
| if url in line: |
| |
| title_patterns = [ |
| r'\*\*([^*]+)\*\*', |
| r'#{1,6}\s*([^\n]+)', |
| r'Title:\s*([^\n]+)', |
| r'([^:\n]+):?\s*' + re.escape(url), |
| ] |
| |
| for pattern in title_patterns: |
| match = re.search(pattern, line, re.IGNORECASE) |
| if match: |
| return match.group(1).strip() |
| |
| |
| return self._extract_domain(url) |
| |
| def _extract_description_from_section(self, section: str, url: str) -> str: |
| """Extract description from section text""" |
| |
| lines = section.split('\n') |
| description_lines = [] |
| |
| for line in lines: |
| if url not in line and line.strip(): |
| |
| clean_line = re.sub(r'^[#*\-\d\.]+\s*', '', line.strip()) |
| if len(clean_line) > 20: |
| description_lines.append(clean_line) |
| |
| description = ' '.join(description_lines) |
| |
| |
| if len(description) > 200: |
| description = description[:200] + "..." |
| |
| return description or "High-quality source for dataset creation" |
| |
| def _determine_source_type(self, url: str, section: str) -> str: |
| """Determine the type of source based on URL and context""" |
| url_lower = url.lower() |
| section_lower = section.lower() |
| |
| |
| if any(domain in url_lower for domain in [ |
| 'arxiv.org', 'scholar.google', 'pubmed', 'ieee.org', |
| 'acm.org', 'springer.com', 'elsevier.com', 'nature.com', |
| 'sciencedirect.com', 'jstor.org' |
| ]): |
| return 'academic' |
| |
| |
| if any(domain in url_lower for domain in [ |
| 'cnn.com', 'bbc.com', 'reuters.com', 'ap.org', 'nytimes.com', |
| 'washingtonpost.com', 'theguardian.com', 'bloomberg.com', |
| 'techcrunch.com', 'wired.com' |
| ]): |
| return 'news' |
| |
| |
| if '.gov' in url_lower or 'government' in section_lower: |
| return 'government' |
| |
| |
| if any(domain in url_lower for domain in [ |
| 'docs.', 'documentation', 'github.com', 'stackoverflow.com', |
| 'medium.com', 'dev.to' |
| ]): |
| return 'technical' |
| |
| |
| if any(domain in url_lower for domain in [ |
| 'twitter.com', 'reddit.com', 'linkedin.com', 'facebook.com' |
| ]): |
| return 'social' |
| |
| |
| return 'blog' |
| |
| def _calculate_relevance_score(self, section: str, url: str) -> float: |
| """Calculate relevance score for a source (0-10)""" |
| score = 5.0 |
| |
| section_lower = section.lower() |
| url_lower = url.lower() |
| |
| |
| quality_indicators = [ |
| 'research', 'study', 'analysis', 'comprehensive', 'detailed', |
| 'expert', 'professional', 'authoritative', 'peer-reviewed', |
| 'dataset', 'data', 'machine learning', 'ai', 'artificial intelligence' |
| ] |
| |
| for indicator in quality_indicators: |
| if indicator in section_lower: |
| score += 0.5 |
| |
| |
| if any(domain in url_lower for domain in ['arxiv.org', 'scholar.google', 'pubmed']): |
| score += 2.0 |
| |
| |
| if '.gov' in url_lower: |
| score += 1.5 |
| |
| |
| if any(domain in url_lower for domain in ['twitter.com', 'facebook.com']): |
| score -= 1.0 |
| |
| |
| return min(score, 10.0) |
| |
| def _extract_domain(self, url: str) -> str: |
| """Extract domain from URL""" |
| try: |
| parsed = urlparse(url) |
| return parsed.netloc |
| except: |
| return "unknown" |
| |
| def _is_valid_url(self, url: str) -> bool: |
| """Validate URL format and basic accessibility""" |
| try: |
| parsed = urlparse(url) |
| return all([parsed.scheme, parsed.netloc]) |
| except: |
| return False |
| |
| def _extract_suggestions(self, content: str) -> List[str]: |
| """Extract search suggestions from Perplexity response""" |
| suggestions = [] |
| |
| |
| suggestion_patterns = [ |
| r'related search terms?:?\s*([^\n]+)', |
| r'you might also search for:?\s*([^\n]+)', |
| r'additional keywords?:?\s*([^\n]+)', |
| r'suggestions?:?\s*([^\n]+)' |
| ] |
| |
| for pattern in suggestion_patterns: |
| matches = re.findall(pattern, content, re.IGNORECASE) |
| for match in matches: |
| |
| terms = re.split(r'[,;|]', match) |
| suggestions.extend([term.strip().strip('"\'') for term in terms if term.strip()]) |
| |
| return suggestions[:10] |
| |
| def _create_empty_results(self, query: str, search_time: float) -> SearchResults: |
| """Create empty results object for failed searches""" |
| return SearchResults( |
| query=query, |
| sources=[], |
| total_found=0, |
| search_time=search_time, |
| perplexity_response="", |
| suggestions=[] |
| ) |
| |
| def search_with_keywords(self, keywords: List[str], search_type: SearchType = SearchType.GENERAL) -> SearchResults: |
| """ |
| π Search using specific keywords |
| |
| Args: |
| keywords: List of search keywords |
| search_type: Type of search to perform |
| |
| Returns: |
| SearchResults object |
| """ |
| query = " ".join(keywords) |
| return self.discover_sources( |
| project_description=f"Find sources related to: {query}", |
| search_type=search_type |
| ) |
| |
| def get_domain_sources(self, domain: str, topic: str, max_sources: int = 10) -> SearchResults: |
| """ |
| π Find sources from a specific domain |
| |
| Args: |
| domain: Target domain (e.g., "nature.com") |
| topic: Topic to search for |
| max_sources: Maximum sources to return |
| |
| Returns: |
| SearchResults object |
| """ |
| return self.discover_sources( |
| project_description=f"Find articles about {topic} from {domain}", |
| domain_filter=[domain], |
| max_sources=max_sources |
| ) |
| |
| def validate_sources(self, sources: List[SourceResult]) -> List[SourceResult]: |
| """ |
| β
Validate and filter sources for quality and accessibility |
| |
| Args: |
| sources: List of source results to validate |
| |
| Returns: |
| Filtered list of valid sources |
| """ |
| valid_sources = [] |
| |
| for source in sources: |
| try: |
| |
| if not self._is_valid_url(source.url): |
| logger.debug(f"β οΈ Invalid URL: {source.url}") |
| continue |
| |
| |
| domain = self._extract_domain(source.url) |
| if not domain or domain == "unknown": |
| logger.debug(f"β οΈ Unknown domain: {source.url}") |
| continue |
| |
| |
| if source.relevance_score < 3.0: |
| logger.debug(f"β οΈ Low quality score: {source.url}") |
| continue |
| |
| valid_sources.append(source) |
| |
| except Exception as e: |
| logger.debug(f"β οΈ Error validating source {source.url}: {str(e)}") |
| continue |
| |
| logger.info(f"β
Validated {len(valid_sources)} out of {len(sources)} sources") |
| return valid_sources |
| |
| def export_sources(self, results: SearchResults, format: str = "json") -> str: |
| """ |
| π Export search results to various formats |
| |
| Args: |
| results: SearchResults object to export |
| format: Export format ("json", "csv", "markdown") |
| |
| Returns: |
| Exported data as string |
| """ |
| if format.lower() == "json": |
| return self._export_json(results) |
| elif format.lower() == "csv": |
| return self._export_csv(results) |
| elif format.lower() == "markdown": |
| return self._export_markdown(results) |
| else: |
| raise ValueError(f"Unsupported export format: {format}") |
| |
| def _export_json(self, results: SearchResults) -> str: |
| """Export results as JSON""" |
| data = { |
| "query": results.query, |
| "total_found": results.total_found, |
| "search_time": results.search_time, |
| "sources": [ |
| { |
| "url": source.url, |
| "title": source.title, |
| "description": source.description, |
| "relevance_score": source.relevance_score, |
| "source_type": source.source_type, |
| "domain": source.domain, |
| "publication_date": source.publication_date, |
| "author": source.author |
| } |
| for source in results.sources |
| ], |
| "suggestions": results.suggestions |
| } |
| return json.dumps(data, indent=2) |
| |
| def _export_csv(self, results: SearchResults) -> str: |
| """Export results as CSV""" |
| import csv |
| from io import StringIO |
| |
| output = StringIO() |
| writer = csv.writer(output) |
| |
| |
| writer.writerow([ |
| "URL", "Title", "Description", "Relevance Score", |
| "Source Type", "Domain", "Publication Date", "Author" |
| ]) |
| |
| |
| for source in results.sources: |
| writer.writerow([ |
| source.url, |
| source.title, |
| source.description, |
| source.relevance_score, |
| source.source_type, |
| source.domain, |
| source.publication_date or "", |
| source.author or "" |
| ]) |
| |
| return output.getvalue() |
| |
| def _export_markdown(self, results: SearchResults) -> str: |
| """Export results as Markdown""" |
| md = f"# Search Results for: {results.query}\n\n" |
| md += f"**Total Sources Found:** {results.total_found}\n" |
| md += f"**Search Time:** {results.search_time:.2f} seconds\n\n" |
| |
| md += "## Sources\n\n" |
| |
| for i, source in enumerate(results.sources, 1): |
| md += f"### {i}. {source.title}\n\n" |
| md += f"**URL:** {source.url}\n" |
| md += f"**Type:** {source.source_type}\n" |
| md += f"**Domain:** {source.domain}\n" |
| md += f"**Relevance Score:** {source.relevance_score}/10\n" |
| md += f"**Description:** {source.description}\n\n" |
| |
| if results.suggestions: |
| md += "## Related Search Suggestions\n\n" |
| for suggestion in results.suggestions: |
| md += f"- {suggestion}\n" |
| |
| return md |
|
|
| |
| def test_perplexity_client(): |
| """Test function for Perplexity client""" |
| client = PerplexityClient() |
| |
| if not client._validate_api_key(): |
| print("β No API key found. Set PERPLEXITY_API_KEY environment variable.") |
| return |
| |
| |
| results = client.discover_sources( |
| project_description="Create a dataset for sentiment analysis of product reviews", |
| search_type=SearchType.GENERAL, |
| max_sources=10 |
| ) |
| |
| print(f"π Found {len(results.sources)} sources") |
| for source in results.sources[:3]: |
| print(f" - {source.title}: {source.url}") |
| |
| |
| json_export = client.export_sources(results, "json") |
| print(f"π JSON export: {len(json_export)} characters") |
|
|
| if __name__ == "__main__": |
| |
| test_perplexity_client() |