| import os |
| from enum import Enum |
|
|
| import typer |
| from openai import OpenAI |
|
|
| from jssp_openenv.client import JSSPEnvClient |
| from jssp_openenv.gantt import gantt_chart |
| from jssp_openenv.policy import JSSPEnvPolicy, JSSPFifoPolicy, JSSPLLMPolicy, JSSPMaxMinPolicy |
| from jssp_openenv.solver import solve_jssp |
|
|
| SERVER_URL = "http://localhost:8000" |
| MAX_STEPS = 1000 |
| OUTPUT_DIR = "output" |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| cli = typer.Typer() |
|
|
|
|
| class PolicyName(str, Enum): |
| FIFO = "fifo" |
| LLM = "llm" |
| MAX_MIN = "maxmin" |
|
|
|
|
| @cli.command() |
| def solve( |
| policy: PolicyName = typer.Argument(help="The policy to use"), |
| server_url: str = typer.Option(SERVER_URL, help="The URL of the JSSP server"), |
| max_steps: int = typer.Option(MAX_STEPS, help="The maximum number of steps per instance"), |
| verbose: bool = typer.Option(False, "--verbose", "-v", help="Whether to print verbose output"), |
| model_id: str = typer.Option(None, "--model-id", "-m", help="The ID of the model to use"), |
| ): |
| """Solve a JSSP instance using the given policy.""" |
| env_client = JSSPEnvClient(base_url=server_url) |
|
|
| policy_obj: JSSPEnvPolicy |
| match policy: |
| case PolicyName.FIFO: |
| policy_obj = JSSPFifoPolicy() |
| title = "FIFO Policy" |
| filename = "gantt_fifo_policy.png" |
|
|
| case PolicyName.LLM: |
| if not model_id: |
| raise ValueError("You must set --model-id to use the LLM policy") |
| api_key = os.getenv("HF_TOKEN") |
| if not api_key: |
| raise ValueError("You must set the HF_TOKEN environment variable to use the LLM policy") |
| client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=api_key) |
| policy_obj = JSSPLLMPolicy(client=client, model_id=model_id) |
| title = f"LLM Policy ({model_id})" |
| filename = f"gantt_llm_policy_{model_id.replace('/', '_').replace(':', '_').replace('-', '_').replace(' ', '_')}.png" |
|
|
| case PolicyName.MAX_MIN: |
| policy_obj = JSSPMaxMinPolicy() |
| title = "Max-Min Policy" |
| filename = "gantt_maxmin_policy.png" |
|
|
| makespan, scheduled_events = solve_jssp(env_client, policy_obj, max_steps, verbose) |
|
|
| if verbose: |
| print("Schedule events:") |
| for event in scheduled_events: |
| print( |
| f"[{event.start_time}] Scheduling job {event.job_id} on machine {event.machine_id} for {event.end_time - event.start_time} minute(s)" |
| ) |
|
|
| print(f"Solved in {makespan} steps") |
|
|
| filepath = os.path.join(OUTPUT_DIR, filename) |
| gantt_chart(scheduled_events, title=title, makespan=makespan, save_to=filepath) |
| print(f"Saved Gantt chart to {filepath}") |
|
|
|
|
| if __name__ == "__main__": |
| cli() |
|
|