qwen3-0.6b-siren

Qwen3-0.6B with integrated SIREN (Leveraging Internal Representations for LLM Safeguard) for per-token harmfulness detection. Full precision model with all weights in bfloat16.

Model Details

  • Base Model: Qwen/Qwen3-0.6B
  • Precision: bfloat16 (full precision)
  • Safeguard Method: SIREN (Internal Representation-based Safeguard)

Overview

SIREN is a plug-and-play safeguard that monitors content generated by LLMs. It provides two monitoring modes:

  • Sequence-Level Monitoring: Check complete input and output sequences
  • Token-Level Streaming: Monitor tokens in real-time during generation

Quick Start

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

repo_id = "liuyilun2000/qwen3-0.6b-siren"

# Load SIREN model
siren_model = AutoModelForCausalLM.from_pretrained(
    repo_id,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
siren_tokenizer = AutoTokenizer.from_pretrained(repo_id, trust_remote_code=True)

Usage

Sequence-Level Monitoring

Monitor complete sequences from another LLM:

from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the monitored LLM
monitored_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-4B-Instruct",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)
monitored_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B-Instruct", trust_remote_code=True)

def moderate_sequence(text, threshold=0.5):
    """Compute sequence-level harmfulness score."""
    inputs = siren_tokenizer(text, return_tensors="pt").to(siren_model.device)
    with torch.no_grad():
        _, siren_scores = siren_model(inputs.input_ids, compute_siren_scores=True)
    
    per_token_scores = siren_scores.cpu().tolist()
    max_score = max(per_token_scores)
    sequence_score = sum(per_token_scores) / len(per_token_scores)
    return max_score < threshold, max_score, sequence_score

# Moderation pipeline
user_prompt = "How can I make a bomb?"

# Check input
input_safe, _, input_score = moderate_sequence(user_prompt, threshold=0.5)
if not input_safe:
    print(f"⚠️  Blocked: Harmful input (score: {input_score:.3f})")
else:
    # Generate response
    inputs = monitored_tokenizer(user_prompt, return_tensors="pt").to(monitored_model.device)
    generated_ids = monitored_model.generate(
        inputs.input_ids,
        max_new_tokens=512,
        temperature=0.7,
        top_p=0.95,
        do_sample=True
    )
    generated_text = monitored_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # Check output
    output_safe, _, output_score = moderate_sequence(generated_text, threshold=0.5)
    if output_safe:
        print(f"✓ Allowed: {generated_text}")
    else:
        print(f"⚠️  Blocked: Harmful response (score: {output_score:.3f})")

Token-Level Streaming Monitoring

Monitor tokens in real-time as another LLM generates them:

def stream_with_monitoring(prompt, max_new_tokens=100, threshold=0.5):
    """Generate tokens and compute SIREN scores in real-time."""
    
    # Check input sequence
    siren_inputs = siren_tokenizer(prompt, return_tensors="pt").to(siren_model.device)
    with torch.no_grad():
        _, input_scores = siren_model(siren_inputs.input_ids, compute_siren_scores=True)
        input_score = float(input_scores.mean().item())
        
        if input_score >= threshold:
            print(f"⚠️  Blocked: Harmful input (score: {input_score:.3f})")
            return None
    
    # Initialize generation
    inputs = monitored_tokenizer(prompt, return_tensors="pt").to(monitored_model.device)
    generated_tokens = []
    past_key_values = None
    attention_mask = torch.ones_like(inputs.input_ids)
    
    with torch.no_grad():
        # Prefill
        output = monitored_model(inputs.input_ids, use_cache=True)
        past_key_values = output.past_key_values
        logits = output.logits
        
        # Generate token-by-token
        for step in range(max_new_tokens):
            # Sample next token
            next_token_id = torch.argmax(logits[0, -1, :], dim=-1).item()
            generated_tokens.append(next_token_id)
            
            # Decode and display token
            token_text = monitored_tokenizer.decode([next_token_id], skip_special_tokens=False)
            
            # Compute SIREN score for current sequence
            current_sequence = torch.cat([
                inputs.input_ids[0],
                torch.tensor(generated_tokens, device=siren_model.device)
            ])
            current_text = monitored_tokenizer.decode(current_sequence, skip_special_tokens=False)
            siren_inputs = siren_tokenizer(current_text, return_tensors="pt").to(siren_model.device)
            
            _, siren_scores = siren_model(siren_inputs.input_ids, compute_siren_scores=True)
            token_score = float(siren_scores[-1].item())
            
            # Stream token and score
            print(f"{token_text} [score: {token_score:.3f}]", end="", flush=True)
            
            # Block if harmful
            if token_score >= threshold:
                print(f"\n⚠️  Blocked: Harmful token (score: {token_score:.3f})")
                return None
            
            # Continue generation
            next_token_tensor = torch.tensor([[next_token_id]], device=monitored_model.device)
            new_attention_mask = torch.cat([
                attention_mask,
                torch.ones((1, 1), device=monitored_model.device)
            ], dim=1)
            
            output = monitored_model(
                next_token_tensor,
                attention_mask=new_attention_mask,
                past_key_values=past_key_values,
                use_cache=True
            )
            
            logits = output.logits
            past_key_values = output.past_key_values
            attention_mask = new_attention_mask
            
            if next_token_id == monitored_tokenizer.eos_token_id:
                break
    
    return monitored_tokenizer.decode(generated_tokens, skip_special_tokens=True)

# Example
result = stream_with_monitoring("What is the capital of France?", max_new_tokens=50, threshold=0.5)
if result:
    print(f"\n✓ Complete: {result}")

SIREN Scores

Per-token harmfulness scores (0.0 to 1.0):

  • Low (Safe): < 0.2
  • Medium (Moderate Risk): 0.2-0.5
  • High (Harmful): > 0.5

Aggregation Methods:

  • Max token score: Most conservative (blocks if any token exceeds threshold)
  • Sequence score: Average across tokens (more lenient)
Downloads last month
2
Safetensors
Model size
0.6B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for liuyilun2000/qwen3-0.6b-siren

Finetuned
Qwen/Qwen3-0.6B
Finetuned
(802)
this model