| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import lightgbm as lgb |
| import numpy as np |
|
|
| class LinearProbeWrapper: |
| def __init__(self, model_name): |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True) |
| self.linear_probe = lgb.Booster(model_file='lgb_layer_1.txt') |
|
|
|
|
| def __call__(self, inputs): |
| prompt = inputs.get("inputs", "") |
| inputs_tensors = self.tokenizer(prompt, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| |
| outputs = self.model(**inputs_tensors) |
| |
| layer_1_hidden_state = outputs.hidden_states[1] |
| layer_1_hidden_state_np = layer_1_hidden_state.cpu().numpy().copy() |
|
|
| |
| y_pred = self.linear_probe.predict(layer_1_hidden_state_np[0]); |
| y_pred_class = np.argmax(y_pred) |
|
|
| |
| generation_output = self.model.generate( |
| **inputs_tensors, max_length=50, num_return_sequences=1 |
| ) |
| generated_text = self.tokenizer.decode(generation_output[0], skip_special_tokens=True) |
|
|
| return { |
| "generated_text": generated_text, |
| "probe_output": y_pred[y_pred_class] |
| } |
|
|
| def model_fn(): |
| return LinearProbeWrapper("mistralai/Mistral-7B-Instruct-v0.2") |
|
|