aghilsabu commited on
Commit
4703744
·
1 Parent(s): 2e0dde8

feat: add DOT diagram generation and management

Browse files
Files changed (1) hide show
  1. src/core/diagram.py +470 -0
src/core/diagram.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagram Generator Module
3
+
4
+ Handles Graphviz rendering, sanitization, and diagram management.
5
+ """
6
+
7
+ import os
8
+ import re
9
+ import json
10
+ import logging
11
+ import tempfile
12
+ import hashlib
13
+ from datetime import datetime
14
+ from dataclasses import dataclass, field
15
+ from pathlib import Path
16
+ from typing import Optional, List, Tuple, Dict
17
+
18
+ from ..config import get_config, DIAGRAMS_DIR
19
+
20
+ logger = logging.getLogger("codeatlas.diagram")
21
+
22
+ # Try to import graphviz
23
+ try:
24
+ import graphviz
25
+ GRAPHVIZ_AVAILABLE = True
26
+ except ImportError:
27
+ GRAPHVIZ_AVAILABLE = False
28
+ logger.warning("Graphviz not available")
29
+
30
+
31
+ @dataclass
32
+ class LayoutOptions:
33
+ """Options for diagram layout."""
34
+ direction: str = "TB" # TB, LR, BT, RL
35
+ splines: str = "polyline" # polyline, ortho, spline, line
36
+ nodesep: float = 0.5
37
+ ranksep: float = 0.75
38
+ zoom: float = 1.0
39
+
40
+ DIRECTION_MAP = {
41
+ "Top → Down": "TB",
42
+ "Left → Right": "LR",
43
+ "Bottom → Up": "BT",
44
+ "Right → Left": "RL",
45
+ }
46
+
47
+ @classmethod
48
+ def from_ui(cls, direction: str, splines: str, nodesep: float, ranksep: float, zoom: float = 1.0):
49
+ """Create from UI values."""
50
+ return cls(
51
+ direction=cls.DIRECTION_MAP.get(direction, "TB"),
52
+ splines=splines,
53
+ nodesep=nodesep,
54
+ ranksep=ranksep,
55
+ zoom=zoom,
56
+ )
57
+
58
+
59
+ @dataclass
60
+ class DiagramInfo:
61
+ """Information about a saved diagram."""
62
+ filename: str
63
+ repo_name: str
64
+ timestamp: str
65
+ formatted_timestamp: str
66
+ file_path: Path
67
+ # Metadata fields (loaded from JSON if available)
68
+ model_name: str = ""
69
+ files_processed: int = 0
70
+ total_characters: int = 0
71
+ node_count: int = 0
72
+ edge_count: int = 0
73
+
74
+
75
+ class DiagramGenerator:
76
+ """Generates and manages architecture diagrams."""
77
+
78
+ def __init__(self):
79
+ self.config = get_config()
80
+ self.diagrams_dir = self.config.diagrams_dir
81
+
82
+ def render(
83
+ self,
84
+ dot_source: str,
85
+ layout: Optional[LayoutOptions] = None,
86
+ repo_name: str = "",
87
+ save_to_history: bool = False,
88
+ metadata: Optional[Dict] = None
89
+ ) -> str:
90
+ """Render DOT source to HTML with embedded SVG.
91
+
92
+ Args:
93
+ dot_source: Graphviz DOT source code
94
+ layout: Layout options
95
+ repo_name: Repository name for saving
96
+ save_to_history: Whether to save to history
97
+ metadata: Optional metadata dict (model, files, chars) to save with diagram
98
+
99
+ Returns:
100
+ HTML string with SVG diagram
101
+ """
102
+ if dot_source.startswith("Error:"):
103
+ return self._error_html(dot_source)
104
+
105
+ if not GRAPHVIZ_AVAILABLE:
106
+ return self._fallback_html(dot_source)
107
+
108
+ if layout is None:
109
+ layout = LayoutOptions()
110
+
111
+ try:
112
+ # Clean and prepare DOT source
113
+ dot_source = self._prepare_dot(dot_source)
114
+
115
+ # Save raw diagram if requested
116
+ if save_to_history:
117
+ self._save_diagram(dot_source, "raw", repo_name, metadata)
118
+
119
+ # Sanitize DOT code
120
+ dot_source = self._sanitize_dot(dot_source)
121
+
122
+ # Apply layout settings
123
+ dot_source = self._apply_layout(dot_source, layout)
124
+
125
+ # Render to SVG
126
+ svg_content = self._render_svg(dot_source)
127
+
128
+ # Wrap in HTML
129
+ return self._wrap_svg(svg_content, layout.zoom)
130
+
131
+ except Exception as e:
132
+ logger.exception("Rendering failed")
133
+ return self._fallback_html(dot_source, error=str(e))
134
+
135
+ def _prepare_dot(self, dot_source: str) -> str:
136
+ """Prepare DOT source by removing markdown and ensuring structure."""
137
+ dot_source = dot_source.strip()
138
+
139
+ # Remove markdown code fences
140
+ if dot_source.startswith("```"):
141
+ lines = dot_source.split("\n")
142
+ if lines[0].startswith("```"):
143
+ lines = lines[1:]
144
+ if lines and lines[-1].strip() == "```":
145
+ lines = lines[:-1]
146
+ dot_source = "\n".join(lines)
147
+
148
+ # Ensure digraph wrapper
149
+ if "digraph" not in dot_source and "graph" not in dot_source:
150
+ dot_source = f"digraph G {{\n{dot_source}\n}}"
151
+
152
+ return dot_source
153
+
154
+ def _sanitize_dot(self, dot_source: str) -> str:
155
+ """Sanitize DOT source to fix common LLM output issues."""
156
+ # Check for error responses
157
+ if "Error" in dot_source and any(x in dot_source for x in ["429", "RESOURCE_EXHAUSTED", "quota"]):
158
+ raise ValueError("Rate limited - received error instead of diagram")
159
+
160
+ lines = dot_source.split("\n")
161
+ sanitized = []
162
+ brace_count = 0
163
+
164
+ for i, line in enumerate(lines):
165
+ stripped = line.strip()
166
+ if not stripped or stripped.startswith("//") or stripped.startswith("#"):
167
+ sanitized.append(line)
168
+ continue
169
+
170
+ brace_count += line.count("{") - line.count("}")
171
+
172
+ # Remove HTML-like tags that Graphviz doesn't support
173
+ line = re.sub(r"<[^>]+>", "", line)
174
+
175
+ # Fix incomplete edges at end of file
176
+ is_last_few = i >= len(lines) - 3
177
+ if is_last_few:
178
+ if stripped.endswith("->") or re.match(r".*->\s*$", stripped):
179
+ continue
180
+ if "->" in stripped and not stripped.endswith(";") and "[" not in stripped:
181
+ parts = stripped.split("->")
182
+ if len(parts) == 2 and not parts[1].strip():
183
+ continue
184
+ # Fix unclosed quotes
185
+ if stripped.count('"') % 2 == 1:
186
+ line = line.rstrip() + '"];'
187
+
188
+ sanitized.append(line)
189
+
190
+ result = "\n".join(sanitized)
191
+
192
+ # Balance braces
193
+ if brace_count > 0:
194
+ result += "\n" + "}" * brace_count
195
+
196
+ return result
197
+
198
+ def _apply_layout(self, dot_source: str, layout: LayoutOptions) -> str:
199
+ """Apply layout settings to DOT source."""
200
+ # Remove existing layout attributes
201
+ dot_source = re.sub(r"rankdir\s*=\s*\w+\s*;?", "", dot_source)
202
+ dot_source = re.sub(r"splines\s*=\s*\w+\s*;?", "", dot_source)
203
+ dot_source = re.sub(r"nodesep\s*=\s*[\d.]+\s*;?", "", dot_source)
204
+ dot_source = re.sub(r"ranksep\s*=\s*[\d.]+\s*;?", "", dot_source)
205
+
206
+ # Add new layout settings
207
+ layout_settings = f"""
208
+ rankdir={layout.direction};
209
+ splines={layout.splines};
210
+ nodesep={layout.nodesep};
211
+ ranksep={layout.ranksep};
212
+ pad=0.5;
213
+ node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=12, width=2.0, height=0.6, margin="0.2,0.1"];
214
+ edge [fontname="Helvetica", fontsize=10, arrowsize=0.8];
215
+ """
216
+ dot_source = dot_source.replace("{", "{" + layout_settings, 1)
217
+
218
+ # For ortho splines, convert label to xlabel
219
+ if layout.splines == "ortho":
220
+ lines = dot_source.split("\n")
221
+ converted = []
222
+ for line in lines:
223
+ if "->" in line and "label=" in line and "xlabel=" not in line:
224
+ line = line.replace("label=", "xlabel=")
225
+ converted.append(line)
226
+ dot_source = "\n".join(converted)
227
+
228
+ return dot_source
229
+
230
+ def _render_svg(self, dot_source: str) -> str:
231
+ """Render DOT to SVG using Graphviz."""
232
+ graph = graphviz.Source(dot_source)
233
+
234
+ with tempfile.TemporaryDirectory() as tmpdir:
235
+ output_path = os.path.join(tmpdir, "diagram")
236
+ graph.render(output_path, format="svg", cleanup=True)
237
+
238
+ with open(output_path + ".svg", "r", encoding="utf-8") as f:
239
+ return f.read()
240
+
241
+ def _wrap_svg(self, svg_content: str, zoom: float = 1.0) -> str:
242
+ """Wrap SVG in responsive HTML container."""
243
+ # Make SVG responsive
244
+ svg_content = re.sub(r'<svg([^>]*?)width="[^"]*"', r'<svg\1width="100%"', svg_content)
245
+ svg_content = re.sub(r'height="[^"]*"', 'height="auto"', svg_content, count=1)
246
+
247
+ # Apply zoom
248
+ transform = f'style="transform: scale({zoom}); transform-origin: top left;"' if zoom != 1.0 else ""
249
+
250
+ return f'''<div class="diagram-box">
251
+ <div class="diagram-inner" {transform}>
252
+ {svg_content}
253
+ </div>
254
+ </div>'''
255
+
256
+ def _save_diagram(
257
+ self,
258
+ dot_content: str,
259
+ prefix: str,
260
+ repo_name: str = "",
261
+ metadata: Optional[Dict] = None
262
+ ) -> Path:
263
+ """Save DOT content to history with optional metadata."""
264
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
265
+
266
+ # Sanitize repo name
267
+ safe_repo = repo_name.replace("/", "_").replace(" ", "_") if repo_name else ""
268
+ safe_repo = "".join(c for c in safe_repo if c.isalnum() or c in "_-")[:50]
269
+
270
+ filename = f"{prefix}_{safe_repo}_{timestamp}.dot" if safe_repo else f"{prefix}_{timestamp}.dot"
271
+ filepath = self.diagrams_dir / filename
272
+
273
+ try:
274
+ with open(filepath, "w", encoding="utf-8") as f:
275
+ f.write(dot_content)
276
+ logger.info(f"Saved diagram: {filepath}")
277
+
278
+ # Save metadata as JSON if provided
279
+ if metadata:
280
+ # Add node/edge counts from DOT content
281
+ node_count, edge_count = self._count_nodes_edges(dot_content)
282
+ metadata["node_count"] = node_count
283
+ metadata["edge_count"] = edge_count
284
+ metadata["repo_name"] = repo_name
285
+ metadata["timestamp"] = timestamp
286
+
287
+ meta_filepath = filepath.with_suffix(".json")
288
+ with open(meta_filepath, "w", encoding="utf-8") as f:
289
+ json.dump(metadata, f, indent=2)
290
+ logger.info(f"Saved metadata: {meta_filepath}")
291
+
292
+ return filepath
293
+ except Exception as e:
294
+ logger.warning(f"Failed to save diagram: {e}")
295
+ return None
296
+
297
+ def _count_nodes_edges(self, dot_content: str) -> Tuple[int, int]:
298
+ """Count nodes and edges in DOT source."""
299
+ # Count edges (lines with ->), handling quoted node names
300
+ edge_pattern = r'(?:"[^"]+"|[\w]+)\s*->\s*(?:"[^"]+"|[\w]+)'
301
+ edges = len(re.findall(edge_pattern, dot_content))
302
+
303
+ # Count unique node names (excluding keywords)
304
+ keywords = {'digraph', 'graph', 'subgraph', 'node', 'edge', 'rankdir', 'splines', 'nodesep', 'ranksep', 'label', 'style', 'shape', 'color', 'fillcolor', 'fontname', 'fontsize', 'margin', 'pad', 'width', 'height', 'arrowsize', 'cluster', 'tb', 'lr', 'bt', 'rl'}
305
+
306
+ nodes = set()
307
+
308
+ # Match quoted node names in definitions: "Node Name" [...]
309
+ quoted_defs = re.findall(r'^\s*"([^"]+)"\s*\[', dot_content, re.MULTILINE)
310
+ for node in quoted_defs:
311
+ nodes.add(node)
312
+
313
+ # Match unquoted node definitions: NodeName [...]
314
+ unquoted_defs = re.findall(r'^\s*(\w+)\s*\[', dot_content, re.MULTILINE)
315
+ for node in unquoted_defs:
316
+ if node.lower() not in keywords and not node.isdigit():
317
+ nodes.add(node)
318
+
319
+ # Match nodes in edges (both quoted and unquoted)
320
+ edge_nodes = re.findall(r'(?:"([^"]+)"|(\w+))\s*->\s*(?:"([^"]+)"|(\w+))', dot_content)
321
+ for match in edge_nodes:
322
+ # Each match is (quoted_src, unquoted_src, quoted_dst, unquoted_dst)
323
+ src = match[0] or match[1]
324
+ dst = match[2] or match[3]
325
+ if src and src.lower() not in keywords and not src.isdigit():
326
+ nodes.add(src)
327
+ if dst and dst.lower() not in keywords and not dst.isdigit():
328
+ nodes.add(dst)
329
+
330
+ return len(nodes), edges
331
+
332
+ def get_history(self, limit: int = 50) -> List[DiagramInfo]:
333
+ """Get list of saved diagrams.
334
+
335
+ Args:
336
+ limit: Maximum number of diagrams to return
337
+
338
+ Returns:
339
+ List of DiagramInfo objects
340
+ """
341
+ if not self.diagrams_dir.exists():
342
+ return []
343
+
344
+ files = [f for f in self.diagrams_dir.iterdir() if f.name.startswith("raw_") and f.suffix == ".dot"]
345
+
346
+ # Sort by timestamp (newest first)
347
+ def extract_timestamp(path: Path) -> str:
348
+ name = path.stem.replace("raw_", "")
349
+ parts = name.split("_")
350
+ if len(parts) >= 2:
351
+ return parts[-2] + parts[-1]
352
+ return "0"
353
+
354
+ files.sort(key=extract_timestamp, reverse=True)
355
+
356
+ # Build DiagramInfo list (no deduplication - show all history)
357
+ diagrams = []
358
+ for f in files[:limit]:
359
+ name = f.stem.replace("raw_", "")
360
+ parts = name.split("_")
361
+
362
+ if len(parts) >= 2 and len(parts[-2]) == 8 and len(parts[-1]) == 6:
363
+ repo_parts = parts[:-2]
364
+ repo_name = repo_parts[-1] if repo_parts else "local"
365
+ date_part = parts[-2]
366
+ time_part = parts[-1]
367
+
368
+ try:
369
+ formatted_ts = f"{date_part[:4]}-{date_part[4:6]}-{date_part[6:8]} {time_part[:2]}:{time_part[2:4]}"
370
+ except:
371
+ formatted_ts = f"{date_part}_{time_part}"
372
+
373
+ # Load metadata if available
374
+ metadata = self._load_metadata(f)
375
+
376
+ diagrams.append(DiagramInfo(
377
+ filename=f.name,
378
+ repo_name=metadata.get("repo_name", repo_name) if metadata else repo_name,
379
+ timestamp=f"{date_part}_{time_part}",
380
+ formatted_timestamp=formatted_ts,
381
+ file_path=f,
382
+ model_name=metadata.get("model_name", "") if metadata else "",
383
+ files_processed=metadata.get("files_processed", 0) if metadata else 0,
384
+ total_characters=metadata.get("total_characters", 0) if metadata else 0,
385
+ node_count=metadata.get("node_count", 0) if metadata else 0,
386
+ edge_count=metadata.get("edge_count", 0) if metadata else 0,
387
+ ))
388
+
389
+ return diagrams
390
+
391
+ def _load_metadata(self, dot_filepath: Path) -> Optional[Dict]:
392
+ """Load metadata JSON for a diagram file."""
393
+ meta_filepath = dot_filepath.with_suffix(".json")
394
+ if not meta_filepath.exists():
395
+ return None
396
+ try:
397
+ with open(meta_filepath, "r", encoding="utf-8") as f:
398
+ return json.load(f)
399
+ except Exception as e:
400
+ logger.warning(f"Failed to load metadata: {e}")
401
+ return None
402
+
403
+ def get_history_choices(self) -> List[Tuple[str, str]]:
404
+ """Get history as choices for Gradio dropdown."""
405
+ diagrams = self.get_history()
406
+ return [(f"{d.repo_name} — {d.formatted_timestamp}", d.filename) for d in diagrams]
407
+
408
+ def load_from_history(self, filename: str) -> Optional[str]:
409
+ """Load a diagram from history.
410
+
411
+ Args:
412
+ filename: Name of the diagram file
413
+
414
+ Returns:
415
+ DOT source or None if not found
416
+ """
417
+ filepath = self.diagrams_dir / filename
418
+ if not filepath.exists():
419
+ return None
420
+
421
+ try:
422
+ return filepath.read_text(encoding="utf-8")
423
+ except Exception as e:
424
+ logger.warning(f"Failed to load diagram: {e}")
425
+ return None
426
+
427
+ def load_from_history_with_metadata(self, filename: str) -> Tuple[Optional[str], Optional[Dict]]:
428
+ """Load a diagram and its metadata from history.
429
+
430
+ Args:
431
+ filename: Name of the diagram file
432
+
433
+ Returns:
434
+ Tuple of (DOT source, metadata dict) or (None, None) if not found
435
+ """
436
+ filepath = self.diagrams_dir / filename
437
+ if not filepath.exists():
438
+ return None, None
439
+
440
+ try:
441
+ dot_source = filepath.read_text(encoding="utf-8")
442
+ metadata = self._load_metadata(filepath)
443
+
444
+ # If no metadata file, compute node/edge counts from DOT
445
+ if metadata is None:
446
+ node_count, edge_count = self._count_nodes_edges(dot_source)
447
+ metadata = {
448
+ "node_count": node_count,
449
+ "edge_count": edge_count,
450
+ }
451
+
452
+ return dot_source, metadata
453
+ except Exception as e:
454
+ logger.warning(f"Failed to load diagram: {e}")
455
+ return None, None
456
+
457
+ def _error_html(self, message: str) -> str:
458
+ """Generate error display HTML."""
459
+ return f'''<div style="color:#dc2626; padding:20px; text-align:center;">
460
+ <strong>⚠️ {message}</strong>
461
+ </div>'''
462
+
463
+ def _fallback_html(self, content: str, error: str = None) -> str:
464
+ """Generate fallback HTML when Graphviz is unavailable."""
465
+ error_msg = f"<p style='color: #dc2626;'>Rendering error: {error}</p>" if error else ""
466
+ return f'''<div style="background: #f9fafb; padding: 1.5rem; border-radius: 12px;">
467
+ <strong>📊 Architecture (Text View)</strong>
468
+ {error_msg}
469
+ <pre style="background: #fff; padding: 1rem; border-radius: 8px; overflow-x: auto;">{content[:2000]}</pre>
470
+ </div>'''