Isshi14 commited on
Commit
71b59d0
·
verified ·
1 Parent(s): 620c87c

Update script_gen.py

Browse files
Files changed (1) hide show
  1. script_gen.py +13 -16
script_gen.py CHANGED
@@ -58,7 +58,10 @@ def _get_client() -> InferenceClient:
58
  "HF_TOKEN environment variable is not set. "
59
  "Please set your Hugging Face API token to use the script generation feature."
60
  )
61
- return InferenceClient(token=token)
 
 
 
62
 
63
 
64
  def generate_script(
@@ -95,28 +98,22 @@ def generate_script(
95
 
96
  client = _get_client()
97
 
98
- # Call the model using text_generation
99
- response = client.text_generation(
100
- prompt=_format_prompt(SYSTEM_PROMPT, user_message),
101
  model=MODEL_ID,
102
- max_new_tokens=MAX_NEW_TOKENS,
 
 
 
 
103
  temperature=TEMPERATURE,
104
  top_p=0.9,
105
- do_sample=True,
106
  )
107
 
108
- script = response.strip()
109
 
110
  if not script:
111
  raise RuntimeError("The model returned an empty script. Please try again.")
112
 
113
  logger.info("Script generated: %d characters", len(script))
114
- return script
115
-
116
-
117
- def _format_prompt(system: str, user: str) -> str:
118
- """
119
- Format a prompt for SmolLM3 instruction following.
120
- Uses a simple system + user format.
121
- """
122
- return f"<|im_start|>system\n{system}<|im_end|>\n<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n"
 
58
  "HF_TOKEN environment variable is not set. "
59
  "Please set your Hugging Face API token to use the script generation feature."
60
  )
61
+ return InferenceClient(
62
+ provider="hf-inference",
63
+ token=token,
64
+ )
65
 
66
 
67
  def generate_script(
 
98
 
99
  client = _get_client()
100
 
101
+ # Call the model using chat_completion
102
+ response = client.chat_completion(
 
103
  model=MODEL_ID,
104
+ messages=[
105
+ {"role": "system", "content": SYSTEM_PROMPT},
106
+ {"role": "user", "content": user_message},
107
+ ],
108
+ max_tokens=MAX_NEW_TOKENS,
109
  temperature=TEMPERATURE,
110
  top_p=0.9,
 
111
  )
112
 
113
+ script = response.choices[0].message.content.strip()
114
 
115
  if not script:
116
  raise RuntimeError("The model returned an empty script. Please try again.")
117
 
118
  logger.info("Script generated: %d characters", len(script))
119
+ return script