hf-repo-info / pages /2_download_tb.py
hahunavth's picture
fix: msg done after download all
54c857e
raw
history blame
No virus
3.22 kB
import os
import shutil
import pandas as pd
import streamlit as st
# from streamlit_tensorboard import st_tensorboard
from huggingface_hub import list_models
from huggingface_hub import HfApi
# ==============================================================
st.set_page_config(layout="wide")
# ==============================================================
logdir="/tmp/tensorboard_logs"
os.makedirs(logdir, exist_ok=True)
def clean_logdir(logdir):
try:
shutil.rmtree(logdir)
except Exception as e:
print(e)
@st.cache_resource
def get_models():
_author = "hahunavth"
_filter = "emofs2"
return list(
list_models(author=_author, filter=_filter, sort="last_modified", direction=-1)
)
TB_FILE_PREFIX = "events.out.tfevents"
def download_repoo_tb(repo_id, api, log_dir, df):
repo_name = repo_id.split("/")[-1]
if api.repo_exists(repo_id):
files = api.list_repo_files(repo_id)
else:
raise ValueError(f"Repo {repo_id} does not exist")
tb_files = [f for f in files if f.split('/')[-1].startswith(TB_FILE_PREFIX)]
tb_files_info = list(api.list_files_info(repo_id, tb_files))
tb_files_info = [f for f in tb_files_info if f.size > 0]
for repo_file in tb_files_info:
path = repo_file.path
size = repo_file.size
stage = path.split('/')[-2]
fname = path.split('/')[-1]
sub_folder = path.replace(f"/{fname}", '')
if ((df["repo"]==repo_name) & (df["path"]==path) & (df["size"]==size)).any() and os.path.exists(os.path.join(log_dir, repo_name, path)):
print(f"Skipping {repo_name}/{path}")
continue
else:
print(f"Downloading {repo_name}/{path}")
api.hf_hub_download(repo_id=repo_id, filename=fname, subfolder=sub_folder, local_dir=os.path.join(log_dir, repo_name))
new_df = pd.DataFrame([{
"repo": repo_name,
"path": path,
"size": size,
}])
df = pd.concat([df, new_df], ignore_index=True)
return df
@st.cache_resource
def create_cache_dataframe():
return pd.DataFrame(columns=["repo", "path", "size"])
# ==============================================================
api = HfApi()
df = create_cache_dataframe()
models = get_models()
model_ids = [model.id for model in models]
# select many
with st.expander("Download tf", expanded=True):
with st.form("my_form"):
selected_models = st.multiselect("Select models", model_ids, default=None)
submit = st.form_submit_button("Download logs")
if submit:
# download tensorboard logs
with st.spinner("Downloading logs..."):
for model_id in selected_models:
st.write(f"Downloading logs for {model_id}")
df = download_repoo_tb(model_id, api, logdir, df)
st.write("Done")
clean_btn = st.button("Clean all")
if clean_btn:
clean_logdir(logdir)
create_cache_dataframe.clear()
get_models.clear()
# with st.expander("...", expanded=True):
# st_tensorboard(logdir=logdir, port=6006, width=1760, scrolling=False)
# st.text(st)