import torch from transformers import AutoModelForCausalLM, AutoTokenizer import lightgbm as lgb import numpy as np class LinearProbeWrapper: def __init__(self, model_name): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True) self.linear_probe = lgb.Booster(model_file='lgb_layer_1.txt') def __call__(self, inputs): prompt = inputs.get("inputs", "") inputs_tensors = self.tokenizer(prompt, return_tensors="pt") with torch.no_grad(): # Get hidden states outputs = self.model(**inputs_tensors) # Access the hidden states (residual stream) - adjust index as needed layer_1_hidden_state = outputs.hidden_states[1] layer_1_hidden_state_np = layer_1_hidden_state.cpu().numpy().copy() #Predict given hidden state y_pred = self.linear_probe.predict(layer_1_hidden_state_np[0]); y_pred_class = np.argmax(y_pred) # Generate text generation_output = self.model.generate( **inputs_tensors, max_length=50, num_return_sequences=1 ) generated_text = self.tokenizer.decode(generation_output[0], skip_special_tokens=True) return { "generated_text": generated_text, "probe_output": y_pred[y_pred_class] } def model_fn(): return LinearProbeWrapper("mistralai/Mistral-7B-Instruct-v0.2") # Replace with desired model