| import torch |
| from shap_e.diffusion.sample import sample_latents |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config |
| from shap_e.models.download import load_model, load_config |
| from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget |
| from flask import Flask, request, jsonify |
| from flask_cors import CORS |
| import threading |
| import io |
| import base64 |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| pipe = None |
| app.config['temp_response'] = None |
| app.config['generation_thread'] = None |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def generate_image_gif(prompt): |
| print('Downloading the model weights') |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| xm = load_model('transmitter', device=device) |
| model = load_model('text300M', device=device) |
| diffusion = diffusion_from_config(load_config('diffusion')) |
|
|
| try: |
| batch_size = 1 |
| guidance_scale = 30.0 |
| |
| latents = sample_latents( |
| batch_size=batch_size, |
| model=model, |
| diffusion=diffusion, |
| guidance_scale=guidance_scale, |
| model_kwargs=dict(texts=[prompt] * batch_size), |
| progress=True, |
| clip_denoised=True, |
| use_fp16=True, |
| use_karras=True, |
| karras_steps=64, |
| sigma_min=1E-3, |
| sigma_max=160, |
| s_churn=0, |
| ) |
| render_mode = 'nerf' |
| size = 256 |
| |
| |
| |
| cameras = create_pan_cameras(size, device) |
| images = decode_latent_images(xm, latents, cameras, rendering_mode=render_mode) |
| writer = io.BytesIO() |
| images[0].save(writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0) |
| writer.seek(0) |
| data = base64.b64encode(writer.read()).decode("ascii") |
| response_data = {'base64_3d': data,'status':None} |
| print('response_data',response_data) |
| return response_data |
| except Exception as e: |
| print(f"Error generating 3D: {e}") |
| return jsonify({"error": f"Failed to generate 3D animation: {str(e)}"}), 500 |
|
|
| def background(prompt): |
| with app.app_context(): |
| data = generate_image_gif(prompt) |
| app.config['temp_response'] = data |
|
|
| @app.route('/run', methods=['POST']) |
| def handle_animation_request(): |
|
|
| prompt = request.form.get('prompt') |
| if prompt: |
| generation_thread = threading.Thread(target=background, args=(prompt,)) |
| app.config['generation_thread'] = generation_thread |
| generation_thread.start() |
| response_data = {"message": "3D generation started", "process_id": generation_thread.ident} |
| |
| return jsonify(response_data) |
| else: |
| return jsonify({"message": "Please provide a valid text prompt."}), 400 |
|
|
| @app.route('/status', methods=['GET']) |
| def check_animation_status(): |
| process_id = request.args.get('process_id',None) |
| |
| if process_id: |
| generation_thread = app.config.get('generation_thread') |
| if generation_thread and generation_thread.is_alive(): |
| return jsonify({"status": "in_progress"}), 200 |
| elif app.config.get('temp_response'): |
| final_response = app.config['temp_response'] |
| final_response['status'] = 'completed' |
| return jsonify(final_response) |
|
|
| if __name__ == '__main__': |
| app.run(debug=True) |