| import base64 |
| import json |
| import os |
| from typing import Dict, List, Optional, Any |
|
|
| import chess |
| import chess.engine |
| from langchain.chat_models import init_chat_model |
| from langchain.schema import SystemMessage, HumanMessage |
| from langchain.tools import tool |
| from langchain_google_genai import ChatGoogleGenerativeAI |
| from langchain_openai import ChatOpenAI |
| from pydantic import BaseModel, Field |
|
|
| from utils.prompt_manager import prompt_mgmt |
|
|
|
|
| def encode_image_to_base64(image_path: str) -> str: |
| """Encode image to base64 for API consumption""" |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
|
| class ChessPiecePosition(BaseModel): |
| """Model for chess piece position""" |
| square: str = Field(..., description="Chess square notation (e.g., 'e4', 'a1')") |
| piece: str = Field(..., description="Piece type and color (e.g., 'white_king', 'black_queen')") |
|
|
|
|
| class ChessBoardAnalysis(BaseModel): |
| """Model for complete chess board analysis""" |
| positions: List[ChessPiecePosition] = Field(..., description="List of all piece positions on the board") |
|
|
| def add_positions(self, positions: List[ChessPiecePosition]) -> None: |
| """Add multiple positions to the analysis""" |
| for position in positions: |
| self.positions.append(position) |
|
|
| def merge_with(self, other: 'ChessBoardAnalysis') -> None: |
| """Merge another analysis into this one (overwriting conflicts)""" |
| self.add_positions(other.positions) |
|
|
| def to_fen(self, active_color) -> str: |
| """Convert the analysis to FEN notation (simplified)""" |
| |
| board = [['' for _ in range(8)] for _ in range(8)] |
|
|
| for position in self.positions: |
| file_idx = ord(position.square[0]) - ord('a') |
| rank_idx = 8 - int(position.square[1]) |
|
|
| if 0 <= file_idx < 8 and 0 <= rank_idx < 8: |
| piece_char = self._piece_to_char(position.piece) |
| board[rank_idx][file_idx] = piece_char |
|
|
| |
| fen_rows = [] |
| for row in board: |
| fen_row = '' |
| empty_count = 0 |
|
|
| for cell in row: |
| if cell == '': |
| empty_count += 1 |
| else: |
| if empty_count > 0: |
| fen_row += str(empty_count) |
| empty_count = 0 |
| fen_row += cell |
|
|
| if empty_count > 0: |
| fen_row += str(empty_count) |
|
|
| fen_rows.append(fen_row) |
|
|
| piece_placement = '/'.join(fen_rows) |
| |
| active_color_char = 'w' if active_color.lower() == 'white' else 'b' |
| |
| castling_rights = "-" |
| en_passant = "-" |
| halfmove_clock = 0 |
| fullmove_number = 1 |
| fen_parts = [ |
| piece_placement, |
| active_color_char, |
| castling_rights, |
| en_passant, |
| str(halfmove_clock), |
| str(fullmove_number) |
| ] |
|
|
| return ' '.join(fen_parts) |
|
|
| def _piece_to_char(self, piece: str) -> str: |
| """Convert piece description to FEN character""" |
| color, piece_type = piece.split('_') |
| piece_chars = { |
| 'king': 'K', 'queen': 'Q', 'rook': 'R', |
| 'bishop': 'B', 'knight': 'N', 'pawn': 'P' |
| } |
| char = piece_chars.get(piece_type, '') |
| return char.lower() if color == 'black' else char |
|
|
|
|
| class ChessVisionAnalyzer: |
| def __init__(self): |
| self.llm1 = init_chat_model(model="openai:gpt-4.1", temperature=0.0) |
| self.llm2 = ChatGoogleGenerativeAI(model="gemini-2.5-flash") |
|
|
| def analyze_board_orientation(self, active_color: str, image_path: str) -> str: |
| """Analyze chess board image and return FEN notation""" |
| base64_image = encode_image_to_base64(image_path) |
|
|
| messages = [ |
| SystemMessage( |
| content=prompt_mgmt.render_template("chess_board_orientation", {})), |
| HumanMessage(content=[ |
| { |
| "type": "text", |
| "text": f"Analyze this chess board image and return the chess board orientation. I know that the " |
| f"active color is {active_color}" |
|
|
| }, |
| { |
| "type": "image_url", |
| "image_url": { |
| "url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" |
| } |
|
|
| } |
| ]) |
| ] |
|
|
| response = self.llm1.invoke(messages) |
| return response.content |
|
|
| def analyze_board_from_image(self, board_orientation: str, image_path: str, llm_no: int, |
| squares: Optional[list] = None) -> Optional[ChessBoardAnalysis]: |
| """Analyze chess board image and return FEN notation""" |
| base64_image = encode_image_to_base64(image_path) |
|
|
| squares_text = "" |
| if squares: |
| squares_text = (f"Focus only on these pieces {sorted(squares)} " |
| f"*** Important: make sure you detect correctly their position as they were challanged by another model. Take into account the board " |
| f"orientation." |
| ) |
|
|
| messages = [ |
| SystemMessage( |
| content=prompt_mgmt.render_template("chess_board_detection", {})), |
| HumanMessage(content=[ |
| { |
| "type": "text", |
| "text": f"""Analyze this chess board image and return the pieces positions. |
| {board_orientation} |
| |
| {squares_text} |
| Return the positions of the pieces in JSON format. |
| Use the following schema for each piece: |
| [{{ |
| "square": "chess notation (e.g., 'e4', 'a1')", |
| "piece": "color_piece (e.g., 'white_king', 'black_queen')" |
| }},... |
| |
| {{ |
| "square": "chess notation (e.g., 'e4', 'a1')", |
| "piece": "color_piece (e.g., 'white_king', 'black_queen')" |
| }} |
| ] |
| Very Important: Return only this list! |
| """ |
|
|
| }, |
| { |
| "type": "image_url", |
| "image_url": { |
| "url": f"data:image/jpeg;base64,{base64_image}", "detail": "high" |
| } |
|
|
| } |
| ]) |
| ] |
| if llm_no == 1: |
| response = self.llm1.invoke(messages) |
| else: |
| response = self.llm2.invoke(messages) |
| return self._parse_llm_response(response.content) |
|
|
| def analyze_board(self, active_color: str, file_reference: str) -> str: |
| board_orientation = self.analyze_board_orientation(active_color, file_reference) |
| first_analysis_res = self.analyze_board_from_image(board_orientation, file_reference, 1) |
| second_analysis_res = self.analyze_board_from_image(board_orientation, file_reference, 2) |
|
|
| result = self.compare_analyses(first_analysis_res, second_analysis_res) |
| if result['conflicts'] is not None and len(result['conflicts']) > 0: |
| arbitrage_result = self.arbitrate_conflicts(result, board_orientation, file_reference, 3) |
|
|
| return arbitrage_result.get("consensus").to_fen(active_color) |
| else: |
| result.get("consensus").to_fen(active_color) |
|
|
| def _parse_llm_response(self, response: str) -> Optional[ChessBoardAnalysis]: |
| """Parse LLM response into ChessBoardAnalysis""" |
| try: |
| |
| json_str = response.strip() |
| if "```json" in json_str: |
| json_str = json_str.split("```json")[1].split("```")[0].strip() |
| elif "```" in json_str: |
| json_str = json_str.split("```")[1].split("```")[0].strip() |
|
|
| data = json.loads(json_str) |
| print(data) |
| |
| positions = [] |
| for item in data: |
| if item["piece"]: |
| positions.append(ChessPiecePosition(**item)) |
|
|
| return ChessBoardAnalysis(positions=positions) |
| except Exception as e: |
| print(f"Failed to parse LLM response: {e}") |
| return None |
|
|
| def compare_analyses(self, analysis_1: ChessBoardAnalysis, analysis_2: ChessBoardAnalysis) -> dict: |
| """Compare the given analyses and identify conflicts""" |
|
|
| if not analysis_1 or not analysis_2: |
| return {"conflicts": [], "consensus": None, "need_arbitration": False} |
|
|
| |
| dict_1 = {pos.square: pos.piece for pos in analysis_1.positions} |
| dict_2 = {pos.square: pos.piece for pos in analysis_2.positions} |
|
|
| conflicts = [] |
| consensus = [] |
|
|
| |
| all_squares = set(dict_1.keys()) | set(dict_2.keys()) |
|
|
| for square in all_squares: |
| piece_1 = dict_1.get(square) |
| piece_2 = dict_2.get(square) |
|
|
| if piece_1 == piece_2: |
| if piece_1: |
| consensus.append(ChessPiecePosition(square=square, piece=piece_1)) |
| else: |
| conflicts.append({ |
| "square": square, |
| "analysis_1": piece_1, |
| "analysis_2": piece_2 |
| }) |
|
|
| need_arbitration = len(conflicts) > 0 |
|
|
| return { |
| "conflicts": conflicts, |
| "consensus": ChessBoardAnalysis(positions=consensus), |
| "need_arbitration": need_arbitration |
| } |
|
|
| def arbitrate_conflicts(self, state: dict, board_orientation: str, image_path: str, depth: int = 1) -> dict: |
| """Arbitrate conflicting piece positions""" |
| print(f"Arbitrating conflicts with depth {depth}") |
|
|
| conflicts = state.get("conflicts", []) |
| conflicts_sqares = [] |
| for conflict in conflicts: |
| if conflict["analysis_1"] is not None: conflicts_sqares.append(conflict["analysis_1"]) |
| if conflict["analysis_2"] is not None: conflicts_sqares.append(conflict["analysis_2"]) |
|
|
| conflicts_sqares = set(conflicts_sqares) |
|
|
| print("Pieces with conflicts:", conflicts_sqares) |
|
|
| first_analysis_res = self.analyze_board_from_image(board_orientation, image_path, 1, conflicts_sqares) |
| second_analysis_res = self.analyze_board_from_image(board_orientation, image_path, 2, conflicts_sqares) |
| result = self.compare_analyses(first_analysis_res, second_analysis_res) |
| result.get("consensus").merge_with(state.get("consensus")) |
| if result['conflicts'] is not None and len(result['conflicts']) > 0: |
| if depth > 0: |
| depth -= 1 |
| result = self.arbitrate_conflicts(result, board_orientation, image_path, depth) |
| else: |
| print("Arbitrage completed with conflicts. took llm2 as ground truth") |
| result.get("consensus").merge_with(second_analysis_res) |
|
|
| return result |
|
|
|
|
| class ChessEngineAnalyzer: |
| def __init__(self, stockfish_path: str = "stockfish"): |
| self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path) |
|
|
| def analyze_position(self, fen: str, depth: int = 18) -> Dict[str, Any]: |
| """Analyze chess position using Stockfish""" |
| board = chess.Board(fen) |
|
|
| |
| info = self.engine.analyse(board, chess.engine.Limit(depth=depth)) |
|
|
| best_move = info.get("pv", [])[0] if info.get("pv") else None |
| evaluation = info.get("score", chess.engine.PovScore(chess.engine.Cp(0), chess.WHITE)) |
|
|
| return { |
| "best_move": best_move.uci() if best_move else None, |
| "evaluation": str(evaluation), |
| "depth": depth, |
| "analysis": info |
| } |
|
|
| def close(self): |
| self.engine.quit() |
|
|
|
|
| class ChessMoveExplainer: |
| def __init__(self): |
| self.llm = ChatOpenAI( |
| model="gpt-4" |
| ) |
|
|
| def explain_move(self, fen: str, move: str, analysis: Dict) -> str: |
| """Generate human-readable explanation of the recommended move""" |
| board = chess.Board(fen) |
| san_move = board.san(chess.Move.from_uci(move)) |
|
|
| prompt = f""" |
| Chess position FEN: {fen} |
| Recommended move: {san_move} ({move}) |
| Engine evaluation: {analysis['evaluation']} |
| Analysis depth: {analysis['depth']} |
| |
| Explain this move recommendation in simple terms. Consider: |
| 1. Why this move is strong |
| 2. What threats it creates or prevents |
| 3. The strategic implications |
| 4. Alternative moves and why they're inferior |
| 5. Keep it concise but informative for an intermediate player |
| """ |
|
|
| response = self.llm.invoke([HumanMessage(content=prompt)]) |
| return response.content |
|
|
|
|
| @tool |
| def chess_analysis_tool(active_color: str, file_reference: str) -> str: |
| """ |
| Tool for analyzing a chess board images and recommending moves |
| :param active_color: The color that should execute the next move |
| :param file_reference: the reference of the image to be analyzed |
| :return: the recommended move along with an analysis |
| """ |
| vision_analyzer = ChessVisionAnalyzer() |
| engine_analyzer = ChessEngineAnalyzer(os.getenv("CHESS_ENGINE_PATH")) |
| move_explainer = ChessMoveExplainer() |
| fen = vision_analyzer.analyze_board(active_color, file_reference) |
|
|
| print(f"Got fen {fen}") |
| analysis_result = engine_analyzer.analyze_position(fen) |
| print(f"Got analysis reslut {analysis_result}") |
| engine_analyzer.close() |
| explanation = move_explainer.explain_move(fen, analysis_result["best_move"], analysis_result) |
|
|
| return explanation |
|
|