fix(api): fix rate limiting
This commit is contained in:
		@@ -1,8 +1,6 @@
 | 
			
		||||
from fastapi import APIRouter
 | 
			
		||||
from starlette.requests import Request
 | 
			
		||||
 | 
			
		||||
from src.neo_neo_todo.utils.rate_limit import limiter
 | 
			
		||||
 | 
			
		||||
router = APIRouter(
 | 
			
		||||
    prefix="/control",
 | 
			
		||||
    tags=["control"],
 | 
			
		||||
@@ -10,6 +8,5 @@ router = APIRouter(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.get("/ping")
 | 
			
		||||
@limiter.exempt
 | 
			
		||||
async def ping(request: Request):
 | 
			
		||||
    return {"ping": "pong"}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,14 +1,31 @@
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
 | 
			
		||||
import redis.asyncio as redis
 | 
			
		||||
from fastapi import FastAPI
 | 
			
		||||
from slowapi import _rate_limit_exceeded_handler
 | 
			
		||||
from slowapi.errors import RateLimitExceeded
 | 
			
		||||
from slowapi.middleware import SlowAPIMiddleware
 | 
			
		||||
from fastapi_limiter import FastAPILimiter
 | 
			
		||||
 | 
			
		||||
from src.neo_neo_todo.control import control
 | 
			
		||||
from src.neo_neo_todo.utils.rate_limit import limiter
 | 
			
		||||
from src.neo_neo_todo.sessions import sessions
 | 
			
		||||
from src.neo_neo_todo.utils.database import pool
 | 
			
		||||
 | 
			
		||||
app = FastAPI()
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def lifespan(_: FastAPI):
 | 
			
		||||
    # Set up pool
 | 
			
		||||
    redis_connection = redis.from_url("redis://localhost:6379", encoding="utf8")
 | 
			
		||||
    await FastAPILimiter.init(redis_connection)
 | 
			
		||||
    print("Starting PostgreSQL DB connection pool and Redis from FastAPI")
 | 
			
		||||
    await pool.open()
 | 
			
		||||
 | 
			
		||||
    yield
 | 
			
		||||
    # Clean up the DB pool
 | 
			
		||||
    await pool.close()
 | 
			
		||||
    await FastAPILimiter.close()
 | 
			
		||||
    print("Closing PostgreSQL DB connection pool and Redis from FastAPI")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
app = FastAPI(lifespan=lifespan)
 | 
			
		||||
 | 
			
		||||
# include API routes
 | 
			
		||||
app.include_router(control.router)
 | 
			
		||||
 | 
			
		||||
app.state.limiter = limiter
 | 
			
		||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)  # type: ignore
 | 
			
		||||
app.add_middleware(SlowAPIMiddleware)
 | 
			
		||||
app.include_router(sessions.router)
 | 
			
		||||
 
 | 
			
		||||
@@ -53,7 +53,6 @@ def generate_test_data() -> None:
 | 
			
		||||
            conn.commit()
 | 
			
		||||
 | 
			
		||||
    print("Finished PostgreSQL test data generation")
 | 
			
		||||
    print("=================================================================")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										10
									
								
								backend/src/neo_neo_todo/models/session.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								backend/src/neo_neo_todo/models/session.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,10 @@
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel, EmailStr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class LoginData(BaseModel):
 | 
			
		||||
    email: EmailStr
 | 
			
		||||
    password: str
 | 
			
		||||
    remember: bool = False
 | 
			
		||||
							
								
								
									
										0
									
								
								backend/src/neo_neo_todo/sessions/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								backend/src/neo_neo_todo/sessions/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										20
									
								
								backend/src/neo_neo_todo/sessions/sessions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								backend/src/neo_neo_todo/sessions/sessions.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,20 @@
 | 
			
		||||
from fastapi import APIRouter, Depends
 | 
			
		||||
from fastapi_limiter.depends import RateLimiter
 | 
			
		||||
from starlette.requests import Request
 | 
			
		||||
 | 
			
		||||
from src.neo_neo_todo.models.session import LoginData
 | 
			
		||||
 | 
			
		||||
router = APIRouter(
 | 
			
		||||
    prefix="/sessions",
 | 
			
		||||
    tags=["sessions"],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@router.post("", dependencies=[Depends(RateLimiter(times=100, seconds=600))])
 | 
			
		||||
async def login(request: Request, login_data: LoginData):
 | 
			
		||||
    """
 | 
			
		||||
    Login to the todo app
 | 
			
		||||
 | 
			
		||||
    If successful, save the returned cookie
 | 
			
		||||
    """
 | 
			
		||||
    pass
 | 
			
		||||
@@ -3,6 +3,14 @@ from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
import psycopg
 | 
			
		||||
from psycopg import sql
 | 
			
		||||
from psycopg_pool import AsyncConnectionPool
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    postgres_db_url = os.environ["TODO_DB_DATABASE_URL"]
 | 
			
		||||
except KeyError:
 | 
			
		||||
    raise KeyError("Can't find postgres DB URL")
 | 
			
		||||
 | 
			
		||||
pool = AsyncConnectionPool(postgres_db_url, open=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def recreate_db() -> None:
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +0,0 @@
 | 
			
		||||
from slowapi import Limiter
 | 
			
		||||
from slowapi.util import get_remote_address
 | 
			
		||||
 | 
			
		||||
limiter = Limiter(
 | 
			
		||||
    key_func=get_remote_address,
 | 
			
		||||
    default_limits=["70/minute"],
 | 
			
		||||
    storage_uri="redis://localhost:6379/1",
 | 
			
		||||
)
 | 
			
		||||
		Reference in New Issue
	
	Block a user