Spaces:
Runtime error
Runtime error
from typing import Annotated | |
from fastapi.encoders import jsonable_encoder | |
from fastapi.exceptions import RequestValidationError | |
from starlette.middleware.cors import CORSMiddleware | |
from fastapi import FastAPI, Header, UploadFile, Depends, HTTPException, status | |
import base64 | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from starlette.responses import JSONResponse | |
import soundfile as sf | |
from collections import defaultdict | |
from model import SynthesisRequest, SynthesisResponse, TransferRequest, TransferResponse, LoginRequest, LoginResponse, BaseResponse | |
from google_sheet import create_repositories | |
from login import AuthService | |
from tts import TTSService | |
account_repo = create_repositories() | |
auth_service = AuthService(account_repo=account_repo) | |
tts_service = TTSService() | |
app = FastAPI() | |
auth = HTTPBearer() | |
async def http_exception_handler(request, exc: HTTPException): | |
return JSONResponse( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
content=jsonable_encoder(BaseResponse(status=0, message=exc.detail)) | |
) | |
def validation_exception_handler(request, exc: RequestValidationError) -> JSONResponse: | |
reformatted_message = defaultdict(list) | |
for pydantic_error in exc.errors(): | |
loc, msg = pydantic_error["loc"], pydantic_error["msg"] | |
filtered_loc = loc[1:] if loc[0] in ("body", "query", "path") else loc | |
field_string = ".".join(filtered_loc) | |
reformatted_message[field_string].append(msg) | |
return JSONResponse( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
content=jsonable_encoder(BaseResponse(status=0, message="Invalid request", result=reformatted_message)) | |
) | |
# return JSONResponse(content=jsonable_encoder(BaseResponse(status=0, message="RequestValidationError", result=str(exc)))) | |
async def get_current_user(access_token: Annotated[str, Header(convert_underscores=False)] = None): | |
if access_token is None: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token missing") | |
username = await auth_service.validate_token(access_token) | |
if not username: | |
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token") | |
return username | |
async def login(request: LoginRequest): | |
email = request.email | |
password = request.password | |
user = await auth_service.authenticate_user(email, password) | |
if not user: | |
raise HTTPException(status_code=400, detail="Incorrect username or password") | |
else: | |
encoded_jwt = await auth_service.create_token(email) | |
return BaseResponse(result={"access_token": encoded_jwt}) | |
def test_auth(username: str = Depends(get_current_user)): | |
return BaseResponse(result={"email": username}) | |
def synthesis(request: SynthesisRequest): # todo: , username: str = Depends(get_current_user)): | |
audio_data = tts_service.synthesis(request.input_text) | |
return BaseResponse(result=SynthesisResponse(data=audio_data)) | |
async def transfer(input_text: str, ref_audio: UploadFile): # request: TransferRequest # todo: , username: str = Depends(get_current_user)): | |
if ref_audio.content_type != "audio/mpeg": | |
raise HTTPException(status_code=400, detail="Only audio files allowed") | |
# ref_audio_contents = await request.ref_audio.read() | |
# Convert the audio file to a NumPy array | |
with sf.SoundFile(ref_audio.file, 'rb') as f: | |
audio_np_array = f.read(dtype='float32') | |
audio_out = tts_service.transfer(input_text, audio_np_array) | |
audio_data = base64.b64encode(audio_out) | |
return BaseResponse(result=TransferResponse(data=audio_data)) | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |