File size: 1,967 Bytes
d443994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""Simple inference script to test the HuggingFace LangFlow model."""

import argparse
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer


def main():
    parser = argparse.ArgumentParser(description="Generate samples with LangFlow")
    parser.add_argument(
        "--model_path", type=str, default="hf_release/model_weights",
        help="Path to the HuggingFace model directory")
    parser.add_argument(
        "--num_samples", type=int, default=5,
        help="Number of samples to generate")
    parser.add_argument(
        "--num_steps", type=int, default=128,
        help="Number of denoising steps")
    parser.add_argument(
        "--seq_length", type=int, default=1024,
        help="Sequence length")
    parser.add_argument(
        "--seed", type=int, default=42,
        help="Random seed")
    args = parser.parse_args()

    # Set seed for reproducibility
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    model = AutoModelForMaskedLM.from_pretrained(
        args.model_path,
        trust_remote_code=True
    )
    model = model.to(device)
    model.eval()

    print(f"\nGenerating {args.num_samples} samples with {args.num_steps} steps...")
    with torch.no_grad():
        samples = model.generate_samples(
            num_samples=args.num_samples,
            seq_length=args.seq_length,
            num_steps=args.num_steps,
            device=device
        )

    texts = tokenizer.batch_decode(samples, skip_special_tokens=True)
    for i, text in enumerate(texts):
        print(f"\n--- Sample {i+1} ---")
        # Print first 500 characters to keep output manageable
        print(text[:500] + ("..." if len(text) > 500 else ""))


if __name__ == "__main__":
    main()