gwoe-antragspruefer/tests/test_database.py

555 lines
23 KiB
Python
Raw Permalink Normal View History

"""Unit-Tests für app/database.py (#134 Phase 2).
Alle Tests nutzen eine tmp-Datei als SQLite-DB (via tmp_path-Fixture).
settings.db_path wird per monkeypatch auf die tmp-Datei umgebogen.
Keine Prod-DB wird angetastet.
"""
from __future__ import annotations
import asyncio
import sys
import types
import pytest
# test_mail.py und test_monitoring.py stubben aiosqlite als leeres
# ModuleType-Objekt (ohne .connect). Wenn diese Files zuerst oder parallel
# gesammelt werden, landet der Stub in sys.modules und database.py importiert
# ihn statt des echten Pakets.
#
# Strategie: Stub jetzt entfernen (falls schon drin) und das echte aiosqlite
# importieren. app.database NICHT aus sys.modules entfernen — eine bereits
# importierte Version mit dem echten aiosqlite soll erhalten bleiben.
# Dafür importieren wir aiosqlite und database schon hier auf Modulebene,
# damit die Bindung in database.py auf das echte Paket zeigt, bevor
# andere Test-Files den Stub injizieren.
_aio = sys.modules.get("aiosqlite")
if _aio is not None and not hasattr(_aio, "connect"):
del sys.modules["aiosqlite"]
# Jetzt echtes aiosqlite laden und app.database mit diesem Paket importieren.
# Der Import passiert hier auf Modulebene (Collection-Zeit), also bevor
# test_mail.py / test_monitoring.py ihre Stubs setzen können.
import aiosqlite as _real_aiosqlite # noqa: E402
# App-Package mit echtem aiosqlite importieren und in sys.modules verankern.
# Nachfolgende "from app import database" in Fixtures holen das gecachte Modul.
import importlib as _importlib
if "app.database" in sys.modules:
# Schon gecacht — prüfen ob es das echte aiosqlite hat
_db_mod = sys.modules["app.database"]
if not hasattr(getattr(_db_mod, "aiosqlite", None), "connect"):
# Gecachte Version hat den Stub → neu laden
del sys.modules["app.database"]
_importlib.import_module("app.database")
else:
_importlib.import_module("app.database")
# aiosqlite muss echt importierbar sein — im Test-Env vorhanden,
# aber falls nicht: früh fehlschlagen statt still hängen.
# ─── Hilfsfunktion für synchronen Aufruf ─────────────────────────────────────
def run(coro):
return asyncio.get_event_loop().run_until_complete(coro)
# ─── DB-Fixture ───────────────────────────────────────────────────────────────
@pytest.fixture()
def db_path(tmp_path, monkeypatch):
"""Setzt settings.db_path auf eine frische tmp-Datei und gibt den Pfad zurück."""
path = tmp_path / "test.db"
from app.config import settings
monkeypatch.setattr(settings, "db_path", str(path))
return str(path)
@pytest.fixture()
def initialized_db(db_path):
"""Initialisierte DB — init_db() einmal gelaufen."""
from app import database
run(database.init_db())
return db_path
# ─── Minimaler Assessment-Dict ────────────────────────────────────────────────
def _assessment(drucksache: str = "18/1234", bundesland: str = "NRW",
score: float = 7.5) -> dict:
return {
"drucksache": drucksache,
"title": f"Testantrag {drucksache}",
"fraktionen": ["SPD", "GRÜNE"],
"datum": "2026-04-15",
"link": "https://example.com",
"bundesland": bundesland,
"gwoeScore": score,
"gwoeBegründung": "Gut.",
"gwoeMatrix": [{"dimension": "A1", "score": 5}],
"gwoeSchwerpunkt": ["A1"],
"wahlprogrammScores": [],
"verbesserungen": [],
"stärken": ["Klimaschutz"],
"schwächen": [],
"empfehlung": "Empfohlen",
"empfehlungSymbol": "",
"verbesserungspotenzial": "gering",
"themen": ["Klimaschutz"],
"antragZusammenfassung": "Zusammenfassung.",
"antragKernpunkte": ["Punkt 1"],
"source": "webapp",
"model": "qwen-plus",
"konfidenz": "hoch",
"fehlendeProgramme": [],
}
# ─── init_db ─────────────────────────────────────────────────────────────────
class TestInitDb:
def test_creates_assessments_table(self, db_path):
import aiosqlite
from app import database
run(database.init_db())
async def check():
async with aiosqlite.connect(db_path) as db:
cur = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='assessments'"
)
return await cur.fetchone()
row = run(check())
assert row is not None
def test_creates_jobs_table(self, db_path):
import aiosqlite
from app import database
run(database.init_db())
async def check():
async with aiosqlite.connect(db_path) as db:
cur = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='jobs'"
)
return await cur.fetchone()
assert run(check()) is not None
def test_creates_all_required_tables(self, db_path):
import aiosqlite
from app import database
run(database.init_db())
expected = {
"assessments", "jobs", "bookmarks", "comments", "votes",
"assessment_versions", "email_subscriptions",
"monitoring_scans", "monitoring_daily_summary",
}
async def check():
async with aiosqlite.connect(db_path) as db:
cur = await db.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
)
return {r[0] for r in await cur.fetchall()}
tables = run(check())
assert expected <= tables
def test_idempotent_double_call(self, db_path):
"""init_db() zweimal aufrufen darf keinen Fehler werfen."""
from app import database
run(database.init_db())
run(database.init_db()) # darf nicht werfen
# ─── upsert_assessment / get_assessment ───────────────────────────────────────
class TestUpsertGetAssessment:
def test_round_trip(self, initialized_db):
from app import database
data = _assessment("18/9999")
run(database.upsert_assessment(data))
result = run(database.get_assessment("18/9999"))
assert result is not None
assert result["drucksache"] == "18/9999"
assert result["bundesland"] == "NRW"
def test_title_stored(self, initialized_db):
from app import database
data = _assessment("18/0001")
data["title"] = "Spezieller Titel"
run(database.upsert_assessment(data))
result = run(database.get_assessment("18/0001"))
assert result["title"] == "Spezieller Titel"
def test_gwoe_score_stored(self, initialized_db):
from app import database
data = _assessment("18/0002", score=8.5)
run(database.upsert_assessment(data))
result = run(database.get_assessment("18/0002"))
assert result["gwoe_score"] == 8.5
def test_json_fields_deserialized(self, initialized_db):
from app import database
data = _assessment("18/0003")
run(database.upsert_assessment(data))
result = run(database.get_assessment("18/0003"))
assert isinstance(result["fraktionen"], list)
assert isinstance(result["themen"], list)
def test_missing_assessment_returns_none(self, initialized_db):
from app import database
result = run(database.get_assessment("99/9999"))
assert result is None
def test_upsert_updates_existing(self, initialized_db):
from app import database
data = _assessment("18/0004", score=5.0)
run(database.upsert_assessment(data))
data2 = _assessment("18/0004", score=9.0)
run(database.upsert_assessment(data2))
result = run(database.get_assessment("18/0004"))
assert result["gwoe_score"] == 9.0
def test_upsert_archives_old_version(self, initialized_db):
"""Bei Re-Save wird Vorversion in assessment_versions archiviert."""
import aiosqlite
from app import database
data = _assessment("18/0005", score=5.0)
run(database.upsert_assessment(data))
data2 = _assessment("18/0005", score=7.0)
run(database.upsert_assessment(data2))
async def count_versions():
async with aiosqlite.connect(initialized_db) as db:
cur = await db.execute(
"SELECT COUNT(*) FROM assessment_versions WHERE drucksache='18/0005'"
)
return (await cur.fetchone())[0]
assert run(count_versions()) == 1
# ─── get_all_assessments ──────────────────────────────────────────────────────
class TestGetAllAssessments:
def test_returns_empty_list_initially(self, initialized_db):
from app import database
result = run(database.get_all_assessments())
assert result == []
def test_returns_inserted_assessments(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/1001")))
run(database.upsert_assessment(_assessment("18/1002")))
result = run(database.get_all_assessments())
assert len(result) == 2
def test_bundesland_filter_none_returns_all(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/1003", bundesland="NRW")))
run(database.upsert_assessment(_assessment("18/1004", bundesland="BY")))
result = run(database.get_all_assessments(bundesland=None))
assert len(result) == 2
def test_bundesland_filter_all_returns_all(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/1005", bundesland="NRW")))
run(database.upsert_assessment(_assessment("18/1006", bundesland="BY")))
result = run(database.get_all_assessments(bundesland="ALL"))
assert len(result) == 2
def test_bundesland_filter_nrw_only(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/1007", bundesland="NRW")))
run(database.upsert_assessment(_assessment("18/1008", bundesland="BY")))
result = run(database.get_all_assessments(bundesland="NRW"))
assert len(result) == 1
assert result[0]["bundesland"] == "NRW"
# ─── delete_assessment ────────────────────────────────────────────────────────
class TestDeleteAssessment:
def test_deletes_existing(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/2001")))
result = run(database.delete_assessment("18/2001"))
assert result is True
assert run(database.get_assessment("18/2001")) is None
def test_returns_false_for_nonexistent(self, initialized_db):
from app import database
result = run(database.delete_assessment("99/9999"))
assert result is False
# ─── assessment_versions ─────────────────────────────────────────────────────
class TestAssessmentHistory:
def test_empty_history_for_new_assessment(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/3001")))
history = run(database.get_assessment_history("18/3001"))
assert history == []
def test_history_after_update(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/3002", score=5.0)))
run(database.upsert_assessment(_assessment("18/3002", score=7.0)))
history = run(database.get_assessment_history("18/3002"))
assert len(history) == 1
assert history[0]["gwoe_score"] == 5.0
def test_version_increments_on_multiple_saves(self, initialized_db):
from app import database
run(database.upsert_assessment(_assessment("18/3003", score=4.0)))
run(database.upsert_assessment(_assessment("18/3003", score=6.0)))
run(database.upsert_assessment(_assessment("18/3003", score=8.0)))
history = run(database.get_assessment_history("18/3003"))
assert len(history) == 2
versions = {h["version"] for h in history}
assert versions == {1, 2}
# ─── bookmarks ───────────────────────────────────────────────────────────────
class TestBookmarks:
def test_toggle_adds_bookmark(self, initialized_db):
from app import database
added = run(database.toggle_bookmark("user1", "18/4001"))
assert added is True
def test_toggle_removes_existing_bookmark(self, initialized_db):
from app import database
run(database.toggle_bookmark("user1", "18/4002"))
removed = run(database.toggle_bookmark("user1", "18/4002"))
assert removed is False
def test_get_bookmarks_returns_list(self, initialized_db):
from app import database
run(database.toggle_bookmark("user2", "18/4003"))
run(database.toggle_bookmark("user2", "18/4004"))
bm = run(database.get_bookmarks("user2"))
assert set(bm) == {"18/4003", "18/4004"}
def test_get_bookmarks_empty_for_unknown_user(self, initialized_db):
from app import database
bm = run(database.get_bookmarks("nobody"))
assert bm == []
# ─── monitoring_scans ────────────────────────────────────────────────────────
class TestMonitoringScans:
def test_new_scan_returns_true(self, initialized_db):
from app import database
is_new = run(database.upsert_monitoring_scan(
bundesland="NRW",
drucksache="18/5001",
title="Testantrag",
datum="2026-04-20",
typ="Antrag",
typ_normiert="antrag",
fraktionen=["SPD"],
link="https://example.com",
now="2026-04-20T10:00:00",
))
assert is_new is True
def test_second_upsert_returns_false(self, initialized_db):
from app import database
run(database.upsert_monitoring_scan(
bundesland="NRW", drucksache="18/5002",
title="T", datum="2026-04-20", typ="Antrag",
typ_normiert="antrag", fraktionen=[],
link=None, now="2026-04-20T10:00:00",
))
is_new = run(database.upsert_monitoring_scan(
bundesland="NRW", drucksache="18/5002",
title="T", datum="2026-04-20", typ="Antrag",
typ_normiert="antrag", fraktionen=[],
link=None, now="2026-04-20T11:00:00",
))
assert is_new is False
# ─── monitoring_daily_summary ─────────────────────────────────────────────────
class TestMonitoringDailySummary:
def test_upsert_and_get_summary(self, initialized_db):
from app import database
run(database.upsert_monitoring_summary(
scan_date="2026-04-20",
bundesland="NRW",
total_seen=10,
new_count=3,
errors=None,
))
rows = run(database.get_monitoring_summary("2026-04-20"))
assert len(rows) == 1
assert rows[0]["total_seen"] == 10
assert rows[0]["new_count"] == 3
def test_upsert_summary_updates_on_conflict(self, initialized_db):
from app import database
run(database.upsert_monitoring_summary("2026-04-20", "NRW", 5, 1, None))
run(database.upsert_monitoring_summary("2026-04-20", "NRW", 15, 4, "Fehler"))
rows = run(database.get_monitoring_summary("2026-04-20"))
assert len(rows) == 1
assert rows[0]["total_seen"] == 15
def test_get_summary_empty_for_unknown_date(self, initialized_db):
from app import database
rows = run(database.get_monitoring_summary("1999-01-01"))
assert rows == []
# ─── email_subscriptions ─────────────────────────────────────────────────────
class TestEmailSubscriptions:
def test_create_and_list_subscription(self, initialized_db):
from app import database
sub_id = run(database.create_subscription(
user_id="u1", email="test@example.com",
bundesland="NRW", partei="SPD",
))
assert isinstance(sub_id, int)
subs = run(database.list_subscriptions("u1"))
assert len(subs) == 1
assert subs[0]["email"] == "test@example.com"
def test_delete_subscription_own(self, initialized_db):
from app import database
sub_id = run(database.create_subscription("u2", "a@b.com"))
deleted = run(database.delete_subscription("u2", sub_id))
assert deleted is True
assert run(database.list_subscriptions("u2")) == []
def test_delete_subscription_wrong_user_fails(self, initialized_db):
from app import database
sub_id = run(database.create_subscription("u3", "a@b.com"))
deleted = run(database.delete_subscription("wrong_user", sub_id))
assert deleted is False
def test_get_all_subscriptions_due_empty(self, initialized_db):
from app import database
due = run(database.get_all_subscriptions_due())
assert due == []
# ─── _parse_search_query ─────────────────────────────────────────────────────
class TestParseSearchQuery:
def test_single_term(self):
from app.database import _parse_search_query
terms, is_exact = _parse_search_query("klimaschutz")
assert terms == ["klimaschutz"]
assert is_exact is False
def test_multi_term_split(self):
from app.database import _parse_search_query
terms, is_exact = _parse_search_query("Klimaschutz Energie")
assert terms == ["klimaschutz", "energie"]
assert is_exact is False
def test_exact_phrase_in_quotes(self):
from app.database import _parse_search_query
terms, is_exact = _parse_search_query('"Grüner Stahl"')
assert terms == ["grüner stahl"]
assert is_exact is True
def test_whitespace_stripped(self):
from app.database import _parse_search_query
terms, is_exact = _parse_search_query(" hallo ")
assert terms[0] == "hallo"
# ─── Merkliste (#140) ────────────────────────────────────────────────────────
class TestMerkliste:
def test_add_and_list(self, initialized_db):
from app import database
run(database.merkliste_add("user1", "18/1001"))
run(database.merkliste_add("user1", "18/1002", notiz="Wichtig"))
entries = run(database.merkliste_list("user1"))
ids = [e["antrag_id"] for e in entries]
assert "18/1001" in ids
assert "18/1002" in ids
def test_add_with_notiz(self, initialized_db):
from app import database
run(database.merkliste_add("user1", "18/2001", notiz="Mein Kommentar"))
entries = run(database.merkliste_list("user1"))
match = next((e for e in entries if e["antrag_id"] == "18/2001"), None)
assert match is not None
assert match["notiz"] == "Mein Kommentar"
def test_remove(self, initialized_db):
from app import database
run(database.merkliste_add("user1", "18/3001"))
removed = run(database.merkliste_remove("user1", "18/3001"))
assert removed is True
entries = run(database.merkliste_list("user1"))
assert not any(e["antrag_id"] == "18/3001" for e in entries)
def test_remove_nonexistent_returns_false(self, initialized_db):
from app import database
removed = run(database.merkliste_remove("user1", "18/9999"))
assert removed is False
def test_list_empty_for_unknown_user(self, initialized_db):
from app import database
entries = run(database.merkliste_list("unknown_user"))
assert entries == []
def test_user_isolation(self, initialized_db):
from app import database
run(database.merkliste_add("userA", "18/5001"))
run(database.merkliste_add("userB", "18/5002"))
a_entries = run(database.merkliste_list("userA"))
b_entries = run(database.merkliste_list("userB"))
assert all(e["antrag_id"] == "18/5001" for e in a_entries)
assert all(e["antrag_id"] == "18/5002" for e in b_entries)
def test_upsert_idempotent(self, initialized_db):
from app import database
run(database.merkliste_add("user1", "18/6001"))
run(database.merkliste_add("user1", "18/6001")) # zweites Mal
entries = run(database.merkliste_list("user1"))
dupes = [e for e in entries if e["antrag_id"] == "18/6001"]
assert len(dupes) == 1
def test_bulk_add(self, initialized_db):
from app import database
entries = [
{"antrag_id": "18/7001"},
{"antrag_id": "18/7002", "notiz": "bulk"},
]
count = run(database.merkliste_bulk_add("user1", entries))
assert count == 2
listed = run(database.merkliste_list("user1"))
ids = [e["antrag_id"] for e in listed]
assert "18/7001" in ids
assert "18/7002" in ids
def test_bulk_add_skips_missing_antrag_id(self, initialized_db):
from app import database
entries = [
{"antrag_id": "18/8001"},
{"notiz": "kein antrag_id"}, # soll übersprungen werden
]
count = run(database.merkliste_bulk_add("user1", entries))
assert count == 1
def test_bulk_add_no_duplicates(self, initialized_db):
from app import database
run(database.merkliste_add("user1", "18/9001"))
count = run(database.merkliste_bulk_add("user1", [{"antrag_id": "18/9001"}]))
# Do-Nothing bei Konflikt → zählt trotzdem als verarbeitet
assert count == 1
listed = run(database.merkliste_list("user1"))
assert len([e for e in listed if e["antrag_id"] == "18/9001"]) == 1