import torch from shap_e.diffusion.sample import sample_latents from shap_e.diffusion.gaussian_diffusion import diffusion_from_config from shap_e.models.download import load_model, load_config from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget from flask import Flask, request, jsonify from flask_cors import CORS import threading import io import base64 app = Flask(__name__) CORS(app) pipe = None app.config['temp_response'] = None app.config['generation_thread'] = None # def initialize_model(): # global pipe # try: # print('Downloading the model weights') # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # xm = load_model('transmitter', device=device) # model = load_model('text300M', device=device) # diffusion = diffusion_from_config(load_config('diffusion')) # return device, xm, model, diffusion # except Exception as e: # print(f"Error downloading the model: {e}") # return jsonify({"error": f"Failed to download model: {str(e)}"}), 500 def generate_image_gif(prompt): print('Downloading the model weights') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') xm = load_model('transmitter', device=device) model = load_model('text300M', device=device) diffusion = diffusion_from_config(load_config('diffusion')) try: batch_size = 1 guidance_scale = 30.0 latents = sample_latents( batch_size=batch_size, model=model, diffusion=diffusion, guidance_scale=guidance_scale, model_kwargs=dict(texts=[prompt] * batch_size), progress=True, clip_denoised=True, use_fp16=True, use_karras=True, karras_steps=64, sigma_min=1E-3, sigma_max=160, s_churn=0, ) render_mode = 'nerf' size = 256 # render_mode = 'nerf' # you can change this to 'stf' # size = # this is the size of the renders, higher values take longer to render. cameras = create_pan_cameras(size, device) images = decode_latent_images(xm, latents, cameras, rendering_mode=render_mode) writer = io.BytesIO() images[0].save(writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0) writer.seek(0) data = base64.b64encode(writer.read()).decode("ascii") response_data = {'base64_3d': data,'status':None} print('response_data',response_data) return response_data except Exception as e: print(f"Error generating 3D: {e}") return jsonify({"error": f"Failed to generate 3D animation: {str(e)}"}), 500 def background(prompt): with app.app_context(): data = generate_image_gif(prompt) app.config['temp_response'] = data @app.route('/run', methods=['POST']) def handle_animation_request(): prompt = request.form.get('prompt') if prompt: generation_thread = threading.Thread(target=background, args=(prompt,)) app.config['generation_thread'] = generation_thread generation_thread.start() response_data = {"message": "3D generation started", "process_id": generation_thread.ident} return jsonify(response_data) else: return jsonify({"message": "Please provide a valid text prompt."}), 400 @app.route('/status', methods=['GET']) def check_animation_status(): process_id = request.args.get('process_id',None) if process_id: generation_thread = app.config.get('generation_thread') if generation_thread and generation_thread.is_alive(): return jsonify({"status": "in_progress"}), 200 elif app.config.get('temp_response'): final_response = app.config['temp_response'] final_response['status'] = 'completed' return jsonify(final_response) if __name__ == '__main__': app.run(debug=True)