diff --git a/app.py b/app.py index 33b2c84..f9a4b09 100644 --- a/app.py +++ b/app.py @@ -19,7 +19,7 @@ DESCRIPTION = os.getenv("DESCRIPTION", "powered by aukpad.com") # Valkey/Redis client (initialized later if enabled) redis_client = None -# In-memory rooms: {doc_id: {"text": str, "ver": int, "peers": set[WebSocket], "last_access": float, "pw_hash": bytes|None, "pw_salt": bytes|None}} +# In-memory rooms: {doc_id: {"text": str, "ver": int, "peers": set[WebSocket], "authed_peers": set[WebSocket], "last_access": float, "pw_hash": bytes|None, "pw_salt": bytes|None}} rooms: dict[str, dict] = {} # Rate limiting: {ip: [timestamp, timestamp, ...]} @@ -404,7 +404,7 @@ function connect(){ errEl.textContent = msg.message; errEl.style.display = "block"; } - } else if (msg.type === "update" && msg.ver > ver && msg.clientId !== clientId) { + } else if (msg.type === "update" && isAuthed && msg.ver > ver && msg.clientId !== clientId) { const {selectionStart:s, selectionEnd:e} = ta; const oldText = ta.value; ta.value = msg.text; ver = msg.ver; updateGutter(); @@ -615,8 +615,8 @@ async def create_pad_with_content(request: Request): raise HTTPException(status_code=413, detail=f"Content too large. Max size: {MAX_TEXT_SIZE} bytes") doc_id = random_id() - rooms[doc_id] = {"text": text_content, "ver": 1, "peers": set(), "last_access": time.time(), - "pw_hash": None, "pw_salt": None} + rooms[doc_id] = {"text": text_content, "ver": 1, "peers": set(), "authed_peers": set(), + "last_access": time.time(), "pw_hash": None, "pw_salt": None} # Save to cache if enabled save_room_data_to_cache(doc_id, rooms[doc_id]) @@ -641,6 +641,7 @@ def get_raw_pad_content(doc_id: str, pw: str = ""): "text": cached_data.get("text", ""), "ver": cached_data.get("ver", 0), "peers": set(), + "authed_peers": set(), "last_access": time.time(), "pw_hash": cached_data.get("pw_hash"), "pw_salt": cached_data.get("pw_salt"), @@ -662,12 +663,13 @@ def get_raw_pad_content(doc_id: str, pw: str = ""): update_room_access_time(doc_id) return PlainTextResponse(room["text"]) -async def _broadcast(doc_id: str, message: dict, exclude: WebSocket | None = None): +async def _broadcast(doc_id: str, message: dict, exclude: WebSocket | None = None, authed_only: bool = False): room = rooms.get(doc_id) if not room: return dead = [] payload = json.dumps(message) - for peer in room["peers"]: + targets = room["authed_peers"] if authed_only else room["peers"] + for peer in list(targets): if peer is exclude: continue try: @@ -676,6 +678,7 @@ async def _broadcast(doc_id: str, message: dict, exclude: WebSocket | None = Non dead.append(peer) for d in dead: room["peers"].discard(d) + room["authed_peers"].discard(d) @app.websocket("/ws/{doc_id}") async def ws(doc_id: str, ws: WebSocket): @@ -698,13 +701,14 @@ async def ws(doc_id: str, ws: WebSocket): "text": cached_data.get("text", ""), "ver": cached_data.get("ver", 0), "peers": set(), + "authed_peers": set(), "last_access": time.time(), "pw_hash": cached_data.get("pw_hash"), "pw_salt": cached_data.get("pw_salt"), } - room = rooms.setdefault(doc_id, {"text": "", "ver": 0, "peers": set(), "last_access": time.time(), - "pw_hash": None, "pw_salt": None}) + room = rooms.setdefault(doc_id, {"text": "", "ver": 0, "peers": set(), "authed_peers": set(), + "last_access": time.time(), "pw_hash": None, "pw_salt": None}) room["peers"].add(ws) # Update access time @@ -715,6 +719,8 @@ async def ws(doc_id: str, ws: WebSocket): # Per-connection auth state: already authed if pad has no password authed = room["pw_hash"] is None + if authed: + room["authed_peers"].add(ws) # Send init; withhold text if protected and not yet authed await ws.send_text(json.dumps({ @@ -732,6 +738,7 @@ async def ws(doc_id: str, ws: WebSocket): if data.get("type") == "auth": if room["pw_hash"] is None: authed = True + room["authed_peers"].add(ws) await ws.send_text(json.dumps({"type": "auth_ok"})) await ws.send_text(json.dumps({ "type": "init", "text": room["text"], "ver": room["ver"], "protected": False, @@ -742,6 +749,7 @@ async def ws(doc_id: str, ws: WebSocket): ) if candidate == room["pw_hash"]: authed = True + room["authed_peers"].add(ws) await ws.send_text(json.dumps({"type": "auth_ok"})) await ws.send_text(json.dumps({ "type": "init", "text": room["text"], "ver": room["ver"], "protected": True, @@ -774,7 +782,7 @@ async def ws(doc_id: str, ws: WebSocket): "text": room["text"], "ver": room["ver"], "clientId": data.get("clientId"), - }) + }, authed_only=True) elif data.get("type") == "set_password": pw = str(data.get("password", "")) @@ -792,6 +800,7 @@ async def ws(doc_id: str, ws: WebSocket): pass finally: room["peers"].discard(ws) + room["authed_peers"].discard(ws) await _broadcast(doc_id, {"type": "peers_changed", "count": len(room["peers"])}) # Decrement connection count for this IP connections_per_ip[client_ip] = max(0, connections_per_ip[client_ip] - 1)