tt3d / app.py
Spanicin's picture
Update app.py
4613216 verified
raw
history blame contribute delete
No virus
4.01 kB
import torch
from shap_e.diffusion.sample import sample_latents
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, gif_widget
from flask import Flask, request, jsonify
from flask_cors import CORS
import threading
import io
import base64
app = Flask(__name__)
CORS(app)
pipe = None
app.config['temp_response'] = None
app.config['generation_thread'] = None
# def initialize_model():
# global pipe
# try:
# print('Downloading the model weights')
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# xm = load_model('transmitter', device=device)
# model = load_model('text300M', device=device)
# diffusion = diffusion_from_config(load_config('diffusion'))
# return device, xm, model, diffusion
# except Exception as e:
# print(f"Error downloading the model: {e}")
# return jsonify({"error": f"Failed to download model: {str(e)}"}), 500
def generate_image_gif(prompt):
print('Downloading the model weights')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
diffusion = diffusion_from_config(load_config('diffusion'))
try:
batch_size = 1
guidance_scale = 30.0
latents = sample_latents(
batch_size=batch_size,
model=model,
diffusion=diffusion,
guidance_scale=guidance_scale,
model_kwargs=dict(texts=[prompt] * batch_size),
progress=True,
clip_denoised=True,
use_fp16=True,
use_karras=True,
karras_steps=64,
sigma_min=1E-3,
sigma_max=160,
s_churn=0,
)
render_mode = 'nerf'
size = 256
# render_mode = 'nerf' # you can change this to 'stf'
# size = # this is the size of the renders, higher values take longer to render.
cameras = create_pan_cameras(size, device)
images = decode_latent_images(xm, latents, cameras, rendering_mode=render_mode)
writer = io.BytesIO()
images[0].save(writer, format="GIF", save_all=True, append_images=images[1:], duration=100, loop=0)
writer.seek(0)
data = base64.b64encode(writer.read()).decode("ascii")
response_data = {'base64_3d': data,'status':None}
print('response_data',response_data)
return response_data
except Exception as e:
print(f"Error generating 3D: {e}")
return jsonify({"error": f"Failed to generate 3D animation: {str(e)}"}), 500
def background(prompt):
with app.app_context():
data = generate_image_gif(prompt)
app.config['temp_response'] = data
@app.route('/run', methods=['POST'])
def handle_animation_request():
prompt = request.form.get('prompt')
if prompt:
generation_thread = threading.Thread(target=background, args=(prompt,))
app.config['generation_thread'] = generation_thread
generation_thread.start()
response_data = {"message": "3D generation started", "process_id": generation_thread.ident}
return jsonify(response_data)
else:
return jsonify({"message": "Please provide a valid text prompt."}), 400
@app.route('/status', methods=['GET'])
def check_animation_status():
process_id = request.args.get('process_id',None)
if process_id:
generation_thread = app.config.get('generation_thread')
if generation_thread and generation_thread.is_alive():
return jsonify({"status": "in_progress"}), 200
elif app.config.get('temp_response'):
final_response = app.config['temp_response']
final_response['status'] = 'completed'
return jsonify(final_response)
if __name__ == '__main__':
app.run(debug=True)