sandbox-5ca717e4 / memory.py
Justin-lee's picture
Update memory.py with web tools in system prompt
1211240 verified
#!/usr/bin/env python3
"""
CodePilot Memory System — 仿 Claude Code 的四層記憶架構
=========================================================
層級 1: CODEPILOT.md 指令層級(走 CWD 到根目錄)
~/.codepilot/CODEPILOT.md ← 全域個人偏好
./CODEPILOT.md ← 專案指令(提交到 repo)
./.codepilot/CODEPILOT.md ← 備選位置
./.codepilot/rules/*.md ← 條件規則
./CODEPILOT.local.md ← 私人覆蓋(gitignore)
層級 2: MEMORY.md 自動記憶(跨 session)
~/.codepilot/projects/<project>/memory/MEMORY.md
記住:用戶偏好、專案決策、修正過的錯誤
層級 3: Session 對話歷史(JSONL 持久化)
~/.codepilot/projects/<project>/<session-id>.jsonl
層級 4: 對話內壓縮(context window 管理)
自動偵測 token 使用量,觸發 9 段摘要壓縮
"""
import json, os, re, uuid, hashlib, html
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional, List
CONFIG_DIR = Path.home() / ".codepilot"
# ============================================================
# Layer 1: CODEPILOT.md Instruction Hierarchy
# ============================================================
MEMORY_INSTRUCTION_PROMPT = (
"Codebase and user instructions are shown below. Be sure to adhere to these instructions. "
"IMPORTANT: These instructions OVERRIDE any default behavior and you MUST follow them exactly as written."
)
MAX_MEMORY_CHARS = 40_000
def _strip_html_comments(text: str) -> str:
"""移除 HTML 註解(讓你放私人筆記模型看不到)"""
return re.sub(r'<!--.*?-->', '', text, flags=re.DOTALL)
def _sanitize_path(path: str) -> str:
"""把路徑轉成安全的目錄名"""
return hashlib.md5(path.encode()).hexdigest()[:12] + "_" + os.path.basename(path)
def load_instructions(cwd: str) -> str:
"""
從 CWD 往上搜尋所有 CODEPILOT.md,按優先級合併。
順序:全域 → 父目錄 → 專案目錄 → local(越後面優先級越高)
"""
files = []
# 全域用戶級
user_file = CONFIG_DIR / "CODEPILOT.md"
if user_file.exists():
content = user_file.read_text(encoding="utf-8", errors="replace")[:MAX_MEMORY_CHARS]
files.append(f"Contents of {user_file} (user-level instructions):\n\n{_strip_html_comments(content)}")
# 從根目錄走到 CWD
cwd_path = Path(cwd).resolve()
ancestors = list(reversed(cwd_path.parents))
ancestors.append(cwd_path)
for d in ancestors:
for candidate in [d / "CODEPILOT.md", d / ".codepilot" / "CODEPILOT.md"]:
if candidate.exists():
content = candidate.read_text(encoding="utf-8", errors="replace")[:MAX_MEMORY_CHARS]
files.append(f"Contents of {candidate} (project instructions):\n\n{_strip_html_comments(content)}")
# .codepilot/rules/*.md
rules_dir = d / ".codepilot" / "rules"
if rules_dir.is_dir():
for rule_file in sorted(rules_dir.glob("*.md")):
content = rule_file.read_text(encoding="utf-8", errors="replace")[:10_000]
files.append(f"Contents of {rule_file} (rule):\n\n{_strip_html_comments(content)}")
# Local 覆蓋(最高優先級,gitignore 用)
local_file = cwd_path / "CODEPILOT.local.md"
if local_file.exists():
content = local_file.read_text(encoding="utf-8", errors="replace")[:MAX_MEMORY_CHARS]
files.append(f"Contents of {local_file} (local overrides, private):\n\n{_strip_html_comments(content)}")
if not files:
return ""
return MEMORY_INSTRUCTION_PROMPT + "\n\n" + "\n\n---\n\n".join(files)
# ============================================================
# Layer 2: MEMORY.md Auto-Memory (Cross-Session)
# ============================================================
def _get_project_dir(cwd: str) -> Path:
"""取得專案的記憶目錄"""
d = CONFIG_DIR / "projects" / _sanitize_path(cwd)
d.mkdir(parents=True, exist_ok=True)
return d
def _get_memory_dir(cwd: str) -> Path:
d = _get_project_dir(cwd) / "memory"
d.mkdir(parents=True, exist_ok=True)
return d
def load_memory(cwd: str) -> str:
"""讀取 MEMORY.md 自動記憶"""
mem_file = _get_memory_dir(cwd) / "MEMORY.md"
if mem_file.exists():
content = mem_file.read_text(encoding="utf-8", errors="replace")
# 限制 200 行 / 25KB
lines = content.splitlines()[:200]
truncated = "\n".join(lines)[:25_000]
if len(lines) >= 200 or len(content) > 25_000:
truncated += "\n\n⚠️ Memory truncated. Consider consolidating."
return truncated
return ""
def save_memory(cwd: str, content: str):
"""保存 MEMORY.md"""
mem_file = _get_memory_dir(cwd) / "MEMORY.md"
mem_file.write_text(content, encoding="utf-8")
def append_memory(cwd: str, entry: str):
"""追加一條記憶"""
mem_file = _get_memory_dir(cwd) / "MEMORY.md"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M")
with open(mem_file, "a", encoding="utf-8") as f:
f.write(f"\n- [{timestamp}] {entry}\n")
# ============================================================
# Layer 3: Session Transcript (JSONL Persistence)
# ============================================================
class SessionTranscript:
"""JSONL 格式的對話持久化"""
def __init__(self, cwd: str, session_id: str = None):
self.session_id = session_id or str(uuid.uuid4())[:8]
self.project_dir = _get_project_dir(cwd)
self.transcript_file = self.project_dir / f"{self.session_id}.jsonl"
self.last_uuid = None
def append(self, msg_type: str, content) -> str:
"""追加一條訊息,回傳 UUID"""
msg_uuid = str(uuid.uuid4())[:12]
entry = {
"type": msg_type,
"uuid": msg_uuid,
"parentUuid": self.last_uuid,
"timestamp": datetime.now(timezone.utc).isoformat(),
"message": content,
}
with open(self.transcript_file, "a", encoding="utf-8") as f:
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
self.last_uuid = msg_uuid
return msg_uuid
def load_messages(self) -> List[dict]:
"""載入上次的對話(恢復 session)"""
if not self.transcript_file.exists():
return []
messages = []
for line in self.transcript_file.read_text().splitlines():
if not line.strip():
continue
try:
entry = json.loads(line)
msg = entry["message"]
if isinstance(msg, dict):
messages.append(msg)
elif isinstance(msg, str):
messages.append({"role": entry["type"], "content": msg})
except:
continue
return messages
@classmethod
def find_latest(cls, cwd: str) -> Optional['SessionTranscript']:
"""找到最近的 session"""
project_dir = _get_project_dir(cwd)
jsonl_files = sorted(project_dir.glob("*.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
if jsonl_files:
sid = jsonl_files[0].stem
t = cls(cwd, sid)
return t
return None
# ============================================================
# Layer 4: Context Compaction (Summary)
# ============================================================
COMPACT_PROMPT = """Your task is to create a detailed summary of the conversation so far.
Focus on information that will be needed to continue the work.
Your summary MUST include these sections:
1. **Primary Request and Intent** — What the user wants to accomplish
2. **Key Technical Concepts** — Languages, frameworks, patterns discussed
3. **Files and Code Sections** — Include actual code snippets that were written/modified
4. **Errors and Fixes** — What went wrong and how it was resolved
5. **Current State** — Where we are right now in the task
6. **Pending Tasks** — What still needs to be done
7. **Important User Preferences** — Any stated preferences about style, tools, etc.
Be thorough. Include actual code when relevant."""
AUTOCOMPACT_THRESHOLD_PCT = 0.75 # compact at 75% of context window
def estimate_tokens(messages: List[dict]) -> int:
"""粗略估算 token 數(1 token ≈ 4 chars for English, 2 chars for CJK)"""
total_chars = sum(len(m.get("content", "")) for m in messages)
# 混合語言估算
return int(total_chars / 3)
def should_compact(messages: List[dict], context_window: int = 32768) -> bool:
"""是否該壓縮對話"""
tokens = estimate_tokens(messages)
threshold = int(context_window * AUTOCOMPACT_THRESHOLD_PCT)
return tokens >= threshold
def compact_messages(messages: List[dict], model_chat_fn, recently_edited_files: List[str] = None) -> List[dict]:
"""
壓縮對話歷史:
1. 用模型生成摘要
2. 重新注入最近編輯的文件
"""
if len(messages) <= 3:
return messages # 太短不需要壓縮
system_msg = messages[0]
# 請模型摘要
summary_request = messages + [
{"role": "user", "content": COMPACT_PROMPT}
]
try:
summary = model_chat_fn(summary_request, max_tokens=2048)
except:
# 摘要失敗,保留最近 10 輪
return [system_msg] + messages[-20:]
# 構建新的對話
new_messages = [
system_msg,
{"role": "user", "content": "[System: Previous conversation was summarized to save context space]"},
{"role": "assistant", "content": f"Here's what we've discussed so far:\n\n{summary}"},
]
# 重新注入最近編輯的文件(最多 5 個)
if recently_edited_files:
file_context = []
for fpath in recently_edited_files[:5]:
try:
content = Path(fpath).read_text(encoding="utf-8")[:5000]
file_context.append(f"--- {fpath} (current state) ---\n{content}")
except:
continue
if file_context:
new_messages.append({
"role": "user",
"content": "Here are the current states of recently edited files:\n\n" + "\n\n".join(file_context)
})
new_messages.append({
"role": "assistant",
"content": "I've reviewed the current state of these files. Ready to continue."
})
return new_messages
# ============================================================
# File State Cache (Read-Before-Edit)
# ============================================================
class FileStateCache:
"""追蹤已讀文件狀態,強制 read-before-edit"""
MAX_ENTRIES = 100
UNCHANGED_STUB = ("File unchanged since last read. Refer to the earlier Read result.")
def __init__(self):
self._cache = {} # path → FileState
self._edited_files = [] # 追蹤最近編輯的文件
def record_read(self, path: str, content: str, offset=None, limit=None, is_partial=False):
path = os.path.abspath(path)
self._cache[path] = {
"content": content,
"mtime": os.path.getmtime(path),
"offset": offset,
"limit": limit,
"is_partial": is_partial,
}
# LRU eviction
if len(self._cache) > self.MAX_ENTRIES:
oldest = next(iter(self._cache))
del self._cache[oldest]
def check_can_edit(self, path: str) -> Optional[str]:
"""檢查是否可以編輯。回傳 None = 可以,否則回傳錯誤訊息"""
path = os.path.abspath(path)
if path not in self._cache:
return "❌ 必須先用 read_file 讀取文件才能編輯"
state = self._cache[path]
if not os.path.exists(path):
return "❌ 文件不存在"
current_mtime = os.path.getmtime(path)
if current_mtime != state["mtime"]:
return "❌ 文件已被外部修改,請重新 read_file"
return None
def record_edit(self, path: str, new_content: str):
"""編輯後更新快取"""
path = os.path.abspath(path)
self._cache[path] = {
"content": new_content,
"mtime": os.path.getmtime(path),
"offset": None,
"limit": None,
"is_partial": False,
}
if path not in self._edited_files:
self._edited_files.append(path)
def check_dedup(self, path: str, offset=None, limit=None) -> Optional[str]:
"""檢查文件是否未變更(省 context)"""
path = os.path.abspath(path)
if path not in self._cache:
return None
state = self._cache[path]
if state["offset"] == offset and state["limit"] == limit:
try:
if os.path.getmtime(path) == state["mtime"]:
return self.UNCHANGED_STUB
except:
pass
return None
def get_recently_edited(self, max_files=5) -> List[str]:
return self._edited_files[-max_files:]
# ============================================================
# System Prompt Builder
# ============================================================
def build_full_system_prompt(cwd: str, git_context: str = "") -> str:
"""組裝完整 system prompt(仿 Claude Code 順序)"""
sections = []
# 1. Identity
sections.append("""You are CodePilot, an expert AI programming assistant.
You work directly in the user's project — reading, editing, and creating files, running commands, and searching code.
You are thorough, precise, and always verify your changes.""")
# 2. Tool usage guidance
sections.append("""## Important Rules
- ALWAYS read a file before editing it
- For edit_file: old_string must EXACTLY match file content (whitespace matters)
- Prefer edit_file over write_file for existing files (smaller diff, safer)
- After making changes, verify by reading the file or running tests
- For git: stage specific files, never `git add -A`; create new commits, don't amend
- If a command might take > 30s, warn the user first""")
# 3. CODEPILOT.md instructions (priority override)
instructions = load_instructions(cwd)
if instructions:
sections.append(instructions)
# 4. MEMORY.md auto-memory
memory = load_memory(cwd)
if memory:
sections.append(f"## Project Memory (auto-saved across sessions)\n{memory}")
# 5. Environment info
import platform, sys
env_info = f"""## Environment
- Working directory: {cwd}
- Git: {git_context.split(chr(10))[0] if git_context else '(not a git repo)'}
- Platform: {sys.platform}
- OS: {platform.system()} {platform.release()}
- Python: {sys.version.split()[0]}"""
sections.append(env_info)
# 6. Tools
sections.append("""## Tools (use <tool>name
{json}</tool>)
- read_file: {"path":"...","offset":1,"limit":200} — also reads PDF, .ipynb, images
- edit_file: {"path":"...","old_string":"...","new_string":"..."} (must read first)
- write_file: {"path":"...","content":"..."}
- run_command: {"command":"...","timeout":120}
- search_files: {"pattern":"...","glob":"*.py"}
- list_files: {"pattern":"*","max_depth":3}
- git_status: {}
- web_fetch: {"url":"https://..."} — fetch a webpage (returns text content)
- web_search: {"query":"how to ..."} — search the web (DuckDuckGo)""")
return "\n\n".join(sections)