|
import argparse |
|
import os |
|
import re |
|
import time |
|
|
|
import torch |
|
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video |
|
from training.zoo.classifiers import DeepFakeClassifier |
|
|
|
import gradio as gr |
|
|
|
def model_fn(model_dir): |
|
model_path = os.path.join(model_dir, 'b7_ns_best.pth') |
|
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") |
|
checkpoint = torch.load(model_path, map_location="cpu") |
|
state_dict = checkpoint.get("state_dict", checkpoint) |
|
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True) |
|
model.eval() |
|
del checkpoint |
|
|
|
|
|
return model |
|
|
|
def convert_result(pred, class_names=["Real", "Fake"]): |
|
preds = [pred, 1 - pred] |
|
assert len(class_names) == len(preds), "Class / Prediction should have the same length" |
|
return {n: float(p) for n, p in zip(class_names, preds)} |
|
|
|
def predict_fn(video): |
|
start = time.time() |
|
prediction = predict_on_video(face_extractor=meta["face_extractor"], |
|
video_path=video, |
|
batch_size=meta["fps"], |
|
input_size=meta["input_size"], |
|
models=model, |
|
strategy=meta["strategy"], |
|
apply_compression=False, |
|
device='cpu') |
|
|
|
elapsed_time = round(time.time() - start, 2) |
|
|
|
prediction = convert_result(prediction) |
|
|
|
return prediction, elapsed_time |
|
|
|
|
|
title = "Deepfake Detector (private)" |
|
description = "A video Deepfake Classifier (code: https://github.com/selimsef/dfdc_deepfake_challenge)" |
|
|
|
example_list = ["examples/" + str(p) for p in os.listdir("examples/")] |
|
|
|
|
|
model_dir = 'weights' |
|
frames_per_video = 32 |
|
video_reader = VideoReader() |
|
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) |
|
face_extractor = FaceExtractor(video_read_fn) |
|
input_size = 380 |
|
strategy = confident_strategy |
|
class_names = ["Real", "Fake"] |
|
|
|
meta = {"fps": 32, |
|
"face_extractor": face_extractor, |
|
"input_size": input_size, |
|
"strategy": strategy} |
|
|
|
model = model_fn(model_dir) |
|
|
|
""" |
|
if __name__ == '__main__': |
|
video_path = "examples/nlurbvsozt.mp4" |
|
model = model_fn(model_dir) |
|
a, b = predict_fn(video_path) |
|
print(a, b) |
|
""" |
|
|
|
demo = gr.Interface(fn=predict_fn, |
|
inputs=gr.Video(), |
|
outputs=[gr.Label(num_top_classes=2, label="Predictions"), |
|
gr.Number(label="Prediction time (s)")], |
|
examples=example_list, |
|
title=title, |
|
description=description) |
|
|
|
|
|
demo.launch(debug=False,) |
|
|