vlsp2023-tts / main.py
hahunavth
add source code
614861a
raw
history blame
4.08 kB
from typing import Annotated
from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from starlette.middleware.cors import CORSMiddleware
from fastapi import FastAPI, Header, UploadFile, Depends, HTTPException, status
import base64
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette.responses import JSONResponse
import soundfile as sf
from collections import defaultdict
from model import SynthesisRequest, SynthesisResponse, TransferRequest, TransferResponse, LoginRequest, LoginResponse, BaseResponse
from google_sheet import create_repositories
from login import AuthService
from tts import TTSService
account_repo = create_repositories()
auth_service = AuthService(account_repo=account_repo)
tts_service = TTSService()
app = FastAPI()
auth = HTTPBearer()
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc: HTTPException):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=jsonable_encoder(BaseResponse(status=0, message=exc.detail))
)
@app.exception_handler(RequestValidationError)
def validation_exception_handler(request, exc: RequestValidationError) -> JSONResponse:
reformatted_message = defaultdict(list)
for pydantic_error in exc.errors():
loc, msg = pydantic_error["loc"], pydantic_error["msg"]
filtered_loc = loc[1:] if loc[0] in ("body", "query", "path") else loc
field_string = ".".join(filtered_loc)
reformatted_message[field_string].append(msg)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=jsonable_encoder(BaseResponse(status=0, message="Invalid request", result=reformatted_message))
)
# return JSONResponse(content=jsonable_encoder(BaseResponse(status=0, message="RequestValidationError", result=str(exc))))
async def get_current_user(access_token: Annotated[str, Header(convert_underscores=False)] = None):
if access_token is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token missing")
username = await auth_service.validate_token(access_token)
if not username:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token")
return username
@app.post("/login", response_model=BaseResponse)
async def login(request: LoginRequest):
email = request.email
password = request.password
user = await auth_service.authenticate_user(email, password)
if not user:
raise HTTPException(status_code=400, detail="Incorrect username or password")
else:
encoded_jwt = await auth_service.create_token(email)
return BaseResponse(result={"access_token": encoded_jwt})
@app.post("/test-auth", response_model=BaseResponse)
def test_auth(username: str = Depends(get_current_user)):
return BaseResponse(result={"email": username})
@app.post("/tts/sub-task-1", response_model=BaseResponse)
def synthesis(request: SynthesisRequest): # todo: , username: str = Depends(get_current_user)):
audio_data = tts_service.synthesis(request.input_text)
return BaseResponse(result=SynthesisResponse(data=audio_data))
@app.post("/tts/sub-task-2", response_model=BaseResponse)
async def transfer(input_text: str, ref_audio: UploadFile): # request: TransferRequest # todo: , username: str = Depends(get_current_user)):
if ref_audio.content_type != "audio/mpeg":
raise HTTPException(status_code=400, detail="Only audio files allowed")
# ref_audio_contents = await request.ref_audio.read()
# Convert the audio file to a NumPy array
with sf.SoundFile(ref_audio.file, 'rb') as f:
audio_np_array = f.read(dtype='float32')
audio_out = tts_service.transfer(input_text, audio_np_array)
audio_data = base64.b64encode(audio_out)
return BaseResponse(result=TransferResponse(data=audio_data))
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)