multishot / MULTI_SHOT_CORE_SUMMARY.md
PencilHu's picture
Upload folder using huggingface_hub
85752bc verified

Multi-shot 核心代码梳理(/data/rczhang/PencilFolder/multi-shot)

1. 核心入口与代码地图

  • multi_view/train.py:训练入口,封装 WanTrainingModule,解析 conf/multi-view.yaml
  • multi_view/datasets/videodataset.py:数据集读取、视频采样、参考图像(人脸)裁剪逻辑。
  • multi_view/DiffSynth-Studio-main/diffsynth/pipelines/wan_video_new.py:核心 Pipeline,包含多镜头 prompt 处理、shot index、RoPE、attention mask 等设计。
  • multi_view/DiffSynth-Studio-main/diffsynth/models/wan_video_dit.py:DiT 模型改动,支持 shot-aware self-attn 与 ID token 设计。
  • multi_view/conf/multi-view.yamlmulti_view/train.sh:训练配置与启动脚本。

2. 数据集与输入格式

2.1 预期 JSON 结构

videodataset.py 期望 dataset_base_path 指向一个 JSON,结构类似:

{
  "video_key_or_path": {
    "disk_path": "/path/to/video.mp4",
    "text": "caption",
    "facedetect_v1": [...],
    "facedetect_v1_frame_index": [...]
  }
}

2.2 视频与参考图像采样

  • 采样固定 5 秒片段,原视频若超过 5 秒则随机截取一段;目标帧率 16 FPS,得到 81 帧。
  • get_ref_id 根据人脸 yaw/pitch/roll 差异阈值(默认 50)挑选 3 帧差异明显的参考图像。
  • 参考图像裁剪后 resize/补白到 (height, width)
  • 数据集随机 95/5 划分 train/test。

2.3 Dataset 输出(期望 vs 实际)

train.pyforward_preprocess 期望每个样本至少包含:

  • video: list[Image],拼接后的多镜头视频帧。
  • pre_shot_caption: list[str],每个 shot 的文字描述。
  • ref_images: list[list[list[Image]]],形如 [batch][ID][3张ref]
  • ref_num 与可选的 shot_cut_framesID_2_shot(用于 shot 索引与注意力设计)。

当前 videodataset.py__getitem__ 返回的是占位字段(如 ID_numImage0 等未定义变量),并未真正输出多镜头结构。

3. 框架设计(multi-shot 重点)

3.1 Pipeline 分层

WanVideoPipeline 内部以 Unit 流水线方式处理:

  1. WanVideoUnit_ShapeChecker:尺寸与帧数对齐
  2. WanVideoUnit_NoiseInitializer:噪声初始化
  3. WanVideoUnit_PromptEmbedder:shot prompt 编码
  4. WanVideoUnit_InputVideoEmbedder:视频 VAE 编码
  5. WanVideoUnit_RefEmbedderFused:ref image VAE 编码并拼接
  6. WanVideoUnit_CfgMerger:CFG 合并
  7. WanVideoUnit_ShotEmbedder:shot 索引构造

3.2 Prompt/shot 编码设计

WanVideoUnit_PromptEmbedder.encode_prompt_separately 设计思路:

  • 输入为 {"shot_caption": [shot1, shot2, ...]}
  • 将各 shot caption 拼接编码,并记录每个 shot 在 token 序列中的起止位置。
  • 输出 context(text embedding)和 text_cut_positions(shot->token range)。

3.3 Shot 索引与 cross-attention mask

WanVideoUnit_ShotEmbedder 设计思路:

  • 输入 shot_cut_frames(每个样本的 shot 结束帧索引列表)与 num_frames
  • 将 frame 边界映射到 latent 帧索引,生成 shot_indices(每个 latent token 属于哪个 shot)。

model_fn_wan_video 中:

  • 使用 shot_indicestext_cut_positions 构建 cross-attn mask。
  • 约束“每个视频 token 只能 attend 自己 shot 的文本 token”,ref image token 不 attend text。

3.4 RoPE 设计

  • WanModel 预计算 3D RoPE(f/h/w)及 4D shot_freqs(shot + f/h/w)。
  • model_fn_wan_video 中:
    • shot_rope:为每个 shot 生成独立 RoPE(shot 维度注入)。
    • split_rope:对 ref image 的 RoPE 位置做 offset,避免与视频 token 混叠(split1/2/3 三种策略)。

3.5 ID/Ref 图像注入设计

WanVideoUnit_RefEmbedderFused 将 ref image 编码并拼到 latent 序列末尾。 wan_video_dit.pyattention_per_batch_with_shots 设计了 ID token 机制:

  • 每个 shot 可以附加“该 shot 相关 ID 的 ref tokens”作为额外 K/V。
  • 通过 ID_2_shot 映射控制哪些 ID token 参与哪个 shot 的 attention。

4. 训练链路

训练入口:multi_view/train.py

  • 解析 conf/multi-view.yaml 写入 args(数据路径、分辨率、ref_num、split_rope 等)。
  • 初始化 MulltiShot_MultiView_Dataset
  • WanTrainingModule 内部构造 WanVideoPipeline,设置 scheduler、冻结模型、可选 LoRA。
  • launch_training_task 使用 Accelerate 做分布式训练与 checkpoint。

5. 可运行性评估(当前阻塞点)

5.1 直接阻塞(会导致运行失败)

  1. videodataset.py 中返回值包含未定义变量:ID_numImage0/1/2
  2. WanVideoUnit_PromptEmbedderWanVideoUnit_ShotEmbedder 存在大量拼写/变量错误:pip/probmpt/enmurate/toch 等。
  3. model_fn_wan_video 函数签名本身有语法错误(ID_2_shot 参数缺少逗号)。
  4. DiTBlock 调用 CrossAttention 时多传了 attn_mask 参数,但 CrossAttention.forward 不接受。
  5. cross-attn mask 构造中引用了未定义变量 shot_ranges,且断言语句语法不正确。
  6. shot_rope 分支硬编码了 shots_nums_batch,未与 dataset 输出对齐。
  7. wan_parser 未定义 --shot_rope,但 model_fn_wan_video 依赖 args.shot_rope

5.2 依赖与路径问题

  • 训练与推理脚本大量使用绝对路径(/root/paddlejob/...),本地环境无法复用。
  • requirements.txt 仅包含 build/test 依赖,缺少 runtime 依赖(torch、accelerate、peft、decord、modelscope、flash-attn 等)。
  • WanVideoPipeline.from_pretrained 内部写死 tokenizer 路径,无法在通用环境运行。

5.3 结论

当前代码的多镜头设计方向清晰,但处于实验/草案阶段;在 dataset、pipeline 与模型实现上存在多处语法/变量/接口不一致的问题,按现状无法跑通训练或推理

6. 最小可跑通修复清单(建议)

  1. 补齐 videodataset.py 输出:真实的 pre_shot_captionshot_cut_framesref_imagesID_2_shot,并移除未定义变量。
  2. 修正 wan_video_new.py 的拼写与变量错误,确保 prompt 编码与 shot 索引能正常生成。
  3. 统一 cross-attn 接口:CrossAttention.forward 增加 attn_mask 或在 DiTBlock 中移除传参。
  4. 修正 shot mask 构造中的未定义变量与断言语句。
  5. shot_rope 的 shot 帧数与 dataset 对齐,去掉硬编码。
  6. wan_parser 补充 --shot_rope,并清理绝对路径依赖。
  7. 补齐运行依赖(torch/accelerate/peft/decord/modelscope/flash-attn 等)并在 README/requirements 中显式声明。