OmAlve commited on
Commit
07293fd
·
verified ·
1 Parent(s): 8484934

Upload inference_example.py

Browse files
Files changed (1) hide show
  1. inference_example.py +65 -0
inference_example.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Example inference script for the fine-tuned Reading Steiner model.
3
+ Replace with your actual LoRA adapter path after training.
4
+ """
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from peft import PeftModel
7
+ import torch
8
+
9
+ # Load base model
10
+ base_model = AutoModelForCausalLM.from_pretrained(
11
+ "Qwen/Qwen3.5-2B",
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto",
14
+ trust_remote_code=True,
15
+ )
16
+
17
+ # Load LoRA adapter (replace with your actual adapter path)
18
+ # model = PeftModel.from_pretrained(base_model, "OmAlve/reading-steiner-qwen3.5-2b")
19
+ # Optionally merge adapter into base model for faster inference:
20
+ # model = model.merge_and_unload()
21
+ model = base_model # Replace with above after training
22
+
23
+ # Tokenizer
24
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3.5-2B", trust_remote_code=True)
25
+
26
+ # Example input
27
+ messages = [
28
+ {"role": "system", "content": (
29
+ "You are Reading Steiner, a web content extraction model. "
30
+ "Given a webpage split into indexed blocks, identify which blocks contain the main content. "
31
+ "Output indices as a Python list of [start, end] intervals."
32
+ )},
33
+ {"role": "user", "content": (
34
+ "URL: https://example.com\n"
35
+ "Title: Example Page\n"
36
+ "Blocks:\n"
37
+ '[1] <div class="nav">Home | About | Contact</div>\n'
38
+ '[2] <div class="sidebar">Trending</div>\n'
39
+ '[3] <p>This is the main article content.</p>\n'
40
+ '[4] <p>More content here.</p>\n'
41
+ '[5] <div class="footer">Copyright 2025</div>'
42
+ )},
43
+ ]
44
+
45
+ inputs = tokenizer.apply_chat_template(
46
+ messages,
47
+ tokenize=True,
48
+ return_tensors="pt",
49
+ add_generation_prompt=True,
50
+ )
51
+ inputs = inputs.to(model.device)
52
+
53
+ # Generate
54
+ with torch.no_grad():
55
+ outputs = model.generate(
56
+ inputs,
57
+ max_new_tokens=128,
58
+ temperature=0.1,
59
+ do_sample=False,
60
+ pad_token_id=tokenizer.pad_token_id,
61
+ )
62
+
63
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
+ print("=== Generated Output ===")
65
+ print(result)