|
|
| from typing import List |
| from diffusers.modular_pipelines import ( |
| PipelineState, |
| ModularPipelineBlocks, |
| InputParam, |
| OutputParam, |
| ) |
| from PIL import Image |
| from google import genai |
| from io import BytesIO |
|
|
| client = genai.Client() |
|
|
| class NanoBanana(ModularPipelineBlocks): |
| def __init__(self, model_id="gemini-2.5-flash-image"): |
| super().__init__() |
| |
| self.model_id = model_id |
|
|
| @property |
| def expected_components(self): |
| return [] |
|
|
| @property |
| def inputs(self) -> List[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| type_hint=Image.Image, |
| required=False, |
| description="Image to use" |
| ), |
| InputParam( |
| "prompt", |
| type_hint=str, |
| required=True, |
| description="Prompt to use", |
| ) |
| ] |
| |
| @property |
| def intermediate_inputs(self) -> List[InputParam]: |
| return [] |
|
|
| @property |
| def intermediate_outputs(self) -> List[OutputParam]: |
| return [ |
| OutputParam( |
| "output_image", |
| type_hint=Image.Image, |
| description="Output image", |
| ), |
| OutputParam( |
| "old_image", |
| type_hint=Image.Image, |
| description="Old image (if) provided by the user", |
| ) |
| ] |
|
|
| def __call__(self, components, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| old_image = block_state.image |
| prompt = block_state.prompt |
| contents = [prompt] |
| if old_image is not None: |
| contents.append(old_image) |
| |
| response = client.models.generate_content( |
| model=self.model_id, contents=contents |
|
|
| ) |
| for part in response.candidates[0].content.parts: |
| if part.text is not None: |
| continue |
| elif part.inline_data is not None: |
| block_state.output_image = Image.open(BytesIO(part.inline_data.data)) |
| |
| if old_image is not None: |
| block_state.old_image = old_image |
| else: |
| block_state.old_image = None |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |