| import uuid |
| import datetime |
| from typing import Optional, Tuple, Dict, Any, List |
| from .models import CloudAction, CloudObservation, CloudState, CloudActionType |
|
|
| class CloudAuditEnv: |
| def __init__(self): |
| self.task_id = "easy" |
| self._initialize_state() |
|
|
| def _initialize_state(self): |
| self.episode_id = str(uuid.uuid4()) |
| self.step_count = 0 |
| self.is_completed = False |
| self.score = 0.01 |
| |
| |
| self.resources = { |
| "s3": [ |
| {"id": "prod-data-001", "region": "us-east-1", "public": True, "tags": {"env": "prod"}}, |
| {"id": "prod-logs-002", "region": "us-east-1", "public": False, "tags": {"env": "prod"}}, |
| {"id": "dev-test-01", "region": "us-west-2", "public": True, "tags": {"env": "dev"}}, |
| ], |
| "ec2": [ |
| {"id": "i-0abcdef1234567890", "type": "t2.micro", "state": "running", "tags": {"env": "dev"}, |
| "security_groups": [{"id": "sg-01", "rules": [{"port": 22, "cidr": "0.0.0.0/0"}, {"port": 3389, "cidr": "0.0.0.0/0"}]}]}, |
| {"id": "i-0987654321fedcba0", "type": "m5.large", "state": "running", "tags": {"env": "prod"}, |
| "security_groups": [{"id": "sg-02", "rules": [{"port": 443, "cidr": "0.0.0.0/0"}]}]}, |
| ], |
| "logs": { |
| "auth-logs": [ |
| {"timestamp": "2026-04-05T10:00:00Z", "user": "admin", "action": "Login", "ip": "1.1.1.1"}, |
| {"timestamp": "2026-04-05T10:15:00Z", "user": "iam-role-01", "action": "DeleteStorage", "ip": "192.168.1.50"}, |
| {"timestamp": "2026-04-05T10:30:00Z", "user": "user-02", "action": "ListBuckets", "ip": "2.2.2.2"}, |
| ] |
| } |
| } |
|
|
| def reset(self, task_id: str = "easy") -> CloudObservation: |
| """Required by openenv-core 0.1.1: takes task_id, returns JUST the observation.""" |
| self.task_id = task_id |
| self._initialize_state() |
| return CloudObservation(info=f"Environment reset. Task: {self.task_id}", reward=0.01, done=False) |
|
|
| def step(self, action: CloudAction) -> CloudObservation: |
| """Required by openenv-core 0.1.1: takes action, returns JUST the observation with reward/done fields.""" |
| try: |
| self.step_count += 1 |
| reward = 0.005 |
| terminated = False |
| truncated = self.step_count >= 20 |
| |
| obs = CloudObservation() |
| |
| if action.action == CloudActionType.LIST: |
| r_type = action.resource_type |
| if r_type in self.resources: |
| obs.resources = self.resources[r_type] |
| obs.status = f"Listed {len(obs.resources)} {r_type} resources." |
| else: |
| obs.status = f"Unknown resource type: {r_type}" |
|
|
| elif action.action == CloudActionType.DESCRIBE: |
| res_id = action.resource_id |
| found = False |
| for r_type in ["s3", "ec2"]: |
| for r in self.resources[r_type]: |
| if r["id"] == res_id: |
| obs.details = r |
| obs.status = f"Described resource {res_id}" |
| found = True |
| break |
| if not found: |
| obs.status = f"Resource not found: {res_id}" |
|
|
| elif action.action == CloudActionType.MODIFY: |
| res_id = action.resource_id |
| patch = action.patch |
| |
| if self.task_id == "medium" and res_id == "i-0abcdef1234567890": |
| for sg in self.resources["ec2"][0]["security_groups"]: |
| if patch and "rules" in patch: |
| sg["rules"] = patch["rules"] |
| obs.status = f"Updated security groups for EC2 instance {res_id}" |
| |
| rules = self.resources["ec2"][0]["security_groups"][0]["rules"] |
| has_rdp = any(r["port"] == 3389 and r["cidr"] == "0.0.0.0/0" for r in rules) |
| if not has_rdp: |
| reward = 0.85 |
| terminated = True |
| obs.info = "Success! Port 3389 removed. Task completed." |
| else: |
| obs.info = "Port 3389 is still open. Remove it by omitting it from the rules list." |
| elif self.task_id == "medium": |
| obs.status = f"Invalid resource ID '{res_id}'. Use the EC2 instance ID 'i-0abcdef1234567890', not the security group ID." |
| else: |
| obs.status = "Action not permitted or invalid resource." |
|
|
| elif action.action == CloudActionType.LOGS: |
| log_name = action.resource_id |
| if log_name in self.resources["logs"]: |
| obs.logs = self.resources["logs"][log_name] |
| obs.status = f"Fetched logs for {log_name}" |
| else: |
| obs.status = f"Logs not found: {log_name}" |
|
|
| elif action.action == CloudActionType.SUBMIT: |
| |
| if self.task_id == "easy": |
| |
| if action.answer: |
| answers = [a.strip() for a in action.answer.split(",")] |
| expected = ["prod-data-001"] |
| if set(answers) == set(expected): |
| reward = 0.85 |
| terminated = True |
| obs.info = "Correct! Task completed." |
| else: |
| obs.info = f"Incorrect. Expected the public prod S3 bucket ID. Got: {answers}" |
| |
| elif self.task_id == "hard": |
| |
| if action.answer and action.answer.strip() == "192.168.1.50": |
| reward = 0.85 |
| terminated = True |
| obs.info = "Correct! Rogue IP identified. Task completed." |
| else: |
| obs.info = f"Wrong IP address. Got: {action.answer}. Check the auth-logs for the DeleteStorage action." |
| |
| elif self.task_id == "medium": |
| obs.info = "For the medium task, use the 'modify' action to update the EC2 security group, not 'submit'." |
|
|
| self.score = min(0.99, self.score + reward) |
| obs.reward = reward |
| obs.done = terminated or truncated |
| return obs |
| except Exception as e: |
| import sys |
| import traceback |
| print(f"ERROR in environment.step: {str(e)}", file=sys.stderr) |
| traceback.print_exc(file=sys.stderr) |
| return CloudObservation(status=f"Internal Server Error: {str(e)}", reward=0.01, done=True) |
|
|
| def state(self) -> CloudState: |
| return CloudState( |
| episode_id=self.episode_id, |
| step_count=self.step_count, |
| task_id=self.task_id, |
| is_completed=self.is_completed, |
| score=self.score |
| ) |
|
|