File size: 13,950 Bytes
b4b2877 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | #!/usr/bin/env python3
"""
Re-annotate action segments using LLM (GPT-4o-mini).
1. Re-classify existing segments with better accuracy
2. Infer actions in unlabeled gaps based on context (scene, surrounding actions)
3. Output improved annotations with higher coverage
"""
import os
import sys
import json
import re
import time
import copy
import glob
import urllib.request
from collections import Counter
ANN_DIR = "${PULSE_ROOT}/annotations_by_scene"
OUTPUT_DIR = "${PULSE_ROOT}/annotations_v2"
DATASET_DIR = "${PULSE_ROOT}/dataset"
API_URL = "https://api.chatanywhere.tech/v1/chat/completions"
API_KEYS = [
"sk-MN5n1uEETyaky96fLJdHqZobXF1f7KmOrZHzwD3lt585asFQ",
"sk-YnYrtPdAXwlE12hRpi6dYqlE1RRVR3LDVBka6wKaefU4iQRY",
"sk-jOZtodDv6OxUOMu3NuJ8lzffjwBlshn9OHY5KSmqmPTtc9qs",
"sk-qAaKTKYIRF24btu1oQWgubWG4UdA92bILNtzOkHNEPAcCxdB",
"sk-MgCBBonblMrCFnSXd6fJZaBLTCfCJ5FjYZfSe2e46bgmyktk",
"sk-79e30kYRgduuf2fSU0Lsc814YjNkClXXzQqIbx0iLS40IOEH",
"sk-h9Tej4tW6AQC6fT0njfzrPKXEk6fBwpiSvvQd0aJAhw4UwLz",
"sk-k2QNHt5wAH26Fw8hZuPWuVXw8Psd1jX09qusiA6PdBj5Vzuu",
"sk-w7EkTblciNI44cwosHXi0PGZNUf1hnJmpzOQ85va9VPdAKbz",
"sk-Dexs5ZF7OjFCq7CZW45wJ8EKoGtIswv6rsLUMzUXXkWBDBBJ",
]
SCENE_DESCRIPTIONS = {
"s1": "办公桌面整理与工作准备(整理文件、电源线、鼠标、笔记本电脑等)",
"s2": "快递打包发送(折叠纸箱、放入物品、封箱、贴标签等)",
"s3": "厨房调料整理(拿取调料瓶、倒调料、拧瓶盖、擦拭等)",
"s4": "清理餐后桌面(收碗碟、擦桌子、整理餐具、倒残渣等)",
"s5": "餐前桌面布置(铺桌布、摆放餐具碗碟、放杯子等)",
"s6": "商务旅行行李箱打包(折叠衣物、放入行李箱、整理物品等)",
"s7": "冲泡咖啡/饮品(取杯子、放咖啡粉/茶包、倒热水、搅拌等)",
"s8": "晾衣架整理与衣物收纳(取衣架、挂衣服、折叠衣物等)",
}
ACTION_CATEGORIES = """动作类别定义(共11类):
1. Grasp - 抓取/拿起物体(手从无接触到接触并握住物体)
2. Place - 放置/放下物体(将物体放到某个位置并释放)
3. Pour - 倾倒/注入液体或颗粒(倒水、倒调料、倒咖啡粉等)
4. Wipe - 擦拭/清洁表面(用抹布或手擦桌面、瓶身等)
5. Fold - 折叠/卷起(折衣服、折桌布、折纸箱等)
6. OpenClose - 打开/关闭/旋开/旋紧(开盒子、拧瓶盖、拉拉链、合箱盖等)
7. Stir - 搅拌(搅拌咖啡、搅拌饮品等)
8. TearCut - 撕/剪/粘贴(撕胶带、剪快递单、贴标签等)
9. Arrange - 整理/摆放/调整位置(摆餐具、整理文件、调整物品位置、理线等)
10. Transport - 搬运/移动物体到较远位置(把包裹搬到架子、把碗端到水槽等)
11. Idle - 空闲/过渡/无明确操作(双手无目的性动作、等待、观察等)
注意:
- 只有真正没有任何手部操作时才标Idle
- "调整姿态"、"检查物体"等属于Arrange
- "插入"、"装入"等属于Place
- "提起并移动"如果距离短属于Grasp,距离远属于Transport
"""
current_key_idx = 0
call_count = 0
def call_llm(prompt, max_tokens=1000, retries=3):
"""Call LLM API with automatic key rotation."""
global current_key_idx, call_count
for attempt in range(retries * len(API_KEYS)):
key = API_KEYS[current_key_idx]
try:
data = json.dumps({
"model": "gpt-4o-mini",
"messages": [{"role": "user", "content": prompt}],
"max_tokens": max_tokens,
"temperature": 0.1,
}).encode()
req = urllib.request.Request(
API_URL, data=data,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
}
)
resp = urllib.request.urlopen(req, timeout=30)
result = json.loads(resp.read())
call_count += 1
return result["choices"][0]["message"]["content"]
except Exception as e:
err = str(e)
if "429" in err or "quota" in err or "limit" in err or "402" in err:
# Key exhausted, rotate
print(f" Key {current_key_idx+1} exhausted, rotating...")
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
elif "timeout" in err.lower():
time.sleep(1)
else:
print(f" API error: {err[:100]}")
current_key_idx = (current_key_idx + 1) % len(API_KEYS)
time.sleep(0.5)
print(" WARNING: All API keys failed!")
return None
def reclassify_segments(segments, scene_id):
"""Use LLM to reclassify all segments in a recording."""
scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
# Build segment list for prompt
seg_list = []
for i, seg in enumerate(segments):
seg_list.append(f"{i+1}. [{seg['timestamp']}] {seg['task']}")
seg_text = "\n".join(seg_list)
prompt = f"""你是一个人体动作标注专家。请为以下每个动作片段分配一个动作类别。
场景:{scene_desc}
{ACTION_CATEGORIES}
动作片段列表:
{seg_text}
请严格按以下JSON格式返回,不要添加任何额外文字:
[{{"id": 1, "action": "类别名"}}, {{"id": 2, "action": "类别名"}}, ...]
每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""
response = call_llm(prompt, max_tokens=len(segments) * 40)
if response is None:
return None
# Parse response
try:
# Extract JSON from response
match = re.search(r'\[.*\]', response, re.DOTALL)
if match:
results = json.loads(match.group())
return {r["id"]: r["action"] for r in results}
except (json.JSONDecodeError, KeyError) as e:
print(f" Parse error: {e}, response: {response[:200]}")
return None
def infer_gap_actions(scene_id, before_seg, after_seg, gap_start, gap_end):
"""Use LLM to infer what actions likely happened in an unlabeled gap."""
scene_desc = SCENE_DESCRIPTIONS.get(scene_id, "日常活动")
gap_duration = gap_end - gap_start
before_text = f"[{before_seg['timestamp']}] {before_seg['task']}" if before_seg else "(录制开始)"
after_text = f"[{after_seg['timestamp']}] {after_seg['task']}" if after_seg else "(录制结束)"
prompt = f"""你是一个人体动作标注专家。在一段日常活动录制中,有一段时间没有被标注。请根据场景和前后动作推断这段时间内最可能发生的动作。
场景:{scene_desc}
未标注时间段:{gap_start//60:02d}:{gap_start%60:02d} - {gap_end//60:02d}:{gap_end%60:02d}(共{gap_duration}秒)
前一个标注动作:{before_text}
后一个标注动作:{after_text}
{ACTION_CATEGORIES}
请推断这段时间内可能发生的动作序列。每个动作段落2-4秒,时间用MM:SS格式。
如果确实是空闲等待,标注为Idle。
严格按以下JSON格式返回,不要添加任何额外文字:
[{{"timestamp": "MM:SS-MM:SS", "task": "动作描述", "action": "类别名"}}]
每个action必须是以下之一:Grasp, Place, Pour, Wipe, Fold, OpenClose, Stir, TearCut, Arrange, Transport, Idle"""
response = call_llm(prompt, max_tokens=500)
if response is None:
return []
try:
match = re.search(r'\[.*\]', response, re.DOTALL)
if match:
results = json.loads(match.group())
# Validate timestamps
valid = []
for r in results:
if "timestamp" in r and "action" in r and "task" in r:
ts_match = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', r["timestamp"])
if ts_match:
s = int(ts_match.group(1))*60 + int(ts_match.group(2))
e = int(ts_match.group(3))*60 + int(ts_match.group(4))
if gap_start <= s < e <= gap_end:
valid.append(r)
return valid
except (json.JSONDecodeError, KeyError) as e:
print(f" Parse error: {e}")
return []
def get_recording_duration(vol, scenario):
"""Get total recording duration in seconds."""
meta_path = os.path.join(DATASET_DIR, vol, scenario, "alignment_metadata.json")
if os.path.exists(meta_path):
meta = json.load(open(meta_path))
if "aligned_length_sec" in meta:
return meta["aligned_length_sec"]
if "aligned_length_frames" in meta:
return meta["aligned_length_frames"] / 100.0
return None
def process_one_file(ann_path, vol, scenario):
"""Process one annotation file: reclassify + fill gaps."""
data = json.load(open(ann_path))
segments = data["segments"]
if not segments:
return data, {"reclassified": 0, "gaps_filled": 0}
# Step 1: Reclassify existing segments
print(f" Reclassifying {len(segments)} segments...")
classifications = reclassify_segments(segments, scenario)
if classifications:
for i, seg in enumerate(segments):
action = classifications.get(i + 1)
if action and action in {"Grasp", "Place", "Pour", "Wipe", "Fold",
"OpenClose", "Stir", "TearCut", "Arrange",
"Transport", "Idle"}:
seg["action_label"] = action
else:
seg["action_label"] = "Idle"
else:
# Fallback: keep without label
for seg in segments:
seg["action_label"] = "Idle"
reclassified = sum(1 for s in segments if "action_label" in s)
# Step 2: Find and fill gaps ≥ 3 seconds
# Parse all timestamps
parsed = []
for seg in segments:
m = re.match(r'(\d+):(\d+)\s*-\s*(\d+):(\d+)', seg["timestamp"])
if m:
s = int(m.group(1))*60 + int(m.group(2))
e = int(m.group(3))*60 + int(m.group(4))
parsed.append((s, e, seg))
parsed.sort()
total_dur = get_recording_duration(vol, scenario)
new_segments = []
gaps_filled = 0
for i in range(len(parsed)):
new_segments.append(parsed[i][2])
# Check gap after this segment
if i < len(parsed) - 1:
gap_start = parsed[i][1]
gap_end = parsed[i + 1][0]
elif total_dur:
gap_start = parsed[i][1]
gap_end = int(total_dur)
else:
continue
gap_duration = gap_end - gap_start
if gap_duration >= 3:
before_seg = parsed[i][2]
after_seg = parsed[i + 1][2] if i < len(parsed) - 1 else None
print(f" Filling gap {gap_start}s-{gap_end}s ({gap_duration}s)...")
inferred = infer_gap_actions(scenario, before_seg, after_seg, gap_start, gap_end)
for inf in inferred:
new_seg = {
"timestamp": inf["timestamp"],
"task": inf["task"],
"action_label": inf["action"],
"source": "llm_inferred",
"left_hand": "",
"right_hand": "",
"bimanual_interaction": "",
"objects": [],
}
new_segments.append(new_seg)
gaps_filled += 1
# Also check gap at the beginning
if parsed and parsed[0][0] >= 3:
print(f" Filling start gap 0s-{parsed[0][0]}s...")
inferred = infer_gap_actions(scenario, None, parsed[0][2], 0, parsed[0][0])
for inf in inferred:
new_seg = {
"timestamp": inf["timestamp"],
"task": inf["task"],
"action_label": inf["action"],
"source": "llm_inferred",
"left_hand": "",
"right_hand": "",
"bimanual_interaction": "",
"objects": [],
}
new_segments.insert(0, new_seg)
gaps_filled += 1
# Sort by timestamp
def sort_key(seg):
m = re.match(r'(\d+):(\d+)', seg["timestamp"])
return int(m.group(1))*60 + int(m.group(2)) if m else 0
new_segments.sort(key=sort_key)
result = copy.deepcopy(data)
result["segments"] = new_segments
return result, {"reclassified": reclassified, "gaps_filled": gaps_filled}
def main():
os.makedirs(OUTPUT_DIR, exist_ok=True)
total_reclassified = 0
total_gaps_filled = 0
total_files = 0
for vol_dir in sorted(glob.glob(f"{ANN_DIR}/v*")):
vol = os.path.basename(vol_dir)
out_vol_dir = os.path.join(OUTPUT_DIR, vol)
os.makedirs(out_vol_dir, exist_ok=True)
for ann_file in sorted(glob.glob(f"{vol_dir}/s*.json")):
scenario = os.path.basename(ann_file).replace(".json", "")
print(f"\n[{vol}/{scenario}]", flush=True)
result, stats = process_one_file(ann_file, vol, scenario)
# Save
out_path = os.path.join(out_vol_dir, f"{scenario}.json")
with open(out_path, "w", encoding="utf-8") as f:
json.dump(result, f, ensure_ascii=False, indent=2)
total_reclassified += stats["reclassified"]
total_gaps_filled += stats["gaps_filled"]
total_files += 1
print(f" Done: {stats['reclassified']} reclassified, {stats['gaps_filled']} gaps filled",
flush=True)
print(f"\n{'='*60}")
print(f"Total: {total_files} files processed")
print(f" Reclassified: {total_reclassified} segments")
print(f" Gap-filled: {total_gaps_filled} new segments")
print(f" API calls: {call_count}")
print(f" Output: {OUTPUT_DIR}")
if __name__ == "__main__":
main()
|