Mmanikandan commited on
Commit
9d3f61d
·
1 Parent(s): d34f0ce

fixes applied

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -2
  2. inference.py +114 -138
  3. server/app.py +1 -1
Dockerfile CHANGED
@@ -7,6 +7,6 @@ RUN pip install --no-cache-dir -r requirements.txt
7
 
8
  COPY . .
9
 
10
- EXPOSE 8000
11
 
12
- CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
 
7
 
8
  COPY . .
9
 
10
+ EXPOSE 5000
11
 
12
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "5000"]
inference.py CHANGED
@@ -40,7 +40,7 @@ def get_environment_config() -> Dict[str, str]:
40
  "api_base_url": os.getenv("API_BASE_URL", "http://localhost:11434/v1"),
41
  "model_name": os.getenv("MODEL_NAME", "llama2"),
42
  "hf_token": os.getenv("HF_TOKEN", ""),
43
- "env_url": os.getenv("ENV_URL", "http://localhost:5001"), # FIXED: Changed from 5000 to 5001
44
  "api_key": os.getenv("HF_TOKEN", "not-needed-for-local"),
45
  }
46
  return config
@@ -618,149 +618,125 @@ def run_inference(config: Optional[Dict[str, str]] = None) -> None:
618
  except Exception as e:
619
  print(f"Warning: Could not initialize LLM client: {e}", file=sys.stderr)
620
 
621
- # Initialize variables for error handling
622
- rewards = []
623
- step_num = 0
624
- action_str = "initialization"
625
 
626
- try:
627
- # Reset environment
628
- reset_response = requests.post(
629
- f"{env_url}/reset",
630
- timeout=10
631
- )
632
- reset_response.raise_for_status()
633
- reset_data = reset_response.json()
634
 
635
- observation = reset_data.get("observation", {})
636
- task_name = observation.get("email_id", "email_workflow")
637
- email_subject = observation.get("subject", "")
638
- email_body = observation.get("body", "")
639
- customer_history = observation.get("customer_history", "")
640
- workflow_context = observation.get("previous_decisions", {}) # ✅ FIXED: Changed from "workflow_context" to "previous_decisions"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
 
642
- # Log start
643
- log_start(task_name, env_name, model_name)
 
644
 
645
- rewards = []
646
- step_num = 0
647
- done = False
648
 
649
- # Multi-step workflow loop
650
- while not done and step_num < 5:
651
- step_num += 1
652
 
653
- # Generate action based on current step
654
- if step_num == 1:
655
- action = generate_classification_action(
656
- email_subject, email_body, customer_history, client, model_name
657
- )
658
- elif step_num == 2:
659
- classification = workflow_context.get("classification", "tech")
660
- action = generate_prioritization_action(
661
- email_subject, email_body, customer_history, classification, client, model_name
662
- )
663
- elif step_num == 3:
664
- classification = workflow_context.get("classification", "tech")
665
- priority = workflow_context.get("priority", "medium")
666
- sentiment = observation.get("customer_sentiment", "neutral") # ✅ FIXED: Use actual sentiment from observation
667
- action = generate_strategy_action(
668
- email_subject, email_body, customer_history, classification, priority, sentiment, client, model_name
669
- )
670
- elif step_num == 4:
671
- classification = workflow_context.get("classification", "tech")
672
- priority = workflow_context.get("priority", "medium")
673
- strategy = workflow_context.get("strategy", "auto_resolve")
674
- action = generate_response_action(
675
- email_subject, email_body, customer_history, classification, priority, strategy, workflow_context, client, model_name
676
- )
677
- elif step_num == 5:
678
- action = generate_escalation_action(
679
- workflow_context, email_subject, email_body, customer_history, client, model_name
680
- )
681
- if action is None:
682
- # No escalation needed, end episode
683
- break
684
-
685
- # Convert action to string for logging
686
- if action["action_type"] == "escalate":
687
- action_str = f"escalate_{action['content'].get('escalation_level', 'unknown')}"
688
- else:
689
- content_preview = str(action["content"])[:50].replace("\n", " ")
690
- action_str = f"{action['action_type']}:{content_preview}"
691
-
692
- # Step environment
693
- step_response = requests.post(
694
- f"{env_url}/step",
695
- json=action,
696
- timeout=15
697
- )
698
- step_response.raise_for_status()
699
- step_data = step_response.json()
700
-
701
- reward = step_data.get("reward", 0.0)
702
- done = step_data.get("done", True)
703
- info = step_data.get("info", {})
704
-
705
- # Update workflow context for next step
706
- workflow_context = info.get("workflow_state", workflow_context)
707
-
708
- rewards.append(reward)
709
-
710
- # Log step
711
- log_step(step_num, action_str, reward, done, None)
712
-
713
- # Prepare final metrics
714
- total_score = sum(rewards)
715
- success = total_score > 2.0 # Threshold for successful multi-step completion
716
-
717
- # CRITICAL FIX: Normalize score to [0,1] range as per OpenEnv spec
718
- MAX_POSSIBLE_REWARD = 2.5 # Maximum theoretical score across all steps
719
- normalized_score = total_score / MAX_POSSIBLE_REWARD
720
- normalized_score = min(max(normalized_score, 0.0), 1.0)
721
-
722
- # Log end
723
- log_end(success, step_num, normalized_score, rewards)
724
-
725
- except requests.exceptions.RequestException as e:
726
- error_msg = f"Step {step_num} failed: {str(e)}"
727
- log_step(step_num, action_str, 0.0, False, error_msg)
728
- rewards.append(0.0)
729
- # Prepare final metrics after error
730
- total_score = sum(rewards)
731
- success = False
732
- normalized_score = 0.0
733
- log_end(success, step_num, normalized_score, rewards)
734
- print(f"Error: {error_msg}", file=sys.stderr)
735
- return # Exit function instead of break
736
-
737
- except Exception as e:
738
- error_msg = f"Step {step_num} error: {str(e)}"
739
- log_step(step_num, action_str, 0.0, False, error_msg)
740
- rewards.append(0.0)
741
- # Prepare final metrics after error
742
- total_score = sum(rewards)
743
- success = False
744
- normalized_score = 0.0
745
- log_end(success, step_num, normalized_score, rewards)
746
- print(f"Error: {error_msg}", file=sys.stderr)
747
- return # Exit function instead of break
748
-
749
- except requests.exceptions.RequestException as e:
750
- error_msg = f"Environment request failed: {str(e)}"
751
- log_start("error", env_name, model_name)
752
- log_step(1, "error", 0.0, False, error_msg)
753
- log_end(False, 1, 0.0, [0.0])
754
- print(f"Error: {error_msg}", file=sys.stderr)
755
- sys.exit(1)
756
-
757
- except Exception as e:
758
- error_msg = f"Inference failed: {str(e)}"
759
- log_start("error", env_name, model_name)
760
- log_step(1, "error", 0.0, False, error_msg)
761
- log_end(False, 1, 0.0, [0.0])
762
- print(f"Error: {error_msg}", file=sys.stderr)
763
- sys.exit(1)
764
 
765
 
766
  if __name__ == "__main__":
 
40
  "api_base_url": os.getenv("API_BASE_URL", "http://localhost:11434/v1"),
41
  "model_name": os.getenv("MODEL_NAME", "llama2"),
42
  "hf_token": os.getenv("HF_TOKEN", ""),
43
+ "env_url": os.getenv("ENV_URL", "http://localhost:5000"), # Fixed to match server default port
44
  "api_key": os.getenv("HF_TOKEN", "not-needed-for-local"),
45
  }
46
  return config
 
618
  except Exception as e:
619
  print(f"Warning: Could not initialize LLM client: {e}", file=sys.stderr)
620
 
621
+ episode_count = 3
 
 
 
622
 
623
+ for episode_idx in range(1, episode_count + 1):
624
+ rewards = []
625
+ step_num = 0
626
+ action_str = "initialization"
 
 
 
 
627
 
628
+ try:
629
+ # Reset environment
630
+ reset_response = requests.post(
631
+ f"{env_url}/reset",
632
+ timeout=10
633
+ )
634
+ reset_response.raise_for_status()
635
+ reset_data = reset_response.json()
636
+
637
+ observation = reset_data.get("observation", {})
638
+ task_name = observation.get("email_id", f"email_workflow_{episode_idx}")
639
+ email_subject = observation.get("subject", "")
640
+ email_body = observation.get("body", "")
641
+ customer_history = observation.get("customer_history", "")
642
+ workflow_context = observation.get("previous_decisions", {})
643
+
644
+ # Log start
645
+ log_start(task_name, env_name, model_name)
646
+
647
+ done = False
648
+
649
+ # Multi-step workflow loop
650
+ while not done and step_num < 5:
651
+ step_num += 1
652
+
653
+ # Generate action based on current step
654
+ if step_num == 1:
655
+ action = generate_classification_action(
656
+ email_subject, email_body, customer_history, client, model_name
657
+ )
658
+ elif step_num == 2:
659
+ classification = workflow_context.get("classification", "tech")
660
+ action = generate_prioritization_action(
661
+ email_subject, email_body, customer_history, classification, client, model_name
662
+ )
663
+ elif step_num == 3:
664
+ classification = workflow_context.get("classification", "tech")
665
+ priority = workflow_context.get("priority", "medium")
666
+ sentiment = observation.get("customer_sentiment", "neutral")
667
+ action = generate_strategy_action(
668
+ email_subject, email_body, customer_history, classification, priority, sentiment, client, model_name
669
+ )
670
+ elif step_num == 4:
671
+ classification = workflow_context.get("classification", "tech")
672
+ priority = workflow_context.get("priority", "medium")
673
+ strategy = workflow_context.get("strategy", "auto_resolve")
674
+ action = generate_response_action(
675
+ email_subject, email_body, customer_history, classification, priority, strategy, workflow_context, client, model_name
676
+ )
677
+ elif step_num == 5:
678
+ action = generate_escalation_action(
679
+ workflow_context, email_subject, email_body, customer_history, client, model_name
680
+ )
681
+ if action is None:
682
+ # No escalation needed, end episode
683
+ done = True
684
+ break
685
+
686
+ # Convert action to string for logging
687
+ if action["action_type"] == "escalate":
688
+ action_str = f"escalate_{action['content'].get('escalation_level', 'unknown')}"
689
+ else:
690
+ content_preview = str(action["content"])[:50].replace("\n", " ")
691
+ action_str = f"{action['action_type']}:{content_preview}"
692
+
693
+ # Step environment
694
+ step_response = requests.post(
695
+ f"{env_url}/step",
696
+ json=action,
697
+ timeout=15
698
+ )
699
+ step_response.raise_for_status()
700
+ step_data = step_response.json()
701
 
702
+ reward = step_data.get("reward", 0.0)
703
+ done = step_data.get("done", True)
704
+ info = step_data.get("info", {})
705
 
706
+ # Update workflow context for next step
707
+ workflow_context = info.get("workflow_state", workflow_context)
 
708
 
709
+ rewards.append(reward)
 
 
710
 
711
+ # Log step
712
+ log_step(step_num, action_str, reward, done, None)
713
+
714
+ # Prepare final metrics
715
+ total_score = sum(rewards)
716
+ MAX_POSSIBLE_REWARD = 2.5
717
+ normalized_score = total_score / MAX_POSSIBLE_REWARD
718
+ normalized_score = min(max(normalized_score, 0.0), 1.0)
719
+ success = normalized_score >= 0.7
720
+
721
+ log_end(success, step_num, normalized_score, rewards)
722
+
723
+ except requests.exceptions.RequestException as e:
724
+ error_msg = f"Step {step_num} failed: {str(e)}"
725
+ log_step(step_num, action_str, 0.0, False, error_msg)
726
+ rewards.append(0.0)
727
+ normalized_score = 0.0
728
+ log_end(False, step_num, normalized_score, rewards)
729
+ print(f"Error: {error_msg}", file=sys.stderr)
730
+ continue
731
+
732
+ except Exception as e:
733
+ error_msg = f"Step {step_num} error: {str(e)}"
734
+ log_step(step_num, action_str, 0.0, False, error_msg)
735
+ rewards.append(0.0)
736
+ normalized_score = 0.0
737
+ log_end(False, step_num, normalized_score, rewards)
738
+ print(f"Error: {error_msg}", file=sys.stderr)
739
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
 
742
  if __name__ == "__main__":
server/app.py CHANGED
@@ -156,7 +156,7 @@ def root() -> Dict[str, str]:
156
  def main():
157
  """Main entry point for running the server."""
158
  import uvicorn
159
- uvicorn.run(app, host="0.0.0.0", port=5001)
160
 
161
 
162
  if __name__ == "__main__":
 
156
  def main():
157
  """Main entry point for running the server."""
158
  import uvicorn
159
+ uvicorn.run(app, host="0.0.0.0", port=5000)
160
 
161
 
162
  if __name__ == "__main__":