multimodalart HF staff commited on
Commit
51261e6
1 Parent(s): 1e18f63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -40,7 +40,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
40
  hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
41
  snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
42
 
43
- pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to(device)
44
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
45
  pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
46
  "THUDM/CogVideoX-5b",
@@ -50,7 +50,7 @@ pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
50
  tokenizer=pipe.tokenizer,
51
  text_encoder=pipe.text_encoder,
52
  torch_dtype=torch.bfloat16,
53
- ).to(device)
54
 
55
  pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
56
  "THUDM/CogVideoX-5b-I2V",
@@ -62,7 +62,7 @@ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
62
  tokenizer=pipe.tokenizer,
63
  text_encoder=pipe.text_encoder,
64
  torch_dtype=torch.bfloat16,
65
- ).to(device)
66
 
67
 
68
  # pipe.transformer.to(memory_format=torch.channels_last)
@@ -229,6 +229,7 @@ def infer(
229
 
230
  if video_input is not None:
231
  video = load_video(video_input)[:49] # Limit to 49 frames
 
232
  video_pt = pipe_video(
233
  video=video,
234
  prompt=prompt,
@@ -240,7 +241,9 @@ def infer(
240
  guidance_scale=guidance_scale,
241
  generator=torch.Generator(device="cpu").manual_seed(seed),
242
  ).frames
 
243
  elif image_input is not None:
 
244
  image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
245
  image = load_image(image_input)
246
  video_pt = pipe_image(
@@ -253,7 +256,9 @@ def infer(
253
  guidance_scale=guidance_scale,
254
  generator=torch.Generator(device="cpu").manual_seed(seed),
255
  ).frames
 
256
  else:
 
257
  video_pt = pipe(
258
  prompt=prompt,
259
  num_videos_per_prompt=1,
@@ -264,7 +269,7 @@ def infer(
264
  guidance_scale=guidance_scale,
265
  generator=torch.Generator(device="cpu").manual_seed(seed),
266
  ).frames
267
-
268
  return (video_pt, seed)
269
 
270
 
 
40
  hf_hub_download(repo_id="ai-forever/Real-ESRGAN", filename="RealESRGAN_x4.pth", local_dir="model_real_esran")
41
  snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
42
 
43
+ pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cpu")
44
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
45
  pipe_video = CogVideoXVideoToVideoPipeline.from_pretrained(
46
  "THUDM/CogVideoX-5b",
 
50
  tokenizer=pipe.tokenizer,
51
  text_encoder=pipe.text_encoder,
52
  torch_dtype=torch.bfloat16,
53
+ ).to("cpu")
54
 
55
  pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
56
  "THUDM/CogVideoX-5b-I2V",
 
62
  tokenizer=pipe.tokenizer,
63
  text_encoder=pipe.text_encoder,
64
  torch_dtype=torch.bfloat16,
65
+ ).to("cpu")
66
 
67
 
68
  # pipe.transformer.to(memory_format=torch.channels_last)
 
229
 
230
  if video_input is not None:
231
  video = load_video(video_input)[:49] # Limit to 49 frames
232
+ pipe_video.to(device)
233
  video_pt = pipe_video(
234
  video=video,
235
  prompt=prompt,
 
241
  guidance_scale=guidance_scale,
242
  generator=torch.Generator(device="cpu").manual_seed(seed),
243
  ).frames
244
+ pipe_video.to("cpu")
245
  elif image_input is not None:
246
+ pipe_image.to(device)
247
  image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
248
  image = load_image(image_input)
249
  video_pt = pipe_image(
 
256
  guidance_scale=guidance_scale,
257
  generator=torch.Generator(device="cpu").manual_seed(seed),
258
  ).frames
259
+ pipe_image.to("cpu")
260
  else:
261
+ pipe.to(device)
262
  video_pt = pipe(
263
  prompt=prompt,
264
  num_videos_per_prompt=1,
 
269
  guidance_scale=guidance_scale,
270
  generator=torch.Generator(device="cpu").manual_seed(seed),
271
  ).frames
272
+ pipe.to("cpu")
273
  return (video_pt, seed)
274
 
275