File size: 1,561 Bytes
c89cf7e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import lightgbm as lgb
import numpy as np
class LinearProbeWrapper:
def __init__(self, model_name):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
self.linear_probe = lgb.Booster(model_file='lgb_layer_1.txt')
def __call__(self, inputs):
prompt = inputs.get("inputs", "")
inputs_tensors = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
# Get hidden states
outputs = self.model(**inputs_tensors)
# Access the hidden states (residual stream) - adjust index as needed
layer_1_hidden_state = outputs.hidden_states[1]
layer_1_hidden_state_np = layer_1_hidden_state.cpu().numpy().copy()
#Predict given hidden state
y_pred = self.linear_probe.predict(layer_1_hidden_state_np[0]);
y_pred_class = np.argmax(y_pred)
# Generate text
generation_output = self.model.generate(
**inputs_tensors, max_length=50, num_return_sequences=1
)
generated_text = self.tokenizer.decode(generation_output[0], skip_special_tokens=True)
return {
"generated_text": generated_text,
"probe_output": y_pred[y_pred_class]
}
def model_fn():
return LinearProbeWrapper("mistralai/Mistral-7B-Instruct-v0.2") # Replace with desired model
|