qwen3-0.6b-siren
Qwen3-0.6B with integrated SIREN (Leveraging Internal Representations for LLM Safeguard) for per-token harmfulness detection. Full precision model with all weights in bfloat16.
Model Details
- Base Model: Qwen/Qwen3-0.6B
- Precision: bfloat16 (full precision)
- Safeguard Method: SIREN (Internal Representation-based Safeguard)
Overview
SIREN is a plug-and-play safeguard that monitors content generated by LLMs. It provides two monitoring modes:
- Sequence-Level Monitoring: Check complete input and output sequences
- Token-Level Streaming: Monitor tokens in real-time during generation
Quick Start
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
repo_id = "liuyilun2000/qwen3-0.6b-siren"
# Load SIREN model
siren_model = AutoModelForCausalLM.from_pretrained(
repo_id,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
siren_tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)
Usage
Sequence-Level Monitoring
Monitor complete sequences from another LLM:
from transformers import AutoModelForCausalLM, AutoTokenizer
# Load the monitored LLM
monitored_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-4B-Instruct",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto"
)
monitored_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct", trust_remote_code=True)
def moderate_sequence(text, threshold=0.5):
"""Compute sequence-level harmfulness score."""
inputs = siren_tokenizer(text, return_tensors="pt").to(siren_model.device)
with torch.no_grad():
_, siren_scores = siren_model(inputs.input_ids, compute_siren_scores=True)
per_token_scores = siren_scores.cpu().tolist()
max_score = max(per_token_scores)
sequence_score = sum(per_token_scores) / len(per_token_scores)
return max_score < threshold, max_score, sequence_score
# Moderation pipeline
user_prompt = "How can I make a bomb?"
# Check input
input_safe, _, input_score = moderate_sequence(user_prompt, threshold=0.5)
if not input_safe:
print(f"⚠️ Blocked: Harmful input (score: {input_score:.3f})")
else:
# Generate response
inputs = monitored_tokenizer(user_prompt, return_tensors="pt").to(monitored_model.device)
generated_ids = monitored_model.generate(
inputs.input_ids,
max_new_tokens=512,
temperature=0.7,
top_p=0.95,
do_sample=True
)
generated_text = monitored_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Check output
output_safe, _, output_score = moderate_sequence(generated_text, threshold=0.5)
if output_safe:
print(f"✓ Allowed: {generated_text}")
else:
print(f"⚠️ Blocked: Harmful response (score: {output_score:.3f})")
Token-Level Streaming Monitoring
Monitor tokens in real-time as another LLM generates them:
def stream_with_monitoring(prompt, max_new_tokens=100, threshold=0.5):
"""Generate tokens and compute SIREN scores in real-time."""
# Check input sequence
siren_inputs = siren_tokenizer(prompt, return_tensors="pt").to(siren_model.device)
with torch.no_grad():
_, input_scores = siren_model(siren_inputs.input_ids, compute_siren_scores=True)
input_score = float(input_scores.mean().item())
if input_score >= threshold:
print(f"⚠️ Blocked: Harmful input (score: {input_score:.3f})")
return None
# Initialize generation
inputs = monitored_tokenizer(prompt, return_tensors="pt").to(monitored_model.device)
generated_tokens = []
past_key_values = None
attention_mask = torch.ones_like(inputs.input_ids)
with torch.no_grad():
# Prefill
output = monitored_model(inputs.input_ids, use_cache=True)
past_key_values = output.past_key_values
logits = output.logits
# Generate token-by-token
for step in range(max_new_tokens):
# Sample next token
next_token_id = torch.argmax(logits[0, -1, :], dim=-1).item()
generated_tokens.append(next_token_id)
# Decode and display token
token_text = monitored_tokenizer.decode([next_token_id], skip_special_tokens=False)
# Compute SIREN score for current sequence
current_sequence = torch.cat([
inputs.input_ids[0],
torch.tensor(generated_tokens, device=siren_model.device)
])
current_text = monitored_tokenizer.decode(current_sequence, skip_special_tokens=False)
siren_inputs = siren_tokenizer(current_text, return_tensors="pt").to(siren_model.device)
_, siren_scores = siren_model(siren_inputs.input_ids, compute_siren_scores=True)
token_score = float(siren_scores[-1].item())
# Stream token and score
print(f"{token_text} [score: {token_score:.3f}]", end="", flush=True)
# Block if harmful
if token_score >= threshold:
print(f"\n⚠️ Blocked: Harmful token (score: {token_score:.3f})")
return None
# Continue generation
next_token_tensor = torch.tensor([[next_token_id]], device=monitored_model.device)
new_attention_mask = torch.cat([
attention_mask,
torch.ones((1, 1), device=monitored_model.device)
], dim=1)
output = monitored_model(
next_token_tensor,
attention_mask=new_attention_mask,
past_key_values=past_key_values,
use_cache=True
)
logits = output.logits
past_key_values = output.past_key_values
attention_mask = new_attention_mask
if next_token_id == monitored_tokenizer.eos_token_id:
break
return monitored_tokenizer.decode(generated_tokens, skip_special_tokens=True)
# Example
result = stream_with_monitoring("What is the capital of France?", max_new_tokens=50, threshold=0.5)
if result:
print(f"\n✓ Complete: {result}")
SIREN Scores
Per-token harmfulness scores (0.0 to 1.0):
- Low (Safe): < 0.2
- Medium (Moderate Risk): 0.2-0.5
- High (Harmful): > 0.5
Aggregation Methods:
- Max token score: Most conservative (blocks if any token exceeds threshold)
- Sequence score: Average across tokens (more lenient)
- Downloads last month
- 2