|
from fastapi import FastAPI, Request |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import PlainTextResponse, JSONResponse |
|
from loguru import logger |
|
|
|
class OpenAIError(Exception): |
|
pass |
|
|
|
class APIError(OpenAIError): |
|
message: str |
|
code: str = None |
|
param: str = None |
|
type: str = None |
|
|
|
def __init__(self, message: str, code: int = 500, param: str = None, internal_message: str = ''): |
|
super().__init__(message) |
|
self.message = message |
|
self.code = code |
|
self.param = param |
|
self.type = self.__class__.__name__, |
|
self.internal_message = internal_message |
|
|
|
def __repr__(self): |
|
return "%s(message=%r, code=%d, param=%s)" % ( |
|
self.__class__.__name__, |
|
self.message, |
|
self.code, |
|
self.param, |
|
) |
|
|
|
class InternalServerError(APIError): |
|
pass |
|
|
|
class ServiceUnavailableError(APIError): |
|
def __init__(self, message="Service unavailable, please try again later.", code=503, internal_message=''): |
|
super().__init__(message, code, internal_message) |
|
|
|
class APIStatusError(APIError): |
|
status_code: int = 400 |
|
|
|
def __init__(self, message: str, param: str = None, internal_message: str = ''): |
|
super().__init__(message, self.status_code, param, internal_message) |
|
|
|
class BadRequestError(APIStatusError): |
|
status_code: int = 400 |
|
|
|
class AuthenticationError(APIStatusError): |
|
status_code: int = 401 |
|
|
|
class PermissionDeniedError(APIStatusError): |
|
status_code: int = 403 |
|
|
|
class NotFoundError(APIStatusError): |
|
status_code: int = 404 |
|
|
|
class ConflictError(APIStatusError): |
|
status_code: int = 409 |
|
|
|
class UnprocessableEntityError(APIStatusError): |
|
status_code: int = 422 |
|
|
|
class RateLimitError(APIStatusError): |
|
status_code: int = 429 |
|
|
|
class OpenAIStub(FastAPI): |
|
def __init__(self, **kwargs) -> None: |
|
super().__init__(**kwargs) |
|
self.models = {} |
|
|
|
self.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"] |
|
) |
|
|
|
@self.exception_handler(Exception) |
|
def openai_exception_handler(request: Request, exc: Exception) -> JSONResponse: |
|
|
|
|
|
|
|
return JSONResponse(status_code=500, content={ |
|
'message': 'InternalServerError', |
|
'code': 500, |
|
}) |
|
|
|
@self.exception_handler(APIError) |
|
def openai_apierror_handler(request: Request, exc: APIError) -> JSONResponse: |
|
|
|
logger.opt(exception=exc).error("Logging exception traceback") |
|
|
|
if exc.internal_message: |
|
logger.info(exc.internal_message) |
|
|
|
return JSONResponse(status_code = exc.code, content={ |
|
'message': exc.message, |
|
'code': exc.code, |
|
'type': exc.__class__.__name__, |
|
'param': exc.param, |
|
}) |
|
|
|
@self.exception_handler(APIStatusError) |
|
def openai_statuserror_handler(request: Request, exc: APIStatusError) -> JSONResponse: |
|
|
|
logger.info(repr(exc)) |
|
|
|
if exc.internal_message: |
|
logger.info(exc.internal_message) |
|
|
|
return JSONResponse(status_code = exc.code, content={ |
|
'message': exc.message, |
|
'code': exc.code, |
|
'type': exc.__class__.__name__, |
|
'param': exc.param, |
|
}) |
|
|
|
@self.middleware("http") |
|
async def log_requests(request: Request, call_next): |
|
logger.debug(f"Request path: {request.url.path}") |
|
logger.debug(f"Request method: {request.method}") |
|
logger.debug(f"Request headers: {request.headers}") |
|
logger.debug(f"Request query params: {request.query_params}") |
|
logger.debug(f"Request body: {await request.body()}") |
|
|
|
response = await call_next(request) |
|
|
|
logger.debug(f"Response status code: {response.status_code}") |
|
logger.debug(f"Response headers: {response.headers}") |
|
|
|
return response |
|
|
|
@self.get('/v1/billing/usage') |
|
@self.get('/v1/dashboard/billing/usage') |
|
async def handle_billing_usage(): |
|
return { 'total_usage': 0 } |
|
|
|
@self.get("/", response_class=PlainTextResponse) |
|
@self.head("/", response_class=PlainTextResponse) |
|
@self.options("/", response_class=PlainTextResponse) |
|
async def root(): |
|
return PlainTextResponse(content="", status_code=200 if self.models else 503) |
|
|
|
@self.get("/health") |
|
async def health(): |
|
return {"status": "ok" if self.models else "unk" } |
|
|
|
@self.get("/v1/models") |
|
async def get_model_list(): |
|
return self.model_list() |
|
|
|
@self.get("/v1/models/{model}") |
|
async def get_model_info(model_id: str): |
|
return self.model_info(model_id) |
|
|
|
def register_model(self, name: str, model: str = None) -> None: |
|
self.models[name] = model if model else name |
|
|
|
def deregister_model(self, name: str) -> None: |
|
if name in self.models: |
|
del self.models[name] |
|
|
|
def model_info(self, model: str) -> dict: |
|
result = { |
|
"id": model, |
|
"object": "model", |
|
"created": 0, |
|
"owned_by": "user" |
|
} |
|
return result |
|
|
|
def model_list(self) -> dict: |
|
if not self.models: |
|
return {} |
|
|
|
result = { |
|
"object": "list", |
|
"data": [ self.model_info(model) for model in list(set(self.models.keys() | self.models.values())) if model ] |
|
} |
|
|
|
return result |
|
|