| import streamlit as st |
| import torch |
| from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer |
| from diffusers import StableDiffusionPipeline |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| st.write(f"Using device: {device}") |
|
|
| |
| @st.cache_resource |
| def load_text_model(): |
| try: |
| st.write("β³ Loading text model...") |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
| low_cpu_mem_usage=True |
| ).to(device) |
|
|
| st.write("β
Text model loaded successfully!") |
| return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1) |
| |
| except Exception as e: |
| st.error(f"β Error loading text model: {e}") |
| return None |
|
|
| story_generator = load_text_model() |
|
|
| |
| @st.cache_resource |
| def load_image_model(): |
| try: |
| st.write("β³ Loading image model...") |
| model_id = "stabilityai/sd-turbo" |
| model = StableDiffusionPipeline.from_pretrained( |
| model_id, |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
| ).to(device) |
| model.enable_attention_slicing() |
| st.write("β
Image model loaded successfully!") |
| return model |
| except Exception as e: |
| st.error(f"β Error loading image model: {e}") |
| return None |
|
|
| image_generator = load_image_model() |
|
|
| |
| def generate_story(prompt): |
| if not story_generator: |
| return "β Error: Story model not loaded." |
|
|
| formatted_prompt = f"Write a short comic-style story about: {prompt}\n\nStory:" |
| |
| try: |
| st.write("β³ Generating story...") |
| story_output = story_generator( |
| formatted_prompt, |
| max_length=100, |
| do_sample=True, |
| temperature=0.7, |
| top_k=30, |
| num_return_sequences=1, |
| truncation=True |
| )[0]['generated_text'] |
| st.write("β
Story generated successfully!") |
| return story_output.replace(formatted_prompt, "").strip() |
| except Exception as e: |
| st.error(f"β Error generating story: {e}") |
| return "Error generating story." |
|
|
| |
| st.title("π¦ΈββοΈ AI Comic Story Generator") |
| st.write("Enter a prompt to generate a comic-style story and image!") |
|
|
| |
| user_prompt = st.text_input("π Enter your story prompt:") |
|
|
| if user_prompt: |
| st.subheader("π AI-Generated Story") |
| generated_story = generate_story(user_prompt) |
| st.write(generated_story) |
|
|
| st.subheader("πΌοΈ AI-Generated Image") |
| |
| if not image_generator: |
| st.error("β Error: Image model not loaded.") |
| else: |
| with st.spinner("β³ Generating image..."): |
| try: |
| image = image_generator( |
| user_prompt, |
| num_inference_steps=8, |
| height=256, width=256 |
| ).images[0] |
| st.write("β
Image generated successfully!") |
| st.image(image, caption="Generated Comic Image", use_container_width=True) |
| except Exception as e: |
| st.error(f"β Error generating image: {e}") |
|
|