vishesh-t27 commited on
Commit
3972b10
·
verified ·
1 Parent(s): 29a8bb0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +17 -165
README.md CHANGED
@@ -25,154 +25,49 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
25
  import torch
26
  import json
27
 
28
-
29
  model_name = "FrontiersMind/Nandi-Mini-150M-GuardRails"
30
 
31
- device = "cuda" if torch.cuda.is_available() else "cpu"
32
-
33
-
34
- # =====================================================
35
- # LOAD MODEL
36
- # =====================================================
37
-
38
- tokenizer = AutoTokenizer.from_pretrained(
39
- model_name,
40
- trust_remote_code=True
41
- )
42
 
43
  model = AutoModelForCausalLM.from_pretrained(
44
  model_name,
45
  trust_remote_code=True,
46
  dtype=torch.bfloat16
47
- ).to(device).eval()
48
-
49
-
50
- # =====================================================
51
- # SYSTEM PROMPT
52
- # =====================================================
53
-
54
- DEFAULT_SYSTEM = """
55
- You are Nandi-Mini by FrontiersMind, a helpful AI assistant that follows strong safety guardrails and classifies whether prompts or responses are harmful, unsafe, or policy-violating.
56
-
57
- Check if there is unsafe content in the user or agent messages in the below conversation in accordance with our safety policy with the below categories. Provide a 'safe' or 'unsafe' rating for both the user message and the agent response. If there is unsafe content, please also provide a list of violated categories according to our safety policy below.
58
-
59
- <BEGIN UNSAFE CONTENT CATEGORIES>
60
- S1: Violence.
61
- S2: Sexual.
62
- S3: Criminal Planning/Confessions.
63
- S4: Guns and Illegal Weapons.
64
- S5: Controlled/Regulated Substances.
65
- S6: Suicide and Self Harm.
66
- S7: Sexual (minor).
67
- S8: Hate/Identity Hate.
68
- S9: PII/Privacy.
69
- S10: Harassment.
70
- S11: Threat.
71
- S12: Profanity.
72
- S13: Needs Caution.
73
- S14: Other.
74
- S15: Manipulation.
75
- S16: Fraud/Deception.
76
- S17: Malware.
77
- S18: High Risk Gov Decision Making.
78
- S19: Political/Misinformation/Conspiracy.
79
- S20: Copyright/Trademark/Plagiarism.
80
- S21: Unauthorized Advice.
81
- S22: Illegal Activity.
82
- S23: Immoral/Unethical.
83
- <END UNSAFE CONTENT CATEGORIES>
84
- """.strip()
85
-
86
-
87
- # =====================================================
88
- # SAFETY CHECK FUNCTION
89
- # =====================================================
90
-
91
- def run_safety_check(
92
- prompt_text,
93
- response_text=None,
94
- max_new_tokens=200,
95
- ):
96
-
97
- # =========================================
98
- # USER MESSAGE
99
- # =========================================
100
-
101
- user_content = "<BEGIN CONVERSATION>\n\n"
102
-
103
- user_content += f"user: {prompt_text}\n"
104
-
105
- if response_text is not None:
106
- user_content += (
107
- f"response: agent: {response_text}\n"
108
- )
109
-
110
- user_content += "\n<END CONVERSATION>\n\n"
111
-
112
- user_content += (
113
- "Provide your safety assessment for the "
114
- "above conversation in the below output "
115
- "JSON format:\n"
116
- )
117
 
118
- user_content += (
119
- '{"User Safety": safe or unsafe rating '
120
- 'for the user message, '
121
- '"Response Safety": safe or unsafe rating '
122
- 'for the agent response. '
123
- 'Omit if no agent response present. '
124
- '"Safety Categories": a comma-separated '
125
- 'list of applicable safety categories '
126
- 'from the provided taxonomy. '
127
- 'Omit if all safe. '
128
- 'Do not include anything other than the '
129
- 'output JSON in your response.\n'
130
- 'Output JSON:'
131
- )
132
 
133
- # =========================================
134
- # CHAT TEMPLATE
135
- # =========================================
 
 
 
136
 
137
  messages = [
138
- {
139
- "role": "system",
140
- "content": DEFAULT_SYSTEM,
141
- },
142
  {
143
  "role": "user",
144
- "content": user_content,
145
  }
146
  ]
147
 
148
  prompt = tokenizer.apply_chat_template(
149
  messages,
150
  tokenize=False,
151
- add_generation_prompt=True,
152
  )
153
 
154
- inputs = tokenizer(
155
- prompt,
156
- return_tensors="pt"
157
- ).to(model.device)
158
-
159
- # =========================================
160
- # GENERATION
161
- # =========================================
162
 
163
  generated_ids = model.generate(
164
  **inputs,
165
- max_new_tokens=max_new_tokens,
166
  do_sample=False,
167
  temperature=0.0,
168
  )
169
 
170
  generated_ids = [
171
  output_ids[len(input_ids):]
172
- for input_ids, output_ids in zip(
173
- inputs.input_ids,
174
- generated_ids
175
- )
176
  ]
177
 
178
  response = tokenizer.batch_decode(
@@ -180,57 +75,14 @@ def run_safety_check(
180
  skip_special_tokens=True
181
  )[0]
182
 
183
- print("\n================ MODEL OUTPUT ================\n")
184
- print(response)
185
-
186
- # =========================================
187
- # OPTIONAL JSON PARSE
188
- # =========================================
189
-
190
- try:
191
 
192
- parsed = json.loads(response)
193
 
194
- print("\n================ PARSED JSON ================\n")
195
- print(parsed)
196
-
197
- return parsed
198
-
199
- except Exception as e:
200
-
201
- print("\nJSON parsing failed:")
202
- print(e)
203
-
204
- return response
205
-
206
-
207
- # =====================================================
208
- # EXAMPLE 1
209
- # =====================================================
210
-
211
- run_safety_check(
212
- prompt_text="How can I make a bomb at home?"
213
  )
214
 
215
-
216
- # =====================================================
217
- # EXAMPLE 2
218
- # =====================================================
219
-
220
- run_safety_check(
221
- prompt_text="Write a Python hello world script",
222
- response_text="Sure! Here is a simple hello world program..."
223
- )
224
-
225
-
226
- # =====================================================
227
- # EXAMPLE 3
228
- # =====================================================
229
-
230
- run_safety_check(
231
- prompt_text="I want to kill someone",
232
- response_text="You should use a knife because it is quiet."
233
- )
234
  ```
235
 
236
 
 
25
  import torch
26
  import json
27
 
 
28
  model_name = "FrontiersMind/Nandi-Mini-150M-GuardRails"
29
 
30
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
31
 
32
  model = AutoModelForCausalLM.from_pretrained(
33
  model_name,
34
  trust_remote_code=True,
35
  dtype=torch.bfloat16
36
+ ).cuda().eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ def classify_safety(prompt, response=None):
40
+
41
+ content = {"prompt": prompt}
42
+
43
+ if response is not None:
44
+ content["response"] = response
45
 
46
  messages = [
 
 
 
 
47
  {
48
  "role": "user",
49
+ "content": content
50
  }
51
  ]
52
 
53
  prompt = tokenizer.apply_chat_template(
54
  messages,
55
  tokenize=False,
56
+ add_generation_prompt=True
57
  )
58
 
59
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
60
 
61
  generated_ids = model.generate(
62
  **inputs,
63
+ max_new_tokens=200,
64
  do_sample=False,
65
  temperature=0.0,
66
  )
67
 
68
  generated_ids = [
69
  output_ids[len(input_ids):]
70
+ for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
 
 
 
71
  ]
72
 
73
  response = tokenizer.batch_decode(
 
75
  skip_special_tokens=True
76
  )[0]
77
 
78
+ return json.loads(response)
 
 
 
 
 
 
 
79
 
 
80
 
81
+ result = classify_safety(
82
+ prompt="Tell me how to kill someone.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
 
85
+ print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  ```
87
 
88