| import graphviz |
| import json |
| from tempfile import NamedTemporaryFile |
| import os |
| from graph_generator_utils import add_nodes_and_edges |
|
|
| def generate_wbs_diagram(json_input: str, output_format: str) -> str: |
| """ |
| Generates a Work Breakdown Structure (WBS) Diagram from JSON input. |
| |
| Args: |
| json_input (str): A JSON string describing the WBS structure. |
| It must follow the Expected JSON Format Example below. |
| |
| Expected JSON Format Example: |
| { |
| "project_title": "AI Model Development Project", |
| "phases": [ |
| { |
| "id": "phase_prep", |
| "label": "Preparation", |
| "tasks": [ |
| { |
| "id": "task_1_1_vision", |
| "label": "Identify Vision", |
| "subtasks": [ |
| { |
| "id": "subtask_1_1_1_design_staff", |
| "label": "Design & Staffing", |
| "sub_subtasks": [ |
| { |
| "id": "ss_task_1_1_1_1_env_setup", |
| "label": "Environment Setup", |
| "sub_sub_subtasks": [ |
| { |
| "id": "sss_task_1_1_1_1_1_lib_install", |
| "label": "Install Libraries", |
| "final_level_tasks": [ |
| {"id": "ft_1_1_1_1_1_1_data_access", "label": "Grant Data Access"} |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| ] |
| }, |
| { |
| "id": "phase_plan", |
| "label": "Planning", |
| "tasks": [ |
| { |
| "id": "task_2_1_cost_analysis", |
| "label": "Cost Analysis", |
| "subtasks": [ |
| { |
| "id": "subtask_2_1_1_benefit_analysis", |
| "label": "Benefit Analysis", |
| "sub_subtasks": [ |
| { |
| "id": "ss_task_2_1_1_1_risk_assess", |
| "label": "AI Risk Assessment", |
| "sub_sub_subtasks": [ |
| { |
| "id": "sss_task_2_1_1_1_1_model_selection", |
| "label": "Model Selection", |
| "final_level_tasks": [ |
| {"id": "ft_2_1_1_1_1_1_data_strategy", "label": "Data Strategy"} |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| ] |
| }, |
| { |
| "id": "phase_dev", |
| "label": "Development", |
| "tasks": [ |
| { |
| "id": "task_3_1_change_mgmt", |
| "label": "Data Preprocessing", |
| "subtasks": [ |
| { |
| "id": "subtask_3_1_1_implementation", |
| "label": "Feature Engineering", |
| "sub_subtasks": [ |
| { |
| "id": "ss_task_3_1_1_1_beta_testing", |
| "label": "Model Training", |
| "sub_sub_subtasks": [ |
| { |
| "id": "sss_task_3_1_1_1_1_other_task", |
| "label": "Model Evaluation", |
| "final_level_tasks": [ |
| {"id": "ft_3_1_1_1_1_1_hyperparam_tune", "label": "Hyperparameter Tuning"} |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| ] |
| } |
| |
| Returns: |
| str: The filepath to the generated PNG image file. |
| """ |
| try: |
| if not json_input.strip(): |
| return "Error: Empty input" |
| |
| data = json.loads(json_input) |
| |
| if 'project_title' not in data or 'phases' not in data: |
| raise ValueError("Missing required fields: project_title or phases") |
|
|
| dot = graphviz.Digraph( |
| name='WBSDiagram', |
| format='png', |
| graph_attr={ |
| 'rankdir': 'TB', |
| 'splines': 'ortho', |
| 'bgcolor': 'white', |
| 'pad': '0.5', |
| 'ranksep': '0.6', |
| 'nodesep': '0.5' |
| } |
| ) |
| |
| base_color = '#19191a' |
|
|
| |
| dot.node( |
| 'project_root', |
| data['project_title'], |
| shape='box', |
| style='filled,rounded', |
| fillcolor=base_color, |
| fontcolor='white', |
| fontsize='18' |
| ) |
|
|
| |
| def get_gradient_color(depth, base_hex_color, lightening_factor=0.12): |
| base_r = int(base_hex_color[1:3], 16) |
| base_g = int(base_hex_color[3:5], 16) |
| base_b = int(base_hex_color[5:7], 16) |
|
|
| current_r = base_r + int((255 - base_r) * depth * lightening_factor) |
| current_g = base_g + int((255 - base_g) * depth * lightening_factor) |
| current_b = base_b + int((255 - base_b) * depth * lightening_factor) |
| |
| return f'#{min(255, current_r):02x}{min(255, current_g):02x}{min(255, current_b):02x}' |
|
|
| def get_font_color_for_background(depth, base_hex_color, lightening_factor=0.12): |
| base_r = int(base_hex_color[1:3], 16) |
| base_g = int(base_hex_color[3:5], 16) |
| base_b = int(base_hex_color[5:7], 16) |
| current_r = base_r + (255 - base_r) * depth * lightening_factor |
| current_g = base_g + (255 - base_g) * depth * lightening_factor |
| current_b = base_b + (255 - base_b) * depth * lightening_factor |
| |
| luminance = (0.2126 * current_r + 0.7152 * current_g + 0.0722 * current_b) / 255 |
| return 'white' if luminance < 0.5 else 'black' |
|
|
| def _add_wbs_nodes_recursive(parent_id, current_level_tasks, current_depth): |
| for task_data in current_level_tasks: |
| task_id = task_data.get('id') |
| task_label = task_data.get('label') |
| |
| if not all([task_id, task_label]): |
| raise ValueError(f"Invalid task data at depth {current_depth}: {task_data}") |
|
|
| node_fill_color = get_gradient_color(current_depth, base_color) |
| node_font_color = get_font_color_for_background(current_depth, base_color) |
| font_size = max(9, 14 - (current_depth * 2)) |
|
|
| dot.node( |
| task_id, |
| task_label, |
| shape='box', |
| style='filled,rounded', |
| fillcolor=node_fill_color, |
| fontcolor=node_font_color, |
| fontsize=str(font_size) |
| ) |
| dot.edge(parent_id, task_id, color='#4a4a4a', arrowhead='none') |
|
|
| |
| |
| next_level_keys = ['tasks', 'subtasks', 'sub_subtasks', 'sub_sub_subtasks', 'final_level_tasks'] |
| for key_idx, key in enumerate(next_level_keys): |
| if key in task_data and isinstance(task_data[key], list): |
| _add_wbs_nodes_recursive(task_id, task_data[key], current_depth + 1) |
| break |
|
|
| |
| phase_depth = 1 |
| for phase in data['phases']: |
| phase_id = phase.get('id') |
| phase_label = phase.get('label') |
| |
| if not all([phase_id, phase_label]): |
| raise ValueError(f"Invalid phase data: {phase}") |
|
|
| phase_fill_color = get_gradient_color(phase_depth, base_color) |
| phase_font_color = get_font_color_for_background(phase_depth, base_color) |
| font_size_phase = max(9, 14 - (phase_depth * 2)) |
|
|
| dot.node( |
| phase_id, |
| phase_label, |
| shape='box', |
| style='filled,rounded', |
| fillcolor=phase_fill_color, |
| fontcolor=phase_font_color, |
| fontsize=str(font_size_phase) |
| ) |
| dot.edge('project_root', phase_id, color='#4a4a4a', arrowhead='none') |
|
|
| |
| if 'tasks' in phase and isinstance(phase['tasks'], list): |
| _add_wbs_nodes_recursive(phase_id, phase['tasks'], phase_depth + 1) |
|
|
| with NamedTemporaryFile(delete=False, suffix=f'.{output_format}') as tmp: |
| dot.render(tmp.name, format=output_format, cleanup=True) |
| return f"{tmp.name}.{output_format}" |
|
|
| except json.JSONDecodeError: |
| return "Error: Invalid JSON format" |
| except Exception as e: |
| return f"Error: {str(e)}" |
|
|
|
|