| from typing import Dict, List, Optional, Tuple |
|
|
| from swift.llm.template import split_str_parts_by |
|
|
|
|
| def calculate_loss_scale(query: str, |
| response: str, |
| response_loss_scale_map: Dict[str, list], |
| query_loss_scale_map: Optional[Dict[str, list]] = None) -> Tuple[List[str], List[float]]: |
| """Calculate the loss scale by splitting the agent response. |
| |
| This algorithm comes from paper: https://arxiv.org/pdf/2309.00986.pdf |
| |
| Agent response format: |
| |
| ```text |
| Thought: you should always think about what to do |
| Action: the action to take, should be one of the above tools[fire_recognition, |
| fire_alert, call_police, call_fireman] |
| Action Input: the input to the action |
| Observation: the result of the action |
| ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) |
| Thought: I now know the final answer |
| Final Answer: the final answer to the original input question |
| ``` |
| Returns: |
| A tuple of agent response parts and their weights. |
| """ |
| |
| if query_loss_scale_map is not None: |
| for key in query_loss_scale_map.keys(): |
| if key in query: |
| if isinstance(query_loss_scale_map[key], (float, int)): |
| query_loss_scale_map[key] = [query_loss_scale_map[key]] |
| loss_scale_value = query_loss_scale_map[key][0] |
| return [response], [float(loss_scale_value)] |
| delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 2] |
| if delimiters: |
| agent_parts = split_str_parts_by(response, delimiters) |
| else: |
| regex_delimiters = [k for k, v in response_loss_scale_map.items() if len(v) == 1] |
| agent_parts = split_str_parts_by(response, regex_delimiters, regex_mode=True) |
| weights = [] |
| agent_content = [] |
| for c in agent_parts: |
| if c['key'] in response_loss_scale_map: |
| loss_scale = response_loss_scale_map[c['key']] |
| assert len(loss_scale) in {1, 2}, f'loss_scale: {loss_scale}' |
| if len(loss_scale) == 1: |
| weights += loss_scale |
| agent_content.append(c['content']) |
| else: |
| weights += loss_scale |
| agent_content += [c['key'], c['content']] |
| else: |
| weights.append(1.) |
| agent_content.append(c['content']) |
| return agent_content, weights |
|
|