diff --git a/main.py b/main.py index 290f6bb..f535adf 100644 --- a/main.py +++ b/main.py @@ -64,7 +64,7 @@ def log(level: str, event: str, **kwargs): def get_real_ip(request: Request) -> str: - """Get real client IP for rate limiting (supports reverse proxy)""" + """Get real client IP for rate limiting and logging (supports reverse proxy)""" # Check X-Real-IP header first (set by reverse proxy) x_real_ip = request.headers.get("X-Real-IP") if x_real_ip: @@ -82,7 +82,7 @@ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) async def log_rate_limit(request: Request, exc: RateLimitExceeded): """Custom handler to log rate limit violations""" log("WARNING", "rate_limit_exceeded", - client_ip=get_client_ip(request), + client_ip=get_real_ip(request), user_agent=request.headers.get("User-Agent", "unknown"), endpoint=request.url.path) return await _rate_limit_exceeded_handler(request, exc) @@ -106,12 +106,6 @@ def generate_random_path(length: int = None) -> str: alphabet = string.ascii_letters + string.digits return ''.join(secrets.choice(alphabet) for _ in range(length)) -def get_client_ip(request: Request) -> str: - x_real_ip = request.headers.get("X-Real-IP") - if x_real_ip: - return x_real_ip.strip() - return request.client.host - def validate_upload_token(request: Request) -> bool: """Validate upload token if authentication is enabled""" @@ -122,7 +116,7 @@ def validate_upload_token(request: Request) -> bool: auth = request.headers.get("Authorization", "") if not auth.startswith("Bearer "): log("WARNING", "auth_failed", - client_ip=get_client_ip(request), + client_ip=get_real_ip(request), user_agent=request.headers.get("User-Agent", "unknown"), reason="missing_bearer") raise HTTPException( @@ -136,7 +130,7 @@ def validate_upload_token(request: Request) -> bool: # Use constant-time comparison to prevent timing attacks if not any(secrets.compare_digest(token, valid_token) for valid_token in UPLOAD_TOKENS): log("WARNING", "auth_failed", - client_ip=get_client_ip(request), + client_ip=get_real_ip(request), user_agent=request.headers.get("User-Agent", "unknown"), reason="invalid_token") raise HTTPException( @@ -168,7 +162,7 @@ def validate_content(content: str) -> bool: @limiter.limit(RATE_LIMIT) async def upload_text(request: Request, authorized: bool = Depends(validate_upload_token)): - client_ip = get_client_ip(request) + client_ip = get_real_ip(request) user_agent = request.headers.get("User-Agent", "unknown") body = await request.body() content = body.decode('utf-8', errors='ignore')