| import torch |
| from PIL import Image |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from pipeline import Text2SignPipeline |
|
|
| def generate_and_save(prompt, checkpoint_path, output_path, device="cuda"): |
| pipeline = Text2SignPipeline.from_pretrained(checkpoint_path, device=device) |
| with torch.no_grad(): |
| video_frames = pipeline(prompt, num_inference_steps=50, guidance_scale=7.5)[0] |
| |
| fig, axes = plt.subplots(1, len(video_frames), figsize=(2*len(video_frames), 2)) |
| for i, frame in enumerate(video_frames): |
| axes[i].imshow(frame) |
| axes[i].axis('off') |
| plt.tight_layout() |
| plt.savefig(output_path) |
| print(f"Saved filmstrip to {output_path}") |
|
|
| if __name__ == "__main__": |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--prompt', type=str, required=True, help='Text prompt to generate sign language video') |
| parser.add_argument('--checkpoint', type=str, default='checkpoint_epoch_70.pt', help='Path to model checkpoint') |
| parser.add_argument('--output', type=str, default='generated_filmstrip.png', help='Output image path') |
| parser.add_argument('--device', type=str, default='cuda', help='Device: cuda or cpu') |
| args = parser.parse_args() |
| generate_and_save(args.prompt, args.checkpoint, args.output, args.device) |
|
|