auth: FIX websocket updates only to authenticated peers #23
This commit is contained in:
parent
c60381c59a
commit
812f891c8f
1 changed files with 18 additions and 9 deletions
27
app.py
27
app.py
|
|
@ -19,7 +19,7 @@ DESCRIPTION = os.getenv("DESCRIPTION", "powered by aukpad.com")
|
||||||
# Valkey/Redis client (initialized later if enabled)
|
# Valkey/Redis client (initialized later if enabled)
|
||||||
redis_client = None
|
redis_client = None
|
||||||
|
|
||||||
# In-memory rooms: {doc_id: {"text": str, "ver": int, "peers": set[WebSocket], "last_access": float, "pw_hash": bytes|None, "pw_salt": bytes|None}}
|
# In-memory rooms: {doc_id: {"text": str, "ver": int, "peers": set[WebSocket], "authed_peers": set[WebSocket], "last_access": float, "pw_hash": bytes|None, "pw_salt": bytes|None}}
|
||||||
rooms: dict[str, dict] = {}
|
rooms: dict[str, dict] = {}
|
||||||
|
|
||||||
# Rate limiting: {ip: [timestamp, timestamp, ...]}
|
# Rate limiting: {ip: [timestamp, timestamp, ...]}
|
||||||
|
|
@ -404,7 +404,7 @@ function connect(){
|
||||||
errEl.textContent = msg.message;
|
errEl.textContent = msg.message;
|
||||||
errEl.style.display = "block";
|
errEl.style.display = "block";
|
||||||
}
|
}
|
||||||
} else if (msg.type === "update" && msg.ver > ver && msg.clientId !== clientId) {
|
} else if (msg.type === "update" && isAuthed && msg.ver > ver && msg.clientId !== clientId) {
|
||||||
const {selectionStart:s, selectionEnd:e} = ta;
|
const {selectionStart:s, selectionEnd:e} = ta;
|
||||||
const oldText = ta.value;
|
const oldText = ta.value;
|
||||||
ta.value = msg.text; ver = msg.ver; updateGutter();
|
ta.value = msg.text; ver = msg.ver; updateGutter();
|
||||||
|
|
@ -615,8 +615,8 @@ async def create_pad_with_content(request: Request):
|
||||||
raise HTTPException(status_code=413, detail=f"Content too large. Max size: {MAX_TEXT_SIZE} bytes")
|
raise HTTPException(status_code=413, detail=f"Content too large. Max size: {MAX_TEXT_SIZE} bytes")
|
||||||
|
|
||||||
doc_id = random_id()
|
doc_id = random_id()
|
||||||
rooms[doc_id] = {"text": text_content, "ver": 1, "peers": set(), "last_access": time.time(),
|
rooms[doc_id] = {"text": text_content, "ver": 1, "peers": set(), "authed_peers": set(),
|
||||||
"pw_hash": None, "pw_salt": None}
|
"last_access": time.time(), "pw_hash": None, "pw_salt": None}
|
||||||
|
|
||||||
# Save to cache if enabled
|
# Save to cache if enabled
|
||||||
save_room_data_to_cache(doc_id, rooms[doc_id])
|
save_room_data_to_cache(doc_id, rooms[doc_id])
|
||||||
|
|
@ -641,6 +641,7 @@ def get_raw_pad_content(doc_id: str, pw: str = ""):
|
||||||
"text": cached_data.get("text", ""),
|
"text": cached_data.get("text", ""),
|
||||||
"ver": cached_data.get("ver", 0),
|
"ver": cached_data.get("ver", 0),
|
||||||
"peers": set(),
|
"peers": set(),
|
||||||
|
"authed_peers": set(),
|
||||||
"last_access": time.time(),
|
"last_access": time.time(),
|
||||||
"pw_hash": cached_data.get("pw_hash"),
|
"pw_hash": cached_data.get("pw_hash"),
|
||||||
"pw_salt": cached_data.get("pw_salt"),
|
"pw_salt": cached_data.get("pw_salt"),
|
||||||
|
|
@ -662,12 +663,13 @@ def get_raw_pad_content(doc_id: str, pw: str = ""):
|
||||||
update_room_access_time(doc_id)
|
update_room_access_time(doc_id)
|
||||||
return PlainTextResponse(room["text"])
|
return PlainTextResponse(room["text"])
|
||||||
|
|
||||||
async def _broadcast(doc_id: str, message: dict, exclude: WebSocket | None = None):
|
async def _broadcast(doc_id: str, message: dict, exclude: WebSocket | None = None, authed_only: bool = False):
|
||||||
room = rooms.get(doc_id)
|
room = rooms.get(doc_id)
|
||||||
if not room: return
|
if not room: return
|
||||||
dead = []
|
dead = []
|
||||||
payload = json.dumps(message)
|
payload = json.dumps(message)
|
||||||
for peer in room["peers"]:
|
targets = room["authed_peers"] if authed_only else room["peers"]
|
||||||
|
for peer in list(targets):
|
||||||
if peer is exclude:
|
if peer is exclude:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
|
|
@ -676,6 +678,7 @@ async def _broadcast(doc_id: str, message: dict, exclude: WebSocket | None = Non
|
||||||
dead.append(peer)
|
dead.append(peer)
|
||||||
for d in dead:
|
for d in dead:
|
||||||
room["peers"].discard(d)
|
room["peers"].discard(d)
|
||||||
|
room["authed_peers"].discard(d)
|
||||||
|
|
||||||
@app.websocket("/ws/{doc_id}")
|
@app.websocket("/ws/{doc_id}")
|
||||||
async def ws(doc_id: str, ws: WebSocket):
|
async def ws(doc_id: str, ws: WebSocket):
|
||||||
|
|
@ -698,13 +701,14 @@ async def ws(doc_id: str, ws: WebSocket):
|
||||||
"text": cached_data.get("text", ""),
|
"text": cached_data.get("text", ""),
|
||||||
"ver": cached_data.get("ver", 0),
|
"ver": cached_data.get("ver", 0),
|
||||||
"peers": set(),
|
"peers": set(),
|
||||||
|
"authed_peers": set(),
|
||||||
"last_access": time.time(),
|
"last_access": time.time(),
|
||||||
"pw_hash": cached_data.get("pw_hash"),
|
"pw_hash": cached_data.get("pw_hash"),
|
||||||
"pw_salt": cached_data.get("pw_salt"),
|
"pw_salt": cached_data.get("pw_salt"),
|
||||||
}
|
}
|
||||||
|
|
||||||
room = rooms.setdefault(doc_id, {"text": "", "ver": 0, "peers": set(), "last_access": time.time(),
|
room = rooms.setdefault(doc_id, {"text": "", "ver": 0, "peers": set(), "authed_peers": set(),
|
||||||
"pw_hash": None, "pw_salt": None})
|
"last_access": time.time(), "pw_hash": None, "pw_salt": None})
|
||||||
room["peers"].add(ws)
|
room["peers"].add(ws)
|
||||||
|
|
||||||
# Update access time
|
# Update access time
|
||||||
|
|
@ -715,6 +719,8 @@ async def ws(doc_id: str, ws: WebSocket):
|
||||||
|
|
||||||
# Per-connection auth state: already authed if pad has no password
|
# Per-connection auth state: already authed if pad has no password
|
||||||
authed = room["pw_hash"] is None
|
authed = room["pw_hash"] is None
|
||||||
|
if authed:
|
||||||
|
room["authed_peers"].add(ws)
|
||||||
|
|
||||||
# Send init; withhold text if protected and not yet authed
|
# Send init; withhold text if protected and not yet authed
|
||||||
await ws.send_text(json.dumps({
|
await ws.send_text(json.dumps({
|
||||||
|
|
@ -732,6 +738,7 @@ async def ws(doc_id: str, ws: WebSocket):
|
||||||
if data.get("type") == "auth":
|
if data.get("type") == "auth":
|
||||||
if room["pw_hash"] is None:
|
if room["pw_hash"] is None:
|
||||||
authed = True
|
authed = True
|
||||||
|
room["authed_peers"].add(ws)
|
||||||
await ws.send_text(json.dumps({"type": "auth_ok"}))
|
await ws.send_text(json.dumps({"type": "auth_ok"}))
|
||||||
await ws.send_text(json.dumps({
|
await ws.send_text(json.dumps({
|
||||||
"type": "init", "text": room["text"], "ver": room["ver"], "protected": False,
|
"type": "init", "text": room["text"], "ver": room["ver"], "protected": False,
|
||||||
|
|
@ -742,6 +749,7 @@ async def ws(doc_id: str, ws: WebSocket):
|
||||||
)
|
)
|
||||||
if candidate == room["pw_hash"]:
|
if candidate == room["pw_hash"]:
|
||||||
authed = True
|
authed = True
|
||||||
|
room["authed_peers"].add(ws)
|
||||||
await ws.send_text(json.dumps({"type": "auth_ok"}))
|
await ws.send_text(json.dumps({"type": "auth_ok"}))
|
||||||
await ws.send_text(json.dumps({
|
await ws.send_text(json.dumps({
|
||||||
"type": "init", "text": room["text"], "ver": room["ver"], "protected": True,
|
"type": "init", "text": room["text"], "ver": room["ver"], "protected": True,
|
||||||
|
|
@ -774,7 +782,7 @@ async def ws(doc_id: str, ws: WebSocket):
|
||||||
"text": room["text"],
|
"text": room["text"],
|
||||||
"ver": room["ver"],
|
"ver": room["ver"],
|
||||||
"clientId": data.get("clientId"),
|
"clientId": data.get("clientId"),
|
||||||
})
|
}, authed_only=True)
|
||||||
|
|
||||||
elif data.get("type") == "set_password":
|
elif data.get("type") == "set_password":
|
||||||
pw = str(data.get("password", ""))
|
pw = str(data.get("password", ""))
|
||||||
|
|
@ -792,6 +800,7 @@ async def ws(doc_id: str, ws: WebSocket):
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
room["peers"].discard(ws)
|
room["peers"].discard(ws)
|
||||||
|
room["authed_peers"].discard(ws)
|
||||||
await _broadcast(doc_id, {"type": "peers_changed", "count": len(room["peers"])})
|
await _broadcast(doc_id, {"type": "peers_changed", "count": len(room["peers"])})
|
||||||
# Decrement connection count for this IP
|
# Decrement connection count for this IP
|
||||||
connections_per_ip[client_ip] = max(0, connections_per_ip[client_ip] - 1)
|
connections_per_ip[client_ip] = max(0, connections_per_ip[client_ip] - 1)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue