|
import torch |
|
from shap_e.diffusion.sample import sample_latents |
|
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config |
|
from shap_e.models.download import load_model, load_config |
|
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget |
|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
import threading |
|
import io |
|
import base64 |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
pipe = None |
|
app.config['temp_response'] = None |
|
app.config['generation_thread'] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_image_gif(prompt): |
|
print('Downloading the model weights') |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
xm = load_model('transmitter', device=device) |
|
model = load_model('text300M', device=device) |
|
diffusion = diffusion_from_config(load_config('diffusion')) |
|
|
|
try: |
|
batch_size = 1 |
|
guidance_scale = 30.0 |
|
|
|
latents = sample_latents( |
|
batch_size=batch_size, |
|
model=model, |
|
diffusion=diffusion, |
|
guidance_scale=guidance_scale, |
|
model_kwargs=dict(texts=[prompt] * batch_size), |
|
progress=True, |
|
clip_denoised=True, |
|
use_fp16=True, |
|
use_karras=True, |
|
karras_steps=64, |
|
sigma_min=1E-3, |
|
sigma_max=160, |
|
s_churn=0, |
|
) |
|
render_mode = 'nerf' |
|
size = 256 |
|
|
|
|
|
|
|
cameras = create_pan_cameras(size, device) |
|
images = decode_latent_images(xm, latents, cameras, rendering_mode=render_mode) |
|
writer = io.BytesIO() |
|
images[0].save(writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0) |
|
writer.seek(0) |
|
data = base64.b64encode(writer.read()).decode("ascii") |
|
response_data = {'base64_3d': data,'status':None} |
|
print('response_data',response_data) |
|
return response_data |
|
except Exception as e: |
|
print(f"Error generating 3D: {e}") |
|
return jsonify({"error": f"Failed to generate 3D animation: {str(e)}"}), 500 |
|
|
|
def background(prompt): |
|
with app.app_context(): |
|
data = generate_image_gif(prompt) |
|
app.config['temp_response'] = data |
|
|
|
@app.route('/run', methods=['POST']) |
|
def handle_animation_request(): |
|
|
|
prompt = request.form.get('prompt') |
|
if prompt: |
|
generation_thread = threading.Thread(target=background, args=(prompt,)) |
|
app.config['generation_thread'] = generation_thread |
|
generation_thread.start() |
|
response_data = {"message": "3D generation started", "process_id": generation_thread.ident} |
|
|
|
return jsonify(response_data) |
|
else: |
|
return jsonify({"message": "Please provide a valid text prompt."}), 400 |
|
|
|
@app.route('/status', methods=['GET']) |
|
def check_animation_status(): |
|
process_id = request.args.get('process_id',None) |
|
|
|
if process_id: |
|
generation_thread = app.config.get('generation_thread') |
|
if generation_thread and generation_thread.is_alive(): |
|
return jsonify({"status": "in_progress"}), 200 |
|
elif app.config.get('temp_response'): |
|
final_response = app.config['temp_response'] |
|
final_response['status'] = 'completed' |
|
return jsonify(final_response) |
|
|
|
if __name__ == '__main__': |
|
app.run(debug=True) |