Private
Public Access
0
0

feat(batcher): implement Batch dataclass and plan() function

This commit is contained in:
2026-06-08 00:46:12 -04:00
parent 246f293c56
commit e07036ad5d
+146
View File
@@ -0,0 +1,146 @@
import os
from dataclasses import dataclass
from pathlib import Path
from categorizer import CategoryRecord, FixtureClass, Speed
@dataclass(frozen=True)
class Batch:
tier: str
label: str
files: list[Path]
pytest_args: list[str]
estimated_seconds: float
skip_reason: str | None = None
_TIER_ORDER: tuple[str, ...] = ("0", "1", "2", "3", "H", "P")
_SPEED_SECONDS: dict[str, float] = {
"fast": 0.5,
"medium": 3.0,
"slow": 15.0,
"very_slow": 60.0,
}
def _est(r: CategoryRecord) -> float:
return _SPEED_SECONDS.get(r.speed.value, 3.0)
def _env_set(name: str) -> bool:
return bool(os.environ.get(name))
def _batches_for_unit(records: list[CategoryRecord], xdist: bool) -> list[Batch]:
by_group: dict[str, list[CategoryRecord]] = {}
for r in records:
by_group.setdefault(r.batch_group or "core", []).append(r)
batches: list[Batch] = []
for group in sorted(by_group):
files = [Path("tests") / r.filename for r in by_group[group]]
args: list[str] = ["--maxfail=10"]
if xdist:
args = ["-n", "auto"] + args
batches.append(Batch(
tier="1",
label=f"tier-1-unit-{group}",
files=files,
pytest_args=args,
estimated_seconds=sum(_est(r) for r in by_group[group]),
))
return batches
def _batches_for_mock_app(records: list[CategoryRecord]) -> list[Batch]:
by_group: dict[str, list[CategoryRecord]] = {}
for r in records:
by_group.setdefault(r.batch_group or "core", []).append(r)
batches: list[Batch] = []
for group in sorted(by_group):
files = [Path("tests") / r.filename for r in by_group[group]]
batches.append(Batch(
tier="2",
label=f"tier-2-mock_app-{group}",
files=files,
pytest_args=["--maxfail=5"],
estimated_seconds=sum(_est(r) for r in by_group[group]),
))
return batches
def _batches_for_live_gui(records: list[CategoryRecord]) -> list[Batch]:
if not records:
return []
files = [Path("tests") / r.filename for r in records]
return [Batch(
tier="3",
label="tier-3-live_gui",
files=files,
pytest_args=["--maxfail=1"],
estimated_seconds=sum(_est(r) for r in records),
)]
def _batches_for_headless(records: list[CategoryRecord]) -> list[Batch]:
if not records:
return []
files = [Path("tests") / r.filename for r in records]
return [Batch(
tier="H",
label="tier-H-headless",
files=files,
pytest_args=["--maxfail=5"],
estimated_seconds=sum(_est(r) for r in records),
)]
def _batches_for_performance(records: list[CategoryRecord]) -> list[Batch]:
if not records:
return []
files = [Path("tests") / r.filename for r in records]
return [Batch(
tier="P",
label="tier-P-performance",
files=files,
pytest_args=["--maxfail=1"],
estimated_seconds=sum(_est(r) for r in records),
)]
def _batches_for_opt_in(records: list[CategoryRecord], include_opt_in: bool) -> list[Batch]:
batches: list[Batch] = []
for r in records:
files = [Path("tests") / r.filename]
skip_reason: str | None = None
if not include_opt_in:
skip_reason = "--include-opt-in not set"
elif r.filename.startswith("test_clean_install") and not _env_set("RUN_CLEAN_INSTALL_TEST"):
skip_reason = "RUN_CLEAN_INSTALL_TEST not set"
elif r.filename.startswith("test_docker_build") and not _env_set("RUN_DOCKER_TEST"):
skip_reason = "RUN_DOCKER_TEST not set"
batches.append(Batch(
tier="0",
label=f"tier-0-opt_in-{r.filename.removeprefix('test_').removesuffix('.py')}",
files=files,
pytest_args=["--maxfail=1"],
estimated_seconds=_est(r),
skip_reason=skip_reason,
))
return batches
def plan(
records: list[CategoryRecord],
*,
tiers: set[str] = set(_TIER_ORDER),
include_opt_in: bool = False,
xdist: bool = True,
) -> list[Batch]:
by_fc: dict[FixtureClass, list[CategoryRecord]] = {fc: [] for fc in FixtureClass}
for r in records:
by_fc[r.fixture_class].append(r)
out: list[Batch] = []
if "0" in tiers:
out.extend(_batches_for_opt_in(by_fc[FixtureClass.OPT_IN], include_opt_in))
if "1" in tiers:
out.extend(_batches_for_unit(by_fc[FixtureClass.UNIT], xdist))
if "2" in tiers:
out.extend(_batches_for_mock_app(by_fc[FixtureClass.MOCK_APP]))
if "3" in tiers:
out.extend(_batches_for_live_gui(by_fc[FixtureClass.LIVE_GUI]))
if "H" in tiers:
out.extend(_batches_for_headless(by_fc[FixtureClass.HEADLESS]))
if "P" in tiers:
out.extend(_batches_for_performance(by_fc[FixtureClass.PERFORMANCE]))
out.sort(key=lambda b: (_TIER_ORDER.index(b.tier), b.label))
return out