Spaces:
Running
Running
from sklearn.cluster import * | |
import os | |
import numpy as np | |
from config import config | |
import yaml | |
import argparse | |
import shutil | |
def ensure_dir(directory): | |
if not os.path.exists(directory): | |
os.makedirs(directory) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("-a", "--algorithm", default="k", help="choose algorithm", type=str) | |
parser.add_argument("-n", "--num_clusters", default=4, help="number of clusters", type=int) | |
parser.add_argument("-r", "--range", default=4, help="number of files in a class", type=int) | |
args = parser.parse_args() | |
filelist_dict = {} | |
yml_result = {} | |
base_dir = "D:/Vits2/Bert-VITS2/Data/BanGDream/filelists" | |
output_dir = "D:/Vits2/classifedSample" | |
with open(os.path.join(base_dir, "Mygo.list"), mode="r", encoding="utf-8") as f: | |
embs = [] | |
wavnames = [] | |
for line in f: | |
parts = line.strip().split("|") | |
speaker = parts[1] # 假设 speaker 是第二个部分 | |
filepath = parts[0] # 假设 filepath 是第一个部分 | |
# ... 其余部分可以根据需要使用 | |
if speaker not in filelist_dict: | |
filelist_dict[speaker] = [] | |
yml_result[speaker] = {} | |
filelist_dict[speaker].append(filepath) | |
for speaker in filelist_dict: | |
print("\nspeaker: " + speaker) | |
embs = [] | |
wavnames = [] | |
for file in filelist_dict[speaker]: | |
try: | |
embs.append(np.expand_dims(np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0)) | |
wavnames.append(file) | |
except Exception as e: | |
print(e) | |
if embs: | |
n_clusters = args.num_clusters | |
x = np.concatenate(embs, axis=0) | |
x = np.squeeze(x) | |
if args.algorithm == "b": | |
model = Birch(n_clusters=n_clusters, threshold=0.2) | |
elif args.algorithm == "s": | |
model = SpectralClustering(n_clusters=n_clusters) | |
elif args.algorithm == "a": | |
model = AgglomerativeClustering(n_clusters=n_clusters) | |
else: | |
model = KMeans(n_clusters=n_clusters, random_state=10) | |
y_predict = model.fit_predict(x) | |
classes = [[] for i in range(y_predict.max() + 1)] | |
for idx, wavname in enumerate(wavnames): | |
classes[y_predict[idx]].append(wavname) | |
for i in range(y_predict.max() + 1): | |
print("类别:", i, "本类中样本数量:", len(classes[i])) | |
yml_result[speaker][f"class{i}"] = [] | |
class_dir = os.path.join(output_dir, speaker, f"class{i}") | |
num_samples_in_class = len(classes[i]) | |
for j in range(min(args.range, num_samples_in_class)): | |
wav_file = classes[i][j] | |
print(wav_file) | |
# 复制文件到新目录 | |
ensure_dir(class_dir) | |
shutil.copy(os.path.join(base_dir, wav_file), os.path.join(class_dir, os.path.basename(wav_file))) | |
yml_result[speaker][f"class{i}"].append(wav_file) | |
with open(os.path.join(base_dir, "emo_clustering.yml"), "w", encoding="utf-8") as f: | |
yaml.dump(yml_result, f) | |
''' | |
from sklearn.cluster import * | |
import os | |
import numpy as np | |
from config import config | |
import yaml | |
import argparse | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"-a", "--algorithm", default="s", help="choose algorithm", type=str | |
) | |
parser.add_argument( | |
"-n", "--num_clusters", default=3, help="number of clusters", type=int | |
) | |
parser.add_argument( | |
"-r", "--range", default=4, help="number of files in a class", type=int | |
) | |
args = parser.parse_args() | |
filelist_dict = {} | |
yml_result = {} | |
with open( | |
"D:/Vits2/Bert-VITS2/Data/BanGDream/filelists/Mygo.list", mode="r", encoding="utf-8" | |
) as f: | |
embs = [] | |
wavnames = [] | |
for line in f: | |
speaker = line.split("|")[1] | |
if speaker not in filelist_dict: | |
filelist_dict[speaker] = [] | |
yml_result[speaker] = {} | |
filelist_dict[speaker].append(line.split("|")[0]) | |
#print(filelist_dict) | |
for speaker in filelist_dict: | |
print("\nspeaker: " + speaker) | |
# 清空 embs 和 wavnames 列表 | |
embs = [] | |
wavnames = [] | |
for file in filelist_dict[speaker]: | |
try: | |
embs.append( | |
np.expand_dims( | |
np.load(f"{os.path.splitext(file)[0]}.emo.npy"), axis=0 | |
) | |
) | |
wavnames.append(os.path.basename(file)) | |
except Exception as e: | |
print(e) | |
if embs: | |
# 聚类算法类的数量 | |
n_clusters = args.num_clusters | |
x = np.concatenate(embs, axis=0) | |
x = np.squeeze(x) | |
# 聚类算法类的数量 | |
n_clusters = args.num_clusters | |
if args.algorithm == "b": | |
model = Birch(n_clusters=n_clusters, threshold=0.2) | |
elif args.algorithm == "s": | |
model = SpectralClustering(n_clusters=n_clusters) | |
elif args.algorithm == "a": | |
model = AgglomerativeClustering(n_clusters=n_clusters) | |
else: | |
model = KMeans(n_clusters=n_clusters, random_state=10) | |
# 可以自行尝试各种不同的聚类算法 | |
y_predict = model.fit_predict(x) | |
classes = [[] for i in range(y_predict.max() + 1)] | |
for idx, wavname in enumerate(wavnames): | |
classes[y_predict[idx]].append(wavname) | |
for i in range(y_predict.max() + 1): | |
print("类别:", i, "本类中样本数量:", len(classes[i])) | |
yml_result[speaker][f"class{i}"] = [] | |
# 修正:确保不会尝试访问超出范围的元素 | |
num_samples_in_class = len(classes[i]) | |
for j in range(min(args.range, num_samples_in_class)): | |
print(classes[i][j]) | |
yml_result[speaker][f"class{i}"].append(classes[i][j]) | |
with open( | |
os.path.join('D:/Vits2/Bert-VITS2/Data/BanGDream', "emo_clustering.yml"), "w", encoding="utf-8" | |
) as f: | |
yaml.dump(yml_result, f) | |
''' |