from __future__ import annotations

import json
import logging
import os
import sys
from datetime import datetime, timezone
from typing import Any, Dict

from src.errors import UserError

LOG_JSON_ENV_VAR = "EVENTFOTO_LOG_JSON"

_JSON_LOGGING_ENABLED = False

_STANDARD_RECORD_KEYS = {
    "args",
    "asctime",
    "created",
    "exc_info",
    "exc_text",
    "filename",
    "funcName",
    "levelname",
    "levelno",
    "lineno",
    "message",
    "module",
    "msecs",
    "msg",
    "name",
    "pathname",
    "process",
    "processName",
    "relativeCreated",
    "stack_info",
    "thread",
    "threadName",
}


class JsonFormatter(logging.Formatter):
    def format(self, record: logging.LogRecord) -> str:
        payload: Dict[str, Any] = {
            "ts": datetime.fromtimestamp(record.created, tz=timezone.utc).isoformat(),
            "level": record.levelname,
            "logger": record.name,
            "message": record.getMessage(),
        }
        payload.update(_extra_fields(record))
        if record.exc_info:
            payload["exc_info"] = self.formatException(record.exc_info)
        return json.dumps(payload, ensure_ascii=True, sort_keys=True)


def configure_logging(*, cli_log_json: bool | None) -> bool:
    json_enabled = _resolve_log_json(cli_log_json)
    global _JSON_LOGGING_ENABLED
    _JSON_LOGGING_ENABLED = json_enabled
    if not json_enabled:
        return False

    handler = logging.StreamHandler(sys.stderr)
    handler.setFormatter(JsonFormatter())
    logging.basicConfig(level=logging.INFO, handlers=[handler], force=True)
    return True


def log_event(logger: logging.Logger, event: str, *, level: str = "info", **fields: Any) -> None:
    if not _JSON_LOGGING_ENABLED:
        return
    extra = {"event": event, **fields}
    if level == "info":
        logger.info(event, extra=extra)
    elif level == "warning":
        logger.warning(event, extra=extra)
    elif level == "error":
        logger.error(event, extra=extra)
    else:
        raise UserError(f"Invalid log level: {level!r}")


def _resolve_log_json(cli_log_json: bool | None) -> bool:
    if cli_log_json is not None:
        return bool(cli_log_json)
    raw = os.environ.get(LOG_JSON_ENV_VAR, "").strip()
    if not raw:
        return False
    value = raw.lower()
    if value in {"1", "true", "yes", "on"}:
        return True
    if value in {"0", "false", "no", "off"}:
        return False
    raise UserError(
        f"Invalid {LOG_JSON_ENV_VAR} value: {raw!r} (expected 1/0/true/false/yes/no/on/off)."
    )


def _extra_fields(record: logging.LogRecord) -> Dict[str, Any]:
    extras: Dict[str, Any] = {}
    for key, value in record.__dict__.items():
        if key in _STANDARD_RECORD_KEYS:
            continue
        extras[key] = value
    return extras
