diff --git a/robyn/router.py b/robyn/router.py index d6301776..83af186f 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -47,65 +47,60 @@ def __init__(self) -> None: super().__init__() self.routes: List[Route] = [] + def _format_tuple_response(self, res: tuple) -> Response: + if len(res) != 3: + raise ValueError("Tuple should have 3 elements") + + description, headers, status_code = res + description = self._format_response(description).description + new_headers: Headers = Headers(headers) + if new_headers.contains("Content-Type"): + headers.set("Content-Type", new_headers.get("Content-Type")) + + return Response( + status_code=status_code, + headers=headers, + description=description, + ) + def _format_response( self, res: Union[Dict, Response, bytes, tuple, str], ) -> Response: - headers = Headers({"Content-Type": "text/plain"}) + if isinstance(res, Response): + return res - response = {} if isinstance(res, dict): - # this should change - headers = Headers({}) - if "Content-Type" not in headers: - headers.set("Content-Type", "application/json") - - description = jsonify(res) - - response = Response( + return Response( status_code=status_codes.HTTP_200_OK, - headers=headers, - description=description, + headers=Headers({"Content-Type": "application/json"}), + description=jsonify(res), ) - elif isinstance(res, Response): - response = res - elif isinstance(res, FileResponse): - response = Response( + + if isinstance(res, FileResponse): + response: Response = Response( status_code=res.status_code, headers=res.headers, description=res.file_path, ) response.file_path = res.file_path + return response - elif isinstance(res, bytes): - headers = Headers({"Content-Type": "application/octet-stream"}) - response = Response( + if isinstance(res, bytes): + return Response( status_code=status_codes.HTTP_200_OK, - headers=headers, + headers=Headers({"Content-Type": "application/octet-stream"}), description=res, ) - elif isinstance(res, tuple): - if len(res) != 3: - raise ValueError("Tuple should have 3 elements") - else: - description, headers, status_code = res - description = self._format_response(description).description - new_headers = Headers(headers) - if "Content-Type" in new_headers: - headers.set("Content-Type", new_headers.get("Content-Type")) - - response = Response( - status_code=status_code, - headers=headers, - description=description, - ) - else: - response = Response( - status_code=status_codes.HTTP_200_OK, - headers=headers, - description=str(res).encode("utf-8"), - ) - return response + + if isinstance(res, tuple): + return self._format_tuple_response(tuple(res)) + + return Response( + status_code=status_codes.HTTP_200_OK, + headers=Headers({"Content-Type": "text/plain"}), + description=str(res).encode("utf-8"), + ) def add_route( self, @@ -291,7 +286,7 @@ def add_auth_middleware(self, endpoint: str): This method adds an authentication middleware to the specified endpoint. """ - injected_dependencies = {} + injected_dependencies: dict = {} def decorator(handler): @wraps(handler) @@ -320,7 +315,7 @@ def inner_handler(request: Request, *args): # Arguments are returned as they could be modified by the middlewares. def add_middleware(self, middleware_type: MiddlewareType, endpoint: Optional[str]) -> Callable[..., None]: # no dependency injection here - injected_dependencies = {} + injected_dependencies: dict = {} def inner(handler): @wraps(handler) @@ -383,7 +378,7 @@ def get_global_middlewares(self) -> List[GlobalMiddleware]: class WebSocketRouter(BaseRouter): def __init__(self) -> None: super().__init__() - self.routes = {} + self.routes: dict = {} def add_route(self, endpoint: str, web_socket: WebSocket) -> None: self.routes[endpoint] = web_socket