lnyan commited on
Commit
e9fdd34
1 Parent(s): 444252b

Add cuda check

Browse files
Files changed (1) hide show
  1. app.py +21 -20
app.py CHANGED
@@ -14,7 +14,7 @@ import skimage
14
  import skimage.measure
15
  from utils import *
16
 
17
-
18
  def load_html():
19
  body, canvaspy = "", ""
20
  with open("index.html", encoding="utf8") as f:
@@ -126,22 +126,23 @@ def run_outpaint(
126
  state,
127
  ):
128
  base64_str = "base64"
129
- # data = base64.b64decode(str(sel_buffer_str))
130
- # pil = Image.open(io.BytesIO(data))
131
- # sel_buffer = np.array(pil)
132
- # sel_buffer[:, :, 3]=255
133
- # sel_buffer[:, :, 0]=255
134
- # out_pil = Image.fromarray(sel_buffer)
135
- # out_buffer = io.BytesIO()
136
- # out_pil.save(out_buffer, format="PNG")
137
- # out_buffer.seek(0)
138
- # base64_bytes = base64.b64encode(out_buffer.read())
139
- # base64_str = base64_bytes.decode("ascii")
140
- # return (
141
- # gr.update(label=str(state + 1), value=base64_str,),
142
- # gr.update(label="Prompt"),
143
- # state + 1,
144
- # )
 
145
  if True:
146
  text2img, inpaint = get_model()
147
  if enable_safety:
@@ -223,9 +224,9 @@ outpaint_button_js = load_js("outpaint")
223
  proceed_button_js = load_js("proceed")
224
  mode_js = load_js("mode")
225
  setup_button_js = load_js("setup")
226
-
227
- # def get_model(x):
228
- # pass
229
  get_model(get_token())
230
 
231
  with blocks as demo:
 
14
  import skimage.measure
15
  from utils import *
16
 
17
+ cuda_available = torch.cuda.is_available()
18
  def load_html():
19
  body, canvaspy = "", ""
20
  with open("index.html", encoding="utf8") as f:
 
126
  state,
127
  ):
128
  base64_str = "base64"
129
+ if not cuda_available:
130
+ data = base64.b64decode(str(sel_buffer_str))
131
+ pil = Image.open(io.BytesIO(data))
132
+ sel_buffer = np.array(pil)
133
+ sel_buffer[:, :, 3]=255
134
+ sel_buffer[:, :, 0]=255
135
+ out_pil = Image.fromarray(sel_buffer)
136
+ out_buffer = io.BytesIO()
137
+ out_pil.save(out_buffer, format="PNG")
138
+ out_buffer.seek(0)
139
+ base64_bytes = base64.b64encode(out_buffer.read())
140
+ base64_str = base64_bytes.decode("ascii")
141
+ return (
142
+ gr.update(label=str(state + 1), value=base64_str,),
143
+ gr.update(label="Prompt"),
144
+ state + 1,
145
+ )
146
  if True:
147
  text2img, inpaint = get_model()
148
  if enable_safety:
 
224
  proceed_button_js = load_js("proceed")
225
  mode_js = load_js("mode")
226
  setup_button_js = load_js("setup")
227
+ if not torch.cuda.is_available():
228
+ def get_model(x):
229
+ pass
230
  get_model(get_token())
231
 
232
  with blocks as demo: