#!/usr/bin/env python3
"""BMAD Code Complexity Gate.

Constraints:
- Max 250 effective lines per .py file (excluding blank/comment-only/docstring lines)
- Max 15 top-level function defs per file
- Max 15 methods per class
- No nested functions
- Functions >40 lines must be split (gate fail; justification in story is process-level, not auto-detected)

Usage:
  python bmad_code_complexity_gate.py .
"""

from __future__ import annotations

import ast
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Set, Tuple

EXCLUDE_DIR_NAMES = {
    ".git", ".venv", "venv", "__pycache__", "node_modules", "dist", "build",
    "worktrees", ".mypy_cache", ".pytest_cache", ".ruff_cache",
}

MAX_EFFECTIVE_LINES = 250
MAX_TOP_LEVEL_FUNCS = 15
MAX_METHODS_PER_CLASS = 15
MAX_FUNC_RAW_SPAN = 40


@dataclass(frozen=True)
class Violation:
    path: Path
    message: str


def iter_py_files(root: Path) -> Iterable[Path]:
    for p in root.rglob("*.py"):
        if any(part in EXCLUDE_DIR_NAMES for part in p.parts):
            continue
        yield p


def _docstring_line_ranges(tree: ast.AST) -> Set[int]:
    doc_lines: Set[int] = set()

    _add_docstring_lines(tree, doc_lines)
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
            _add_docstring_lines(node, doc_lines)

    return doc_lines


def _add_docstring_lines(node: ast.AST, doc_lines: Set[int]) -> None:
    body = getattr(node, "body", None)
    if not isinstance(body, list) or not body:
        return
    first = body[0]
    if (
        isinstance(first, ast.Expr)
        and isinstance(getattr(first, "value", None), ast.Constant)
        and isinstance(first.value.value, str)
    ):
        lineno = getattr(first, "lineno", None)
        end_lineno = getattr(first, "end_lineno", None)
        if lineno is not None and end_lineno is not None:
            for ln in range(lineno, end_lineno + 1):
                doc_lines.add(ln)


def effective_line_count(src: str, doc_lines: Set[int]) -> int:
    lines = src.splitlines()
    effective = 0
    for i, line in enumerate(lines, start=1):
        stripped = line.strip()
        if not stripped:
            continue
        if stripped.startswith("#"):
            continue
        if i in doc_lines:
            continue
        effective += 1
    return effective


def count_top_level_functions(tree: ast.Module) -> int:
    return sum(isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) for n in tree.body)


def class_method_counts(tree: ast.Module) -> List[Tuple[str, int]]:
    out: List[Tuple[str, int]] = []
    for node in tree.body:
        if isinstance(node, ast.ClassDef):
            methods = sum(isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) for n in node.body)
            out.append((node.name, methods))
    return out


def build_parent_map(tree: ast.AST) -> dict[int, ast.AST]:
    parents: dict[int, ast.AST] = {}
    for parent in ast.walk(tree):
        for child in ast.iter_child_nodes(parent):
            parents[id(child)] = parent
    return parents


def find_nested_functions(tree: ast.AST) -> List[Tuple[str, int]]:
    nested = []
    parents = build_parent_map(tree)
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            parent = parents.get(id(node))
            if isinstance(parent, (ast.FunctionDef, ast.AsyncFunctionDef)):
                nested.append((node.name, getattr(node, "lineno", -1)))
    return nested


def find_long_functions(tree: ast.AST) -> List[Tuple[str, int, int]]:
    longf = []
    for node in ast.walk(tree):
        if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
            lineno = getattr(node, "lineno", None)
            end_lineno = getattr(node, "end_lineno", None)
            if lineno is None or end_lineno is None:
                continue
            span = end_lineno - lineno + 1
            if span > MAX_FUNC_RAW_SPAN:
                longf.append((node.name, lineno, span))
    return longf


def check_file(path: Path) -> List[Violation]:
    violations: List[Violation] = []
    src = path.read_text(encoding="utf-8")

    tree = ast.parse(src)
    doc_lines = _docstring_line_ranges(tree)
    eff_lines = effective_line_count(src, doc_lines)

    if eff_lines > MAX_EFFECTIVE_LINES:
        violations.append(Violation(path, f"Effective lines {eff_lines} > {MAX_EFFECTIVE_LINES} (max)."))

    if isinstance(tree, ast.Module):
        top_funcs = count_top_level_functions(tree)
        if top_funcs > MAX_TOP_LEVEL_FUNCS:
            violations.append(Violation(path, f"Top-level functions {top_funcs} > {MAX_TOP_LEVEL_FUNCS} (max)."))

        for cls_name, mcount in class_method_counts(tree):
            if mcount > MAX_METHODS_PER_CLASS:
                violations.append(Violation(path, f"Class '{cls_name}' methods {mcount} > {MAX_METHODS_PER_CLASS} (max)."))

    for name, lineno in find_nested_functions(tree):
        violations.append(Violation(path, f"Nested function '{name}' at line {lineno} (not allowed)."))

    for name, lineno, span in find_long_functions(tree):
        violations.append(Violation(path, f"Function '{name}' span {span} lines (> {MAX_FUNC_RAW_SPAN}) at line {lineno}."))

    return violations


def main() -> int:
    root = Path(sys.argv[1]) if len(sys.argv) > 1 else Path(".")
    files = list(iter_py_files(root))
    all_violations: List[Violation] = []

    for f in files:
        try:
            all_violations.extend(check_file(f))
        except SyntaxError as e:
            all_violations.append(Violation(f, f"SyntaxError: {e.msg} at line {e.lineno}"))
        except Exception as e:
            all_violations.append(Violation(f, f"Error: {e}"))

    if all_violations:
        print("FAIL: Code complexity gate failed:")
        for v in all_violations:
            print(f"  - {v.path}: {v.message}")
        return 1

    print(f"OK: Code complexity gate passed. Checked {len(files)} Python files.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
