my-probed-model / inference.py
ivanenclonar's picture
Initial commit
c89cf7e verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import lightgbm as lgb
import numpy as np
class LinearProbeWrapper:
def __init__(self, model_name):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
self.linear_probe = lgb.Booster(model_file='lgb_layer_1.txt')
def __call__(self, inputs):
prompt = inputs.get("inputs", "")
inputs_tensors = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
# Get hidden states
outputs = self.model(**inputs_tensors)
# Access the hidden states (residual stream) - adjust index as needed
layer_1_hidden_state = outputs.hidden_states[1]
layer_1_hidden_state_np = layer_1_hidden_state.cpu().numpy().copy()
#Predict given hidden state
y_pred = self.linear_probe.predict(layer_1_hidden_state_np[0]);
y_pred_class = np.argmax(y_pred)
# Generate text
generation_output = self.model.generate(
**inputs_tensors, max_length=50, num_return_sequences=1
)
generated_text = self.tokenizer.decode(generation_output[0], skip_special_tokens=True)
return {
"generated_text": generated_text,
"probe_output": y_pred[y_pred_class]
}
def model_fn():
return LinearProbeWrapper("mistralai/Mistral-7B-Instruct-v0.2") # Replace with desired model