Spaces:
Sleeping
Sleeping
JarvisLabs
commited on
Commit
•
8276f79
1
Parent(s):
f3299a1
Update src/rep_api.py
Browse files- src/rep_api.py +210 -190
src/rep_api.py
CHANGED
@@ -1,190 +1,210 @@
|
|
1 |
-
import replicate
|
2 |
-
import os
|
3 |
-
from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile
|
4 |
-
import json
|
5 |
-
import time
|
6 |
-
style_json="model_dict.json"
|
7 |
-
model_dict=json.load(open(style_json,"r"))
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model
|
12 |
-
print(prompt,lora_model,api_path,aspect_ratio)
|
13 |
-
#if model=="dev":
|
14 |
-
num_inference_steps=30
|
15 |
-
if model=="schnell":
|
16 |
-
num_inference_steps=5
|
17 |
-
|
18 |
-
if lora_model is not None:
|
19 |
-
api_path=model_dict[lora_model]
|
20 |
-
|
21 |
-
inputs={
|
22 |
-
"model": model,
|
23 |
-
"prompt": prompt,
|
24 |
-
"lora_scale":lora_scale,
|
25 |
-
"aspect_ratio": aspect_ratio,
|
26 |
-
"num_outputs":num_outputs,
|
27 |
-
"num_inference_steps":num_inference_steps,
|
28 |
-
"guidance_scale":guidance_scale,
|
29 |
-
}
|
30 |
-
if seed is not None:
|
31 |
-
inputs["seed"]=seed
|
32 |
-
output = replicate.run(
|
33 |
-
api_path,
|
34 |
-
input=inputs
|
35 |
-
)
|
36 |
-
print(output)
|
37 |
-
if gallery is None:
|
38 |
-
gallery=[]
|
39 |
-
gallery.append(output[0])
|
40 |
-
return output[0],gallery
|
41 |
-
|
42 |
-
|
43 |
-
def replicate_caption_api(image,model,context_text):
|
44 |
-
base64_image = image_to_base64(image)
|
45 |
-
if model=="blip":
|
46 |
-
output = replicate.run(
|
47 |
-
"andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
|
48 |
-
input={
|
49 |
-
"image": base64_image,
|
50 |
-
"caption": True,
|
51 |
-
"question": context_text,
|
52 |
-
"temperature": 1,
|
53 |
-
"use_nucleus_sampling": False
|
54 |
-
}
|
55 |
-
)
|
56 |
-
print(output)
|
57 |
-
|
58 |
-
elif model=="llava-16":
|
59 |
-
output = replicate.run(
|
60 |
-
# "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
|
61 |
-
"yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
|
62 |
-
input={
|
63 |
-
"image": base64_image,
|
64 |
-
"top_p": 1,
|
65 |
-
"prompt": context_text,
|
66 |
-
"max_tokens": 1024,
|
67 |
-
"temperature": 0.2
|
68 |
-
}
|
69 |
-
)
|
70 |
-
print(output)
|
71 |
-
output = "".join(output)
|
72 |
-
|
73 |
-
elif model=="img2prompt":
|
74 |
-
output = replicate.run(
|
75 |
-
"methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
|
76 |
-
input={
|
77 |
-
"image":base64_image
|
78 |
-
}
|
79 |
-
)
|
80 |
-
print(output)
|
81 |
-
return output
|
82 |
-
|
83 |
-
def update_replicate_api_key(api_key):
|
84 |
-
os.environ["REPLICATE_API_TOKEN"] = api_key
|
85 |
-
return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
#
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import replicate
|
2 |
+
import os
|
3 |
+
from src.utils import image_to_base64 , update_model_dicts, BB_uploadfile,numpy_to_base64
|
4 |
+
import json
|
5 |
+
import time
|
6 |
+
style_json="model_dict.json"
|
7 |
+
model_dict=json.load(open(style_json,"r"))
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def generate_image_replicate(prompt,lora_model,api_path,aspect_ratio,gallery,model,lora_scale=1,num_outputs=1,guidance_scale=3.5,seed=None):
|
12 |
+
print(prompt,lora_model,api_path,aspect_ratio)
|
13 |
+
#if model=="dev":
|
14 |
+
num_inference_steps=30
|
15 |
+
if model=="schnell":
|
16 |
+
num_inference_steps=5
|
17 |
+
|
18 |
+
if lora_model is not None:
|
19 |
+
api_path=model_dict[lora_model]
|
20 |
+
|
21 |
+
inputs={
|
22 |
+
"model": model,
|
23 |
+
"prompt": prompt,
|
24 |
+
"lora_scale":lora_scale,
|
25 |
+
"aspect_ratio": aspect_ratio,
|
26 |
+
"num_outputs":num_outputs,
|
27 |
+
"num_inference_steps":num_inference_steps,
|
28 |
+
"guidance_scale":guidance_scale,
|
29 |
+
}
|
30 |
+
if seed is not None:
|
31 |
+
inputs["seed"]=seed
|
32 |
+
output = replicate.run(
|
33 |
+
api_path,
|
34 |
+
input=inputs
|
35 |
+
)
|
36 |
+
print(output)
|
37 |
+
if gallery is None:
|
38 |
+
gallery=[]
|
39 |
+
gallery.append(output[0])
|
40 |
+
return output[0],gallery
|
41 |
+
|
42 |
+
|
43 |
+
def replicate_caption_api(image,model,context_text):
|
44 |
+
base64_image = image_to_base64(image)
|
45 |
+
if model=="blip":
|
46 |
+
output = replicate.run(
|
47 |
+
"andreasjansson/blip-2:f677695e5e89f8b236e52ecd1d3f01beb44c34606419bcc19345e046d8f786f9",
|
48 |
+
input={
|
49 |
+
"image": base64_image,
|
50 |
+
"caption": True,
|
51 |
+
"question": context_text,
|
52 |
+
"temperature": 1,
|
53 |
+
"use_nucleus_sampling": False
|
54 |
+
}
|
55 |
+
)
|
56 |
+
print(output)
|
57 |
+
|
58 |
+
elif model=="llava-16":
|
59 |
+
output = replicate.run(
|
60 |
+
# "yorickvp/llava-13b:80537f9eead1a5bfa72d5ac6ea6414379be41d4d4f6679fd776e9535d1eb58bb",
|
61 |
+
"yorickvp/llava-v1.6-34b:41ecfbfb261e6c1adf3ad896c9066ca98346996d7c4045c5bc944a79d430f174",
|
62 |
+
input={
|
63 |
+
"image": base64_image,
|
64 |
+
"top_p": 1,
|
65 |
+
"prompt": context_text,
|
66 |
+
"max_tokens": 1024,
|
67 |
+
"temperature": 0.2
|
68 |
+
}
|
69 |
+
)
|
70 |
+
print(output)
|
71 |
+
output = "".join(output)
|
72 |
+
|
73 |
+
elif model=="img2prompt":
|
74 |
+
output = replicate.run(
|
75 |
+
"methexis-inc/img2prompt:50adaf2d3ad20a6f911a8a9e3ccf777b263b8596fbd2c8fc26e8888f8a0edbb5",
|
76 |
+
input={
|
77 |
+
"image":base64_image
|
78 |
+
}
|
79 |
+
)
|
80 |
+
print(output)
|
81 |
+
return output
|
82 |
+
|
83 |
+
def update_replicate_api_key(api_key):
|
84 |
+
os.environ["REPLICATE_API_TOKEN"] = api_key
|
85 |
+
return f"Replicate API key updated: {api_key[:5]}..." if api_key else "Replicate API key cleared"
|
86 |
+
|
87 |
+
|
88 |
+
def virtual_try_on(crop, seed, steps, category, garm_img, human_img, garment_des):
|
89 |
+
output = replicate.run(
|
90 |
+
"cuuupid/idm-vton:906425dbca90663ff5427624839572cc56ea7d380343d13e2a4c4b09d3f0c30f",
|
91 |
+
input={
|
92 |
+
"crop": crop,
|
93 |
+
"seed": seed,
|
94 |
+
"steps": steps,
|
95 |
+
"category": category,
|
96 |
+
# "force_dc": force_dc,
|
97 |
+
"garm_img": numpy_to_base64( garm_img),
|
98 |
+
"human_img": numpy_to_base64(human_img),
|
99 |
+
#"mask_only": mask_only,
|
100 |
+
"garment_des": garment_des
|
101 |
+
}
|
102 |
+
)
|
103 |
+
print(output)
|
104 |
+
return output
|
105 |
+
|
106 |
+
|
107 |
+
from src.utils import create_zip
|
108 |
+
from PIL import Image
|
109 |
+
|
110 |
+
|
111 |
+
def process_images(files,model,context_text,token_string):
|
112 |
+
images = []
|
113 |
+
textbox =""
|
114 |
+
for file in files:
|
115 |
+
print(file)
|
116 |
+
image = Image.open(file)
|
117 |
+
if model=="None":
|
118 |
+
caption="[Insert cap here]"
|
119 |
+
else:
|
120 |
+
caption = replicate_caption_api(image,model,context_text)
|
121 |
+
textbox += f"Tags: {caption}, file: " + os.path.basename(file) + "\n"
|
122 |
+
images.append(image)
|
123 |
+
#texts.append(textbox)
|
124 |
+
zip_path=create_zip(files,textbox,token_string)
|
125 |
+
|
126 |
+
return images, textbox,zip_path
|
127 |
+
|
128 |
+
def replicate_create_model(owner,name,visibility="private",hardware="gpu-a40-large"):
|
129 |
+
try:
|
130 |
+
model = replicate.models.create(
|
131 |
+
owner=owner,
|
132 |
+
name=name,
|
133 |
+
visibility=visibility,
|
134 |
+
hardware=hardware,
|
135 |
+
)
|
136 |
+
print(model)
|
137 |
+
return True
|
138 |
+
except Exception as e:
|
139 |
+
print(e)
|
140 |
+
if "A model with that name and owner already exists" in str(e):
|
141 |
+
return True
|
142 |
+
return False
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
def traning_function(zip_path,training_model,training_destination,seed,token_string,max_train_steps,hf_repo_id=None,hf_token=None):
|
147 |
+
##Place holder for now
|
148 |
+
BB_bucket_name="jarvisdataset"
|
149 |
+
BB_defult="https://f005.backblazeb2.com/file/"
|
150 |
+
if BB_defult not in zip_path:
|
151 |
+
zip_path=BB_uploadfile(zip_path,os.path.basename(zip_path),BB_bucket_name)
|
152 |
+
print(zip_path)
|
153 |
+
training_logs = f"Using zip traning file at: {zip_path}\n"
|
154 |
+
yield training_logs, None
|
155 |
+
input={
|
156 |
+
"steps": max_train_steps,
|
157 |
+
"lora_rank": 16,
|
158 |
+
"batch_size": 1,
|
159 |
+
"autocaption": True,
|
160 |
+
"trigger_word": token_string,
|
161 |
+
"learning_rate": 0.0004,
|
162 |
+
"seed": seed,
|
163 |
+
"input_images": zip_path
|
164 |
+
}
|
165 |
+
print(training_destination)
|
166 |
+
username,model_name=training_destination.split("/")
|
167 |
+
assert replicate_create_model(username,model_name,visibility="private",hardware="gpu-a40-large"),"Error in creating model on replicate, check API key and username is correct "
|
168 |
+
|
169 |
+
print(input)
|
170 |
+
try:
|
171 |
+
training = replicate.trainings.create(
|
172 |
+
destination=training_destination,
|
173 |
+
version="ostris/flux-dev-lora-trainer:1296f0ab2d695af5a1b5eeee6e8ec043145bef33f1675ce1a2cdb0f81ec43f02",
|
174 |
+
input=input,
|
175 |
+
)
|
176 |
+
|
177 |
+
training_logs = f"Training started with model: {training_model}\n"
|
178 |
+
training_logs += f"Destination: {training_destination}\n"
|
179 |
+
training_logs += f"Seed: {seed}\n"
|
180 |
+
training_logs += f"Token string: {token_string}\n"
|
181 |
+
training_logs += f"Max train steps: {max_train_steps}\n"
|
182 |
+
|
183 |
+
# Poll the training status
|
184 |
+
while training.status != "succeeded":
|
185 |
+
training.reload()
|
186 |
+
training_logs += f"Training status: {training.status}\n"
|
187 |
+
training_logs += f"{training.logs}\n"
|
188 |
+
if training.status == "failed":
|
189 |
+
training_logs += "Training failed!\n"
|
190 |
+
return training_logs, training
|
191 |
+
|
192 |
+
yield training_logs, None
|
193 |
+
time.sleep(10) # Wait for 10 seconds before checking again
|
194 |
+
|
195 |
+
training_logs += "Training completed!\n"
|
196 |
+
if hf_repo_id and hf_token:
|
197 |
+
training_logs += f"Uploading to Hugging Face repo: {hf_repo_id}\n"
|
198 |
+
# Here you would implement the logic to upload to Hugging Face
|
199 |
+
|
200 |
+
traning_finnal=training.output
|
201 |
+
|
202 |
+
# In a real scenario, you might want to download and display some result images
|
203 |
+
# For now, we'll just return the original images
|
204 |
+
#images = [Image.open(file) for file in files]
|
205 |
+
_= update_model_dicts(traning_finnal["version"],token_string,style_json="model_dict.json")
|
206 |
+
traning_finnal["replicate_link"]="https://replicate.com/"+traning_finnal["version"].replace(":","/")
|
207 |
+
yield training_logs, traning_finnal
|
208 |
+
|
209 |
+
except Exception as e:
|
210 |
+
yield f"An error occurred: {str(e)}", None
|