Multi-shot 核心代码梳理(/data/rczhang/PencilFolder/multi-shot)
1. 核心入口与代码地图
multi_view/train.py:训练入口,封装WanTrainingModule,解析conf/multi-view.yaml。multi_view/datasets/videodataset.py:数据集读取、视频采样、参考图像(人脸)裁剪逻辑。multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py:核心 Pipeline,包含多镜头 prompt 处理、shot index、RoPE、attention mask 等设计。multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py:DiT 模型改动,支持 shot-aware self-attn 与 ID token 设计。multi_view/conf/multi-view.yaml与multi_view/train.sh:训练配置与启动脚本。
2. 数据集与输入格式
2.1 预期 JSON 结构
videodataset.py 期望 dataset_base_path 指向一个 JSON,结构类似:
{
"video_key_or_path": {
"disk_path": "/path/to/video.mp4",
"text": "caption",
"facedetect_v1": [...],
"facedetect_v1_frame_index": [...]
}
}
2.2 视频与参考图像采样
- 采样固定 5 秒片段,原视频若超过 5 秒则随机截取一段;目标帧率 16 FPS,得到 81 帧。
get_ref_id根据人脸 yaw/pitch/roll 差异阈值(默认 50)挑选 3 帧差异明显的参考图像。- 参考图像裁剪后 resize/补白到
(height, width)。 - 数据集随机 95/5 划分 train/test。
2.3 Dataset 输出(期望 vs 实际)
train.py 的 forward_preprocess 期望每个样本至少包含:
video: list[Image],拼接后的多镜头视频帧。pre_shot_caption: list[str],每个 shot 的文字描述。ref_images: list[list[list[Image]]],形如[batch][ID][3张ref]。ref_num与可选的shot_cut_frames、ID_2_shot(用于 shot 索引与注意力设计)。
当前 videodataset.py 里 __getitem__ 返回的是占位字段(如 ID_num、Image0 等未定义变量),并未真正输出多镜头结构。
3. 框架设计(multi-shot 重点)
3.1 Pipeline 分层
WanVideoPipeline 内部以 Unit 流水线方式处理:
WanVideoUnit_ShapeChecker:尺寸与帧数对齐WanVideoUnit_NoiseInitializer:噪声初始化WanVideoUnit_PromptEmbedder:shot prompt 编码WanVideoUnit_InputVideoEmbedder:视频 VAE 编码WanVideoUnit_RefEmbedderFused:ref image VAE 编码并拼接WanVideoUnit_CfgMerger:CFG 合并WanVideoUnit_ShotEmbedder:shot 索引构造
3.2 Prompt/shot 编码设计
WanVideoUnit_PromptEmbedder.encode_prompt_separately 设计思路:
- 输入为
{"shot_caption": [shot1, shot2, ...]}。 - 将各 shot caption 拼接编码,并记录每个 shot 在 token 序列中的起止位置。
- 输出
context(text embedding)和text_cut_positions(shot->token range)。
3.3 Shot 索引与 cross-attention mask
WanVideoUnit_ShotEmbedder 设计思路:
- 输入
shot_cut_frames(每个样本的 shot 结束帧索引列表)与num_frames。 - 将 frame 边界映射到 latent 帧索引,生成
shot_indices(每个 latent token 属于哪个 shot)。
在 model_fn_wan_video 中:
- 使用
shot_indices和text_cut_positions构建 cross-attn mask。 - 约束“每个视频 token 只能 attend 自己 shot 的文本 token”,ref image token 不 attend text。
3.4 RoPE 设计
WanModel预计算 3D RoPE(f/h/w)及 4Dshot_freqs(shot + f/h/w)。model_fn_wan_video中:shot_rope:为每个 shot 生成独立 RoPE(shot 维度注入)。split_rope:对 ref image 的 RoPE 位置做 offset,避免与视频 token 混叠(split1/2/3 三种策略)。
3.5 ID/Ref 图像注入设计
WanVideoUnit_RefEmbedderFused 将 ref image 编码并拼到 latent 序列末尾。
wan_video_dit.py 的 attention_per_batch_with_shots 设计了 ID token 机制:
- 每个 shot 可以附加“该 shot 相关 ID 的 ref tokens”作为额外 K/V。
- 通过
ID_2_shot映射控制哪些 ID token 参与哪个 shot 的 attention。
4. 训练链路
训练入口:multi_view/train.py
- 解析
conf/multi-view.yaml写入 args(数据路径、分辨率、ref_num、split_rope 等)。 - 初始化
MulltiShot_MultiView_Dataset。 WanTrainingModule内部构造WanVideoPipeline,设置 scheduler、冻结模型、可选 LoRA。launch_training_task使用 Accelerate 做分布式训练与 checkpoint。
5. 可运行性评估(当前阻塞点)
5.1 直接阻塞(会导致运行失败)
videodataset.py中返回值包含未定义变量:ID_num、Image0/1/2。WanVideoUnit_PromptEmbedder与WanVideoUnit_ShotEmbedder存在大量拼写/变量错误:pip/probmpt/enmurate/toch等。model_fn_wan_video函数签名本身有语法错误(ID_2_shot参数缺少逗号)。DiTBlock调用CrossAttention时多传了attn_mask参数,但CrossAttention.forward不接受。- cross-attn mask 构造中引用了未定义变量
shot_ranges,且断言语句语法不正确。 shot_rope分支硬编码了shots_nums_batch,未与 dataset 输出对齐。wan_parser未定义--shot_rope,但model_fn_wan_video依赖args.shot_rope。
5.2 依赖与路径问题
- 训练与推理脚本大量使用绝对路径(
/root/paddlejob/...),本地环境无法复用。 requirements.txt仅包含 build/test 依赖,缺少 runtime 依赖(torch、accelerate、peft、decord、modelscope、flash-attn 等)。WanVideoPipeline.from_pretrained内部写死 tokenizer 路径,无法在通用环境运行。
5.3 结论
当前代码的多镜头设计方向清晰,但处于实验/草案阶段;在 dataset、pipeline 与模型实现上存在多处语法/变量/接口不一致的问题,按现状无法跑通训练或推理。
6. 最小可跑通修复清单(建议)
- 补齐
videodataset.py输出:真实的pre_shot_caption、shot_cut_frames、ref_images、ID_2_shot,并移除未定义变量。 - 修正
wan_video_new.py的拼写与变量错误,确保 prompt 编码与 shot 索引能正常生成。 - 统一 cross-attn 接口:
CrossAttention.forward增加attn_mask或在DiTBlock中移除传参。 - 修正 shot mask 构造中的未定义变量与断言语句。
- 将
shot_rope的 shot 帧数与 dataset 对齐,去掉硬编码。 wan_parser补充--shot_rope,并清理绝对路径依赖。- 补齐运行依赖(torch/accelerate/peft/decord/modelscope/flash-attn 等)并在 README/requirements 中显式声明。