rupeshs commited on
Commit
2676544
·
1 Parent(s): 4189926

Add hf demo support

Browse files
Files changed (2) hide show
  1. src/app.py +10 -0
  2. src/frontend/webui/hf_demo.py +4 -4
src/app.py CHANGED
@@ -37,6 +37,12 @@ group.add_argument(
37
  action="store_true",
38
  help="Start Web UI",
39
  )
 
 
 
 
 
 
40
  group.add_argument(
41
  "-a",
42
  "--api",
@@ -330,6 +336,10 @@ elif args.mcp:
330
  from backend.api.mcp_server import start_mcp_server
331
 
332
  start_mcp_server(args.port)
 
 
 
 
333
  else:
334
  context = get_context(InterfaceType.CLI)
335
  config = app_settings.settings
 
37
  action="store_true",
38
  help="Start Web UI",
39
  )
40
+ group.add_argument(
41
+ "-d",
42
+ "--hfdemo",
43
+ action="store_true",
44
+ help="Start HF demo",
45
+ )
46
  group.add_argument(
47
  "-a",
48
  "--api",
 
336
  from backend.api.mcp_server import start_mcp_server
337
 
338
  start_mcp_server(args.port)
339
+ elif args.hfdemo:
340
+ from frontend.webui.hf_demo import start_demo
341
+
342
+ start_demo()
343
  else:
344
  context = get_context(InterfaceType.CLI)
345
  config = app_settings.settings
src/frontend/webui/hf_demo.py CHANGED
@@ -11,7 +11,7 @@ from constants import APP_VERSION
11
  from backend.device import is_openvino_device
12
  from PIL import Image
13
  from backend.models.lcmdiffusion_setting import DiffusionTask
14
- from backend.safety_check import is_safe_image
15
  from pprint import pprint
16
  from transformers import pipeline
17
 
@@ -24,6 +24,7 @@ classifier = pipeline(
24
  "image-classification",
25
  model="Falconsai/nsfw_image_detection",
26
  )
 
27
 
28
 
29
  # https://github.com/gradio-app/gradio/issues/2635#issuecomment-1423531319
@@ -71,8 +72,7 @@ def predict(
71
  latency = perf_counter() - start
72
  print(f"Latency: {latency:.2f} seconds")
73
  result = images[0]
74
- if is_safe_image(
75
- classifier,
76
  result,
77
  ):
78
  return result # .resize([512, 512], Image.LANCZOS)
@@ -171,6 +171,6 @@ with gr.Blocks(css=css) as demo:
171
  generate_btn.click(fn=predict, inputs=inputs, outputs=image)
172
 
173
 
174
- if __name__ == "__main__":
175
  demo.queue()
176
  demo.launch(share=False)
 
11
  from backend.device import is_openvino_device
12
  from PIL import Image
13
  from backend.models.lcmdiffusion_setting import DiffusionTask
14
+ from backend.safety_checker import SafetyChecker
15
  from pprint import pprint
16
  from transformers import pipeline
17
 
 
24
  "image-classification",
25
  model="Falconsai/nsfw_image_detection",
26
  )
27
+ safety_checker = SafetyChecker()
28
 
29
 
30
  # https://github.com/gradio-app/gradio/issues/2635#issuecomment-1423531319
 
72
  latency = perf_counter() - start
73
  print(f"Latency: {latency:.2f} seconds")
74
  result = images[0]
75
+ if safety_checker.is_safe(
 
76
  result,
77
  ):
78
  return result # .resize([512, 512], Image.LANCZOS)
 
171
  generate_btn.click(fn=predict, inputs=inputs, outputs=image)
172
 
173
 
174
+ def start_demo():
175
  demo.queue()
176
  demo.launch(share=False)