import argparse import os import re import time import torch from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video from training.zoo.classifiers import DeepFakeClassifier import gradio as gr def model_fn(model_dir): model_path = os.path.join(model_dir, 'b7_ns_best.pth') model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns") # default: CPU checkpoint = torch.load(model_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=True) model.eval() del checkpoint #models.append(model.half()) return model def convert_result(pred, class_names=["Real", "Fake"]): preds = [pred, 1 - pred] assert len(class_names) == len(preds), "Class / Prediction should have the same length" return {n: float(p) for n, p in zip(class_names, preds)} def predict_fn(video): start = time.time() prediction = predict_on_video(face_extractor=meta["face_extractor"], video_path=video, batch_size=meta["fps"], input_size=meta["input_size"], models=model, strategy=meta["strategy"], apply_compression=False, device='cpu') elapsed_time = round(time.time() - start, 2) prediction = convert_result(prediction) return prediction, elapsed_time # Create title, description and article strings title = "Deepfake Detector (private)" description = "A video Deepfake Classifier (code: https://github.com/selimsef/dfdc_deepfake_challenge)" example_list = ["examples/" + str(p) for p in os.listdir("examples/")] # Environments model_dir = 'weights' frames_per_video = 32 video_reader = VideoReader() video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) face_extractor = FaceExtractor(video_read_fn) input_size = 380 strategy = confident_strategy class_names = ["Real", "Fake"] meta = {"fps": 32, "face_extractor": face_extractor, "input_size": input_size, "strategy": strategy} model = model_fn(model_dir) """ if __name__ == '__main__': video_path = "examples/nlurbvsozt.mp4" model = model_fn(model_dir) a, b = predict_fn(video_path) print(a, b) """ # Create the Gradio demo demo = gr.Interface(fn=predict_fn, # mapping function from input to output inputs=gr.Video(), outputs=[gr.Label(num_top_classes=2, label="Predictions"), # what are the outputs? gr.Number(label="Prediction time (s)")], # our fn has two outputs, therefore we have two outputs examples=example_list, title=title, description=description) # Launch the demo! demo.launch(debug=False,) # Hugging face space don't need shareable_links