Nayefleb commited on
Commit
bc4cd2c
·
verified ·
1 Parent(s): 04b2815

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +277 -288
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # =========================================================
2
- # ZERO GPU PATCHED + ALL TASKS ENABLED + QWEN FIX
 
3
  # Hugging Face Spaces Compatible
4
  # =========================================================
5
 
@@ -34,7 +35,11 @@ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
34
  # LOGIN
35
  # =========================================================
36
 
37
- from huggingface_hub import login, snapshot_download
 
 
 
 
38
 
39
  HF_TOKEN = os.getenv("HF_TOKEN")
40
 
@@ -48,9 +53,8 @@ if HF_TOKEN:
48
  from safetensors.torch import load_file
49
 
50
  from transformers import (
51
- AutoProcessor,
52
- Qwen2_5_VLForConditionalGeneration,
53
  set_seed,
 
54
  )
55
 
56
  from transformers.utils import is_flash_attn_2_available
@@ -131,27 +135,17 @@ snapshot_download(
131
  token=HF_TOKEN,
132
  )
133
 
134
- # =========================================================
135
- # DOWNLOAD QWEN 2.5 VL
136
- # =========================================================
137
-
138
- QWEN_VL_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
139
-
140
- QWEN_VL_PATH = MODEL_CACHE_DIR / "Qwen2.5-VL-7B-Instruct"
141
-
142
- snapshot_download(
143
- repo_id=QWEN_VL_REPO,
144
- local_dir=str(QWEN_VL_PATH),
145
- local_dir_use_symlinks=False,
146
- token=HF_TOKEN,
147
- )
148
-
149
  DEFAULT_MODEL_PATH = str(
150
  MODEL_CACHE_DIR / "Lance_3B_Video"
151
  )
152
 
153
  print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH)
154
- print("QWEN_VL_PATH =", QWEN_VL_PATH)
 
 
 
 
 
155
 
156
  # =========================================================
157
  # DEFAULTS
@@ -255,35 +249,7 @@ class LancePipeline:
255
  if not torch.cuda.is_available():
256
  raise RuntimeError("CUDA unavailable")
257
 
258
- print("Initializing Lance pipeline...")
259
-
260
- # =====================================================
261
- # QWEN VL LOAD FIX
262
- # =====================================================
263
-
264
- print("Loading Qwen2.5 VL Processor...")
265
-
266
- self.qwen_processor = AutoProcessor.from_pretrained(
267
- str(QWEN_VL_PATH),
268
- trust_remote_code=True,
269
- token=HF_TOKEN,
270
- )
271
-
272
- print("Loading Qwen2.5 VL Model...")
273
-
274
- self.qwen_vl_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
275
- str(QWEN_VL_PATH),
276
- torch_dtype=torch.bfloat16,
277
- device_map="auto",
278
- trust_remote_code=True,
279
- token=HF_TOKEN,
280
- )
281
-
282
- print("Qwen2.5 VL loaded successfully")
283
-
284
- # =====================================================
285
- # LANCE CONFIG
286
- # =====================================================
287
 
288
  model_args = ModelArguments(
289
  model_path=DEFAULT_MODEL_PATH,
@@ -298,10 +264,10 @@ class LancePipeline:
298
  )
299
 
300
  # =====================================================
301
- # FORCE CORRECT VIT PATH
302
  # =====================================================
303
 
304
- model_args.vit_path = str(QWEN_VL_PATH)
305
 
306
  data_args = DataArguments()
307
 
@@ -334,6 +300,10 @@ class LancePipeline:
334
 
335
  set_seed(42)
336
 
 
 
 
 
337
  llm_config = Qwen2Config.from_json_file(
338
  str(Path(model_args.model_path) / "llm_config.json")
339
  )
@@ -341,52 +311,65 @@ class LancePipeline:
341
  language_model = Qwen2ForCausalLM(llm_config)
342
 
343
  # =====================================================
344
- # FIXED VIT CONFIG
345
  # =====================================================
346
 
347
- print("Loading VIT config from:", model_args.vit_path)
348
-
349
- from transformers import AutoConfig
350
 
351
- vit_config = AutoConfig.from_pretrained(
352
- model_args.vit_path,
353
- trust_remote_code=True,
354
  token=HF_TOKEN,
 
355
  )
356
 
 
 
357
  vit_config._attn_implementation = "eager"
358
 
 
 
359
  vit_model = Qwen2_5_VisionTransformerPretrainedModel(
360
  vit_config
361
  )
362
 
363
- vit_weights_path = Path(model_args.vit_path) / "model.safetensors"
 
 
364
 
365
- if vit_weights_path.exists():
366
 
367
- print("Loading VIT weights:", vit_weights_path)
 
 
 
 
368
 
369
- vit_weights = load_file(
370
- str(vit_weights_path)
371
- )
372
 
373
- missing, unexpected = vit_model.load_state_dict(
374
- vit_weights,
375
- strict=False
376
- )
377
 
378
- print("Missing keys:", len(missing))
379
- print("Unexpected keys:", len(unexpected))
 
 
380
 
381
- clean_memory(vit_weights)
 
382
 
383
- else:
384
- print("WARNING: model.safetensors not found")
 
 
 
385
 
386
  vae_model = WanVideoVAE()
387
 
388
  vae_config = deepcopy(vae_model.vae_config)
389
 
 
 
 
 
390
  config = LanceConfig(
391
  visual_gen=True,
392
  visual_und=True,
@@ -410,6 +393,8 @@ class LancePipeline:
410
  training_args=inference_args,
411
  )
412
 
 
 
413
  model = model.to(
414
  device="cuda",
415
  dtype=torch.bfloat16,
@@ -448,6 +433,10 @@ class LancePipeline:
448
 
449
  print("Lance initialized successfully")
450
 
 
 
 
 
451
  def generate(
452
  self,
453
  task,
@@ -465,231 +454,245 @@ class LancePipeline:
465
  cfg_text_scale,
466
  ):
467
 
468
- task = normalize_task(task)
469
 
470
- actual_seed = normalize_seed(int(seed))
471
 
472
- set_seed(actual_seed)
473
 
474
- save_dir = RESULTS_ROOT / str(time.time())
475
- save_dir.mkdir(parents=True, exist_ok=True)
476
 
477
- inference_args = deepcopy(
478
- self.base_inference_args
479
- )
480
 
481
- inference_args.video_height = int(height)
482
- inference_args.video_width = int(width)
483
- inference_args.num_frames = int(num_frames)
484
 
485
- inference_args.validation_num_timesteps = (
486
- validation_num_timesteps
487
- )
488
 
489
- inference_args.validation_timestep_shift = (
490
- validation_timestep_shift
491
- )
492
 
493
- inference_args.task = task
 
 
494
 
495
- prompt_file = TMP_INPUT_DIR / f"prompt_{time.time()}.json"
496
 
497
- # =====================================================
498
- # PAYLOADS
499
- # =====================================================
500
 
501
- if task == TASK_T2V:
 
 
502
 
503
- payload = {
504
- "000000.mp4": prompt
505
- }
506
 
507
- elif task == TASK_T2I:
 
 
508
 
509
- payload = {
510
- "000000.png": prompt
511
- }
512
 
513
- elif task == TASK_IMAGE_EDIT:
 
 
514
 
515
- payload = {
516
- "000000": {
517
- "interleave_array": [
518
- input_image,
519
- [prompt, ""]
520
- ],
521
- "element_dtype_array": [
522
- "image",
523
- "text"
524
- ],
525
- "istarget_in_interleave": [
526
- 0,
527
- 1
528
- ],
 
 
 
529
  }
530
- }
531
-
532
- elif task == TASK_VIDEO_EDIT:
533
-
534
- payload = {
535
- "000000": {
536
- "interleave_array": [
537
- input_video,
538
- [prompt, ""]
539
- ],
540
- "element_dtype_array": [
541
- "video",
542
- "text"
543
- ],
544
- "istarget_in_interleave": [
545
- 0,
546
- 1
547
- ],
548
  }
549
- }
550
-
551
- elif task == TASK_X2T_IMAGE:
552
-
553
- payload = {
554
- "000000": {
555
- "interleave_array": [
556
- input_image,
557
- [
558
- "Describe the image",
559
- question,
560
- ""
561
- ]
562
- ],
563
- "element_dtype_array": [
564
- "image",
565
- "text"
566
- ],
567
- "istarget_in_interleave": [
568
- 0,
569
- 1
570
- ],
571
  }
572
- }
573
-
574
- elif task == TASK_X2T_VIDEO:
575
-
576
- payload = {
577
- "000000": {
578
- "interleave_array": [
579
- input_video,
580
- [
581
- "Describe the video",
582
- question,
583
- ""
584
- ]
585
- ],
586
- "element_dtype_array": [
587
- "video",
588
- "text"
589
- ],
590
- "istarget_in_interleave": [
591
- 0,
592
- 1
593
- ],
594
  }
595
- }
596
 
597
- else:
598
 
599
- return (
600
- None,
601
- None,
602
- "",
603
- "Invalid task",
604
- "",
 
 
 
 
 
 
 
605
  )
606
 
607
- with open(prompt_file, "w") as f:
608
- json.dump(payload, f)
609
-
610
- dataset_config = DataConfig.from_yaml(
611
- str(prompt_file)
612
- )
613
-
614
- val_dataset = ValidationDataset(
615
- jsonl_path=str(prompt_file),
616
- tokenizer=self.tokenizer,
617
- data_args=self.base_data_args,
618
- model_args=self.base_model_args,
619
- training_args=inference_args,
620
- new_token_ids=self.new_token_ids,
621
- dataset_config=dataset_config,
622
- local_rank=0,
623
- world_size=1,
624
- )
625
-
626
- val_data_cpu = simple_custom_collate(
627
- [val_dataset[0]]
628
- )
629
-
630
- validate_on_fixed_batch(
631
- fsdp_model=self.model,
632
- vae_model=self.vae_model,
633
- tokenizer=self.tokenizer,
634
- val_data_cpu=val_data_cpu,
635
- training_args=inference_args,
636
- model_args=self.base_model_args,
637
- inference_args=inference_args,
638
- new_token_ids=self.new_token_ids,
639
- image_token_id=self.image_token_id,
640
- device="cuda",
641
- save_source_video=False,
642
- save_path_gen=str(save_dir),
643
- save_path_gt="",
644
- )
645
-
646
- clean_memory()
647
-
648
- gc.collect()
649
-
650
- torch.cuda.empty_cache()
651
-
652
- videos = list(save_dir.glob("*.mp4"))
653
- images = list(save_dir.glob("*.png"))
654
-
655
- if len(videos) > 0:
656
 
657
- return (
658
- str(videos[0]),
659
- None,
660
- "",
661
- "Success",
662
- "",
663
  )
664
 
665
- if len(images) > 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
666
 
667
  return (
668
  None,
669
- str(images[0]),
670
  "",
671
- "Success",
672
  "",
673
  )
674
 
675
- if task in [TASK_X2T_IMAGE, TASK_X2T_VIDEO]:
 
 
676
 
677
  return (
678
  None,
679
  None,
680
- "Understanding complete",
681
- "Success",
682
  "",
 
 
683
  )
684
 
685
- return (
686
- None,
687
- None,
688
- "",
689
- "No output generated",
690
- "",
691
- )
692
-
693
  # =========================================================
694
  # GLOBAL
695
  # =========================================================
@@ -717,37 +720,23 @@ def run_task(
717
  cfg_text_scale,
718
  ):
719
 
720
- try:
721
-
722
- PIPELINE.initialize()
723
-
724
- return PIPELINE.generate(
725
- task=task,
726
- prompt=prompt,
727
- input_image=input_image,
728
- input_video=input_video,
729
- question=question,
730
- height=height,
731
- width=width,
732
- num_frames=num_frames,
733
- seed=seed,
734
- resolution=resolution,
735
- validation_num_timesteps=validation_num_timesteps,
736
- validation_timestep_shift=validation_timestep_shift,
737
- cfg_text_scale=cfg_text_scale,
738
- )
739
-
740
- except Exception as e:
741
-
742
- traceback_str = traceback.format_exc()
743
-
744
- return (
745
- None,
746
- None,
747
- "",
748
- f"ERROR: {str(e)}",
749
- traceback_str,
750
- )
751
 
752
  # =========================================================
753
  # UI
 
1
  # =========================================================
2
+ # ZERO GPU PATCHED + ALL TASKS ENABLED
3
+ # Qwen2.5-VL FIXED VERSION
4
  # Hugging Face Spaces Compatible
5
  # =========================================================
6
 
 
35
  # LOGIN
36
  # =========================================================
37
 
38
+ from huggingface_hub import (
39
+ login,
40
+ snapshot_download,
41
+ hf_hub_download,
42
+ )
43
 
44
  HF_TOKEN = os.getenv("HF_TOKEN")
45
 
 
53
  from safetensors.torch import load_file
54
 
55
  from transformers import (
 
 
56
  set_seed,
57
+ AutoConfig,
58
  )
59
 
60
  from transformers.utils import is_flash_attn_2_available
 
135
  token=HF_TOKEN,
136
  )
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  DEFAULT_MODEL_PATH = str(
139
  MODEL_CACHE_DIR / "Lance_3B_Video"
140
  )
141
 
142
  print("DEFAULT_MODEL_PATH =", DEFAULT_MODEL_PATH)
143
+
144
+ # =========================================================
145
+ # QWEN VL
146
+ # =========================================================
147
+
148
+ QWEN_VL_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
149
 
150
  # =========================================================
151
  # DEFAULTS
 
249
  if not torch.cuda.is_available():
250
  raise RuntimeError("CUDA unavailable")
251
 
252
+ print("Initializing Lance...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  model_args = ModelArguments(
255
  model_path=DEFAULT_MODEL_PATH,
 
264
  )
265
 
266
  # =====================================================
267
+ # IMPORTANT FIX
268
  # =====================================================
269
 
270
+ model_args.vit_path = QWEN_VL_REPO
271
 
272
  data_args = DataArguments()
273
 
 
300
 
301
  set_seed(42)
302
 
303
+ # =====================================================
304
+ # LLM
305
+ # =====================================================
306
+
307
  llm_config = Qwen2Config.from_json_file(
308
  str(Path(model_args.model_path) / "llm_config.json")
309
  )
 
311
  language_model = Qwen2ForCausalLM(llm_config)
312
 
313
  # =====================================================
314
+ # FIXED QWEN2.5-VL LOADING
315
  # =====================================================
316
 
317
+ print("Loading Qwen2.5-VL config...")
 
 
318
 
319
+ full_qwen_config = AutoConfig.from_pretrained(
320
+ QWEN_VL_REPO,
 
321
  token=HF_TOKEN,
322
+ trust_remote_code=True,
323
  )
324
 
325
+ vit_config = full_qwen_config.vision_config
326
+
327
  vit_config._attn_implementation = "eager"
328
 
329
+ print("Creating vision transformer...")
330
+
331
  vit_model = Qwen2_5_VisionTransformerPretrainedModel(
332
  vit_config
333
  )
334
 
335
+ # =====================================================
336
+ # LOAD WEIGHTS
337
+ # =====================================================
338
 
339
+ print("Downloading Qwen weights...")
340
 
341
+ vit_weights_path = hf_hub_download(
342
+ repo_id=QWEN_VL_REPO,
343
+ filename="model.safetensors",
344
+ token=HF_TOKEN,
345
+ )
346
 
347
+ print("Loading VIT weights...")
 
 
348
 
349
+ vit_weights = load_file(vit_weights_path)
 
 
 
350
 
351
+ missing, unexpected = vit_model.load_state_dict(
352
+ vit_weights,
353
+ strict=False,
354
+ )
355
 
356
+ print("Missing keys:", len(missing))
357
+ print("Unexpected keys:", len(unexpected))
358
 
359
+ clean_memory(vit_weights)
360
+
361
+ # =====================================================
362
+ # VAE
363
+ # =====================================================
364
 
365
  vae_model = WanVideoVAE()
366
 
367
  vae_config = deepcopy(vae_model.vae_config)
368
 
369
+ # =====================================================
370
+ # CONFIG
371
+ # =====================================================
372
+
373
  config = LanceConfig(
374
  visual_gen=True,
375
  visual_und=True,
 
393
  training_args=inference_args,
394
  )
395
 
396
+ print("Moving model to CUDA...")
397
+
398
  model = model.to(
399
  device="cuda",
400
  dtype=torch.bfloat16,
 
433
 
434
  print("Lance initialized successfully")
435
 
436
+ # =========================================================
437
+ # GENERATE
438
+ # =========================================================
439
+
440
  def generate(
441
  self,
442
  task,
 
454
  cfg_text_scale,
455
  ):
456
 
457
+ try:
458
 
459
+ task = normalize_task(task)
460
 
461
+ actual_seed = normalize_seed(int(seed))
462
 
463
+ set_seed(actual_seed)
 
464
 
465
+ save_dir = RESULTS_ROOT / str(time.time())
466
+ save_dir.mkdir(parents=True, exist_ok=True)
 
467
 
468
+ inference_args = deepcopy(
469
+ self.base_inference_args
470
+ )
471
 
472
+ inference_args.video_height = int(height)
473
+ inference_args.video_width = int(width)
474
+ inference_args.num_frames = int(num_frames)
475
 
476
+ inference_args.validation_num_timesteps = (
477
+ validation_num_timesteps
478
+ )
479
 
480
+ inference_args.validation_timestep_shift = (
481
+ validation_timestep_shift
482
+ )
483
 
484
+ inference_args.task = task
485
 
486
+ prompt_file = TMP_INPUT_DIR / "prompt.json"
 
 
487
 
488
+ # =====================================================
489
+ # PAYLOADS
490
+ # =====================================================
491
 
492
+ if task == TASK_T2V:
 
 
493
 
494
+ payload = {
495
+ "000000.mp4": prompt
496
+ }
497
 
498
+ elif task == TASK_T2I:
 
 
499
 
500
+ payload = {
501
+ "000000.png": prompt
502
+ }
503
 
504
+ elif task == TASK_IMAGE_EDIT:
505
+
506
+ payload = {
507
+ "000000": {
508
+ "interleave_array": [
509
+ input_image,
510
+ [prompt, ""]
511
+ ],
512
+ "element_dtype_array": [
513
+ "image",
514
+ "text"
515
+ ],
516
+ "istarget_in_interleave": [
517
+ 0,
518
+ 1
519
+ ],
520
+ }
521
  }
522
+
523
+ elif task == TASK_VIDEO_EDIT:
524
+
525
+ payload = {
526
+ "000000": {
527
+ "interleave_array": [
528
+ input_video,
529
+ [prompt, ""]
530
+ ],
531
+ "element_dtype_array": [
532
+ "video",
533
+ "text"
534
+ ],
535
+ "istarget_in_interleave": [
536
+ 0,
537
+ 1
538
+ ],
539
+ }
540
  }
541
+
542
+ elif task == TASK_X2T_IMAGE:
543
+
544
+ payload = {
545
+ "000000": {
546
+ "interleave_array": [
547
+ input_image,
548
+ [
549
+ "Describe the image",
550
+ question,
551
+ ""
552
+ ]
553
+ ],
554
+ "element_dtype_array": [
555
+ "image",
556
+ "text"
557
+ ],
558
+ "istarget_in_interleave": [
559
+ 0,
560
+ 1
561
+ ],
562
+ }
563
  }
564
+
565
+ elif task == TASK_X2T_VIDEO:
566
+
567
+ payload = {
568
+ "000000": {
569
+ "interleave_array": [
570
+ input_video,
571
+ [
572
+ "Describe the video",
573
+ question,
574
+ ""
575
+ ]
576
+ ],
577
+ "element_dtype_array": [
578
+ "video",
579
+ "text"
580
+ ],
581
+ "istarget_in_interleave": [
582
+ 0,
583
+ 1
584
+ ],
585
+ }
586
  }
 
587
 
588
+ else:
589
 
590
+ return (
591
+ None,
592
+ None,
593
+ "",
594
+ "Invalid task",
595
+ "",
596
+ )
597
+
598
+ with open(prompt_file, "w") as f:
599
+ json.dump(payload, f)
600
+
601
+ dataset_config = DataConfig.from_yaml(
602
+ str(prompt_file)
603
  )
604
 
605
+ val_dataset = ValidationDataset(
606
+ jsonl_path=str(prompt_file),
607
+ tokenizer=self.tokenizer,
608
+ data_args=self.base_data_args,
609
+ model_args=self.base_model_args,
610
+ training_args=inference_args,
611
+ new_token_ids=self.new_token_ids,
612
+ dataset_config=dataset_config,
613
+ local_rank=0,
614
+ world_size=1,
615
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
+ val_data_cpu = simple_custom_collate(
618
+ [val_dataset[0]]
 
 
 
 
619
  )
620
 
621
+ validate_on_fixed_batch(
622
+ fsdp_model=self.model,
623
+ vae_model=self.vae_model,
624
+ tokenizer=self.tokenizer,
625
+ val_data_cpu=val_data_cpu,
626
+ training_args=inference_args,
627
+ model_args=self.base_model_args,
628
+ inference_args=inference_args,
629
+ new_token_ids=self.new_token_ids,
630
+ image_token_id=self.image_token_id,
631
+ device="cuda",
632
+ save_source_video=False,
633
+ save_path_gen=str(save_dir),
634
+ save_path_gt="",
635
+ )
636
+
637
+ clean_memory()
638
+
639
+ gc.collect()
640
+
641
+ torch.cuda.empty_cache()
642
+
643
+ videos = list(save_dir.glob("*.mp4"))
644
+ images = list(save_dir.glob("*.png"))
645
+
646
+ if len(videos) > 0:
647
+
648
+ return (
649
+ str(videos[0]),
650
+ None,
651
+ "",
652
+ "Success",
653
+ "",
654
+ )
655
+
656
+ if len(images) > 0:
657
+
658
+ return (
659
+ None,
660
+ str(images[0]),
661
+ "",
662
+ "Success",
663
+ "",
664
+ )
665
+
666
+ if task in [TASK_X2T_IMAGE, TASK_X2T_VIDEO]:
667
+
668
+ return (
669
+ None,
670
+ None,
671
+ "Understanding complete",
672
+ "Success",
673
+ "",
674
+ )
675
 
676
  return (
677
  None,
678
+ None,
679
  "",
680
+ "No output generated",
681
  "",
682
  )
683
 
684
+ except Exception as e:
685
+
686
+ traceback.print_exc()
687
 
688
  return (
689
  None,
690
  None,
 
 
691
  "",
692
+ f"ERROR: {str(e)}",
693
+ traceback.format_exc(),
694
  )
695
 
 
 
 
 
 
 
 
 
696
  # =========================================================
697
  # GLOBAL
698
  # =========================================================
 
720
  cfg_text_scale,
721
  ):
722
 
723
+ PIPELINE.initialize()
724
+
725
+ return PIPELINE.generate(
726
+ task=task,
727
+ prompt=prompt,
728
+ input_image=input_image,
729
+ input_video=input_video,
730
+ question=question,
731
+ height=height,
732
+ width=width,
733
+ num_frames=num_frames,
734
+ seed=seed,
735
+ resolution=resolution,
736
+ validation_num_timesteps=validation_num_timesteps,
737
+ validation_timestep_shift=validation_timestep_shift,
738
+ cfg_text_scale=cfg_text_scale,
739
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
740
 
741
  # =========================================================
742
  # UI