#!/usr/bin/env python3
"""BMAD Story Lint Checker - BMM Playbook compliance.

Checks:
- required section headers exist
- Status field exists
- each AC has Verification Command and Expected Result
- tasks/subtasks contain explicit (AC: X) references
"""

from __future__ import annotations

import re
import sys
from pathlib import Path
from typing import List, Tuple


REQUIRED_SECTIONS = [
    r"^# Story .+$",
    r"^Status: .+$",
    r"^## Story$",
    r"^## Acceptance Criteria$",
    r"^## Tasks / Subtasks$",
    r"^## Dev Notes$",
    r"^### Project Structure Notes$",
    r"^### References$",
    r"^## Dev Agent Record$",
    r"^### Context Reference$",
    r"^### Agent Model Used$",
    r"^### Debug Log References$",
    r"^### Completion Notes List$",
    r"^### File List$",
]


AC_BLOCK_RE = re.compile(
    r"(?ms)^\*\*AC-(\d+).+?\*\*.*?(?=^\*\*AC-|\Z)"
)
VERIFY_RE = re.compile(r"(?m)^- Verification Command:\s+`.+`$")
EXPECTED_RE = re.compile(r"(?m)^- Expected Result:\s+.+$")
TASK_AC_RE = re.compile(r"\(AC:\s*([^)]+)\)")


def _has_required_sections(text: str) -> List[str]:
    missing = []
    for pat in REQUIRED_SECTIONS:
        if re.search(pat, text, flags=re.MULTILINE) is None:
            missing.append(pat)
    return missing


def _lint_acs(text: str) -> List[str]:
    errors = []
    blocks = AC_BLOCK_RE.findall(text)
    # We don't just count; we validate blocks by scanning the whole section.
    ac_section = split_section(text, "## Acceptance Criteria")
    if ac_section is None:
        return ["Missing '## Acceptance Criteria' section (cannot lint ACs)."]

    # Identify AC blocks by regex slices
    ac_blocks = re.findall(r"(?ms)^\*\*AC-\d+.*?(?=^\*\*AC-|\Z)", ac_section)
    if not ac_blocks:
        errors.append("No AC blocks found. Expect format '**AC-1 — title**' etc.")
        return errors

    for b in ac_blocks:
        if VERIFY_RE.search(b) is None:
            errors.append("AC missing '- Verification Command: `...`'.")
        if EXPECTED_RE.search(b) is None:
            errors.append("AC missing '- Expected Result: ...'.")
        # Optional: enforce exactly one verification command per AC
        if len(VERIFY_RE.findall(b)) != 1:
            errors.append("AC must have exactly 1 Verification Command line.")
    return errors


def split_section(text: str, header: str) -> str | None:
    # Very small markdown section splitter: finds header and returns until next '## ' header.
    m = re.search(rf"(?m)^{re.escape(header)}\s*$", text)
    if not m:
        return None
    start = m.end()
    rest = text[start:]
    n = re.search(r"(?m)^##\s+", rest)
    if not n:
        return rest.strip()
    return rest[: n.start()].strip()


def _lint_tasks(text: str) -> List[str]:
    errors = []
    tasks_section = split_section(text, "## Tasks / Subtasks")
    if tasks_section is None:
        return ["Missing '## Tasks / Subtasks' section (cannot lint task mapping)."]
    if TASK_AC_RE.search(tasks_section) is None:
        errors.append("No '(AC: X)' references found in Tasks / Subtasks.")
    return errors


def lint_story(path: Path) -> Tuple[bool, List[str]]:
    text = path.read_text(encoding="utf-8")
    errors: List[str] = []

    missing = _has_required_sections(text)
    if missing:
        errors.append("Missing required section headers/patterns:")
        errors.extend([f"  - {m}" for m in missing])

    errors.extend(_lint_acs(text))
    errors.extend(_lint_tasks(text))

    return (len(errors) == 0, errors)


def main() -> int:
    base = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("worktrees")
    if not base.exists():
        print(f"ERROR: worktrees dir not found: {base}")
        return 1

    failed = False
    stories = sorted(base.glob("story-*.md"))
    if not stories:
        print("ERROR: No story files found: worktrees/story-*.md")
        return 1

    for md in stories:
        ok, errs = lint_story(md)
        if ok:
            print(f"OK   {md}")
        else:
            failed = True
            print(f"FAIL {md}")
            for e in errs:
                print(f"  - {e}")

    return 1 if failed else 0


if __name__ == "__main__":
    raise SystemExit(main())
