multimodalart HF staff commited on
Commit
fff6dc8
1 Parent(s): 799d1da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -13
app.py CHANGED
@@ -44,18 +44,9 @@ snapshot_download(repo_id="AlexWortega/RIFE", local_dir="model_rife")
44
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cpu")
45
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
46
 
47
- pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
48
- "THUDM/CogVideoX-5b-I2V",
49
- transformer=CogVideoXTransformer3DModel.from_pretrained(
50
- "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
51
- ),
52
- vae=pipe.vae,
53
- scheduler=pipe.scheduler,
54
- tokenizer=pipe.tokenizer,
55
- text_encoder=pipe.text_encoder,
56
- torch_dtype=torch.bfloat16,
57
- ).to("cpu")
58
-
59
 
60
  # pipe.transformer.to(memory_format=torch.channels_last)
61
  # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
@@ -241,11 +232,20 @@ def infer(
241
  guidance_scale=guidance_scale,
242
  generator=torch.Generator(device="cpu").manual_seed(seed),
243
  ).frames
 
244
  del pipe_video
245
  gc.collect()
246
  torch.cuda.empty_cache()
247
  elif image_input is not None:
248
- pipe_image.to(device)
 
 
 
 
 
 
 
 
249
  image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
250
  image = load_image(image_input)
251
  video_pt = pipe_image(
@@ -259,7 +259,9 @@ def infer(
259
  generator=torch.Generator(device="cpu").manual_seed(seed),
260
  ).frames
261
  pipe_image.to("cpu")
 
262
  gc.collect()
 
263
  else:
264
  pipe.to(device)
265
  video_pt = pipe(
 
44
  pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cpu")
45
  pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
46
 
47
+ i2v_transformer = CogVideoXTransformer3DModel.from_pretrained(
48
+ "THUDM/CogVideoX-5b-I2V", subfolder="transformer", torch_dtype=torch.bfloat16
49
+ )
 
 
 
 
 
 
 
 
 
50
 
51
  # pipe.transformer.to(memory_format=torch.channels_last)
52
  # pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
 
232
  guidance_scale=guidance_scale,
233
  generator=torch.Generator(device="cpu").manual_seed(seed),
234
  ).frames
235
+ pipe_video.to("cpu")
236
  del pipe_video
237
  gc.collect()
238
  torch.cuda.empty_cache()
239
  elif image_input is not None:
240
+ pipe_image = CogVideoXImageToVideoPipeline.from_pretrained(
241
+ "THUDM/CogVideoX-5b-I2V",
242
+ transformer=i2v_transformer,
243
+ vae=pipe.vae,
244
+ scheduler=pipe.scheduler,
245
+ tokenizer=pipe.tokenizer,
246
+ text_encoder=pipe.text_encoder,
247
+ torch_dtype=torch.bfloat16,
248
+ ).to(device)
249
  image_input = Image.fromarray(image_input).resize(size=(720, 480)) # Convert to PIL
250
  image = load_image(image_input)
251
  video_pt = pipe_image(
 
259
  generator=torch.Generator(device="cpu").manual_seed(seed),
260
  ).frames
261
  pipe_image.to("cpu")
262
+ del pipe_image
263
  gc.collect()
264
+ torch.cuda.empty_cache()
265
  else:
266
  pipe.to(device)
267
  video_pt = pipe(