Traly commited on
Commit
a089d4c
β€’
1 Parent(s): f0d5d42
Files changed (1) hide show
  1. app.py +7 -11
app.py CHANGED
@@ -52,7 +52,7 @@ def load_checkpoint(model, ckpt_path):
52
  new_state_dict[k] = v
53
 
54
  model.load_state_dict(new_state_dict)
55
- model.cuda()
56
  del checkpoint
57
  torch.cuda.empty_cache()
58
 
@@ -77,8 +77,8 @@ def image_infer(img_PIL):
77
  trainer.model.eval()
78
  img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4)
79
 
80
- img_lr = img_lr.to('cuda')
81
- img_lr_up = img_lr_up.to('cuda')
82
 
83
  img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
84
 
@@ -89,25 +89,21 @@ def image_infer(img_PIL):
89
  return img_sr
90
 
91
 
92
- # cheetah = os.path.join(os.path.dirname(__file__), "images/cheetah1.jpg")
93
-
94
  root_path = os.path.dirname(__file__)
95
 
96
- cheetah = os.path.join(root_path, "images/lion.jpg")
97
  print(cheetah)
98
 
 
 
99
  demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image",
100
  # flagging_options=["blurry", "incorrect", "other"],
101
  examples=[
102
  os.path.join(root_path, "images/0801x4.png"),
103
- os.path.join(root_path, "images/0809x4.png"),
104
  os.path.join(root_path, "images/0809x4.png"),
105
  ]
106
  )
107
 
108
  if __name__ == "__main__":
109
- parent_path = Path(__file__).absolute().parent
110
- fill_root = os.path.abspath(parent_path)
111
- ckpt_path = os.path.join(fill_root, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt')
112
- trainer = model_init(ckpt_path)
113
  demo.launch()
 
52
  new_state_dict[k] = v
53
 
54
  model.load_state_dict(new_state_dict)
55
+ # model.cuda()
56
  del checkpoint
57
  torch.cuda.empty_cache()
58
 
 
77
  trainer.model.eval()
78
  img_lr, img_lr_up = get_img_data(img_PIL, hparams, sr_scale=4)
79
 
80
+ # img_lr = img_lr.to('cuda')
81
+ # img_lr_up = img_lr_up.to('cuda')
82
 
83
  img_sr, _ = trainer.model.sample(img_lr, img_lr_up, img_lr_up.shape)
84
 
 
89
  return img_sr
90
 
91
 
 
 
92
  root_path = os.path.dirname(__file__)
93
 
94
+ cheetah = os.path.join(root_path, "images/0801x4.png")
95
  print(cheetah)
96
 
97
+ ckpt_path = os.path.join(root_path, 'sam_diffsr/weight/model_ckpt_steps_400000.ckpt')
98
+ trainer = model_init(ckpt_path)
99
  demo = gr.Interface(image_infer, gr.Image(type="pil", value=cheetah), "image",
100
  # flagging_options=["blurry", "incorrect", "other"],
101
  examples=[
102
  os.path.join(root_path, "images/0801x4.png"),
103
+ os.path.join(root_path, "images/0804x4.png"),
104
  os.path.join(root_path, "images/0809x4.png"),
105
  ]
106
  )
107
 
108
  if __name__ == "__main__":
 
 
 
 
109
  demo.launch()