# Multi-shot 核心代码梳理(/data/rczhang/PencilFolder/multi-shot) ## 1. 核心入口与代码地图 - `multi_view/train.py`:训练入口,封装 `WanTrainingModule`,解析 `conf/multi-view.yaml`。 - `multi_view/datasets/videodataset.py`:数据集读取、视频采样、参考图像(人脸)裁剪逻辑。 - `multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py`:核心 Pipeline,包含多镜头 prompt 处理、shot index、RoPE、attention mask 等设计。 - `multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py`:DiT 模型改动,支持 shot-aware self-attn 与 ID token 设计。 - `multi_view/conf/multi-view.yaml` 与 `multi_view/train.sh`:训练配置与启动脚本。 ## 2. 数据集与输入格式 ### 2.1 预期 JSON 结构 `videodataset.py` 期望 `dataset_base_path` 指向一个 JSON,结构类似: ```json { "video_key_or_path": { "disk_path": "/path/to/video.mp4", "text": "caption", "facedetect_v1": [...], "facedetect_v1_frame_index": [...] } } ``` ### 2.2 视频与参考图像采样 - 采样固定 5 秒片段,原视频若超过 5 秒则随机截取一段;目标帧率 16 FPS,得到 81 帧。 - `get_ref_id` 根据人脸 yaw/pitch/roll 差异阈值(默认 50)挑选 3 帧差异明显的参考图像。 - 参考图像裁剪后 resize/补白到 `(height, width)`。 - 数据集随机 95/5 划分 train/test。 ### 2.3 Dataset 输出(期望 vs 实际) `train.py` 的 `forward_preprocess` 期望每个样本至少包含: - `video`: list[Image],拼接后的多镜头视频帧。 - `pre_shot_caption`: list[str],每个 shot 的文字描述。 - `ref_images`: list[list[list[Image]]],形如 `[batch][ID][3张ref]`。 - `ref_num` 与可选的 `shot_cut_frames`、`ID_2_shot`(用于 shot 索引与注意力设计)。 当前 `videodataset.py` 里 `__getitem__` 返回的是占位字段(如 `ID_num`、`Image0` 等未定义变量),并未真正输出多镜头结构。 ## 3. 框架设计(multi-shot 重点) ### 3.1 Pipeline 分层 `WanVideoPipeline` 内部以 Unit 流水线方式处理: 1) `WanVideoUnit_ShapeChecker`:尺寸与帧数对齐 2) `WanVideoUnit_NoiseInitializer`:噪声初始化 3) `WanVideoUnit_PromptEmbedder`:shot prompt 编码 4) `WanVideoUnit_InputVideoEmbedder`:视频 VAE 编码 5) `WanVideoUnit_RefEmbedderFused`:ref image VAE 编码并拼接 6) `WanVideoUnit_CfgMerger`:CFG 合并 7) `WanVideoUnit_ShotEmbedder`:shot 索引构造 ### 3.2 Prompt/shot 编码设计 `WanVideoUnit_PromptEmbedder.encode_prompt_separately` 设计思路: - 输入为 `{"shot_caption": [shot1, shot2, ...]}`。 - 将各 shot caption 拼接编码,并记录每个 shot 在 token 序列中的起止位置。 - 输出 `context`(text embedding)和 `text_cut_positions`(shot->token range)。 ### 3.3 Shot 索引与 cross-attention mask `WanVideoUnit_ShotEmbedder` 设计思路: - 输入 `shot_cut_frames`(每个样本的 shot 结束帧索引列表)与 `num_frames`。 - 将 frame 边界映射到 latent 帧索引,生成 `shot_indices`(每个 latent token 属于哪个 shot)。 在 `model_fn_wan_video` 中: - 使用 `shot_indices` 和 `text_cut_positions` 构建 cross-attn mask。 - 约束“每个视频 token 只能 attend 自己 shot 的文本 token”,ref image token 不 attend text。 ### 3.4 RoPE 设计 - `WanModel` 预计算 3D RoPE(f/h/w)及 4D `shot_freqs`(shot + f/h/w)。 - `model_fn_wan_video` 中: - `shot_rope`:为每个 shot 生成独立 RoPE(shot 维度注入)。 - `split_rope`:对 ref image 的 RoPE 位置做 offset,避免与视频 token 混叠(split1/2/3 三种策略)。 ### 3.5 ID/Ref 图像注入设计 `WanVideoUnit_RefEmbedderFused` 将 ref image 编码并拼到 latent 序列末尾。 `wan_video_dit.py` 的 `attention_per_batch_with_shots` 设计了 ID token 机制: - 每个 shot 可以附加“该 shot 相关 ID 的 ref tokens”作为额外 K/V。 - 通过 `ID_2_shot` 映射控制哪些 ID token 参与哪个 shot 的 attention。 ## 4. 训练链路 训练入口:`multi_view/train.py` - 解析 `conf/multi-view.yaml` 写入 args(数据路径、分辨率、ref_num、split_rope 等)。 - 初始化 `MulltiShot_MultiView_Dataset`。 - `WanTrainingModule` 内部构造 `WanVideoPipeline`,设置 scheduler、冻结模型、可选 LoRA。 - `launch_training_task` 使用 Accelerate 做分布式训练与 checkpoint。 ## 5. 可运行性评估(当前阻塞点) ### 5.1 直接阻塞(会导致运行失败) 1) `videodataset.py` 中返回值包含未定义变量:`ID_num`、`Image0/1/2`。 2) `WanVideoUnit_PromptEmbedder` 与 `WanVideoUnit_ShotEmbedder` 存在大量拼写/变量错误:`pip`/`probmpt`/`enmurate`/`toch` 等。 3) `model_fn_wan_video` 函数签名本身有语法错误(`ID_2_shot` 参数缺少逗号)。 4) `DiTBlock` 调用 `CrossAttention` 时多传了 `attn_mask` 参数,但 `CrossAttention.forward` 不接受。 5) cross-attn mask 构造中引用了未定义变量 `shot_ranges`,且断言语句语法不正确。 6) `shot_rope` 分支硬编码了 `shots_nums_batch`,未与 dataset 输出对齐。 7) `wan_parser` 未定义 `--shot_rope`,但 `model_fn_wan_video` 依赖 `args.shot_rope`。 ### 5.2 依赖与路径问题 - 训练与推理脚本大量使用绝对路径(`/root/paddlejob/...`),本地环境无法复用。 - `requirements.txt` 仅包含 build/test 依赖,缺少 runtime 依赖(torch、accelerate、peft、decord、modelscope、flash-attn 等)。 - `WanVideoPipeline.from_pretrained` 内部写死 tokenizer 路径,无法在通用环境运行。 ### 5.3 结论 当前代码的多镜头设计方向清晰,但处于实验/草案阶段;在 dataset、pipeline 与模型实现上存在多处语法/变量/接口不一致的问题,**按现状无法跑通训练或推理**。 ## 6. 最小可跑通修复清单(建议) 1) 补齐 `videodataset.py` 输出:真实的 `pre_shot_caption`、`shot_cut_frames`、`ref_images`、`ID_2_shot`,并移除未定义变量。 2) 修正 `wan_video_new.py` 的拼写与变量错误,确保 prompt 编码与 shot 索引能正常生成。 3) 统一 cross-attn 接口:`CrossAttention.forward` 增加 `attn_mask` 或在 `DiTBlock` 中移除传参。 4) 修正 shot mask 构造中的未定义变量与断言语句。 5) 将 `shot_rope` 的 shot 帧数与 dataset 对齐,去掉硬编码。 6) `wan_parser` 补充 `--shot_rope`,并清理绝对路径依赖。 7) 补齐运行依赖(torch/accelerate/peft/decord/modelscope/flash-attn 等)并在 README/requirements 中显式声明。