#!/usr/bin/env python3
"""CI Access Layer compliance linter.

Scans workflow YAML files in .github/workflows/ and fails if any forbidden
cross-repo access pattern is found.

Forbidden patterns:
  1. git clone using GITHUB_TOKEN to access an AbilityBI repo
     (pattern: git clone.*GITHUB_TOKEN.*github.com/AbilityBI)
  2. continue-on-error: true on a provider checkout step
     (pattern: name:.*[Cc]lone.*abi- with continue-on-error: true nearby)

Exit code: 0 = clean, 1 = violations found.
"""
from __future__ import annotations

import re
import sys
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parent.parent
WORKFLOWS_DIR = REPO_ROOT / ".github" / "workflows"

# ── Patterns ────────────────────────────────────────────────────────────────

# Pattern 1: git clone with GITHUB_TOKEN targeting AbilityBI org
_PAT_GITHUB_TOKEN_CLONE = re.compile(
    r"git\s+clone\s+.*GITHUB_TOKEN.*github\.com/AbilityBI",
    re.IGNORECASE,
)

# Pattern 2: continue-on-error adjacent to a provider clone step.
# We look for any `continue-on-error: true` that appears within 3 lines of
# a `run:` or `name:` that references abi- repos.
_PAT_CONTINUE_ON_ERROR = re.compile(r"continue-on-error\s*:\s*true", re.IGNORECASE)
_PAT_PROVIDER_STEP = re.compile(
    r"(git\s+clone.*AbilityBI|checkout-providers|abi-core|abi-policy|abi-architecture"
    r"|abi-constraints|abi-evals|abi-observability|abi-plugins|abi-prd|abi-promotion"
    r"|abi-promptware|abi-runtime|abi-spec|abi-template)",
    re.IGNORECASE,
)

# How many lines before/after a continue-on-error to look for provider context
_WINDOW = 5


def _scan_file(path: Path) -> list[str]:
    """Return a list of violation messages for a single workflow file."""
    violations: list[str] = []
    try:
        text = path.read_text(encoding="utf-8")
    except OSError as exc:
        return [f"WARN: could not read {path}: {exc}"]

    lines = text.splitlines()

    for i, line in enumerate(lines, start=1):
        # Check 1: GITHUB_TOKEN clone
        if _PAT_GITHUB_TOKEN_CLONE.search(line):
            violations.append(
                f"{path.relative_to(REPO_ROOT)}:{i}: "
                f"FORBIDDEN — git clone with GITHUB_TOKEN targeting AbilityBI. "
                f"Use checkout-providers action with ABI_GH_APP_ID/ABI_GH_APP_PRIVATE_KEY."
            )

    # Check 2: continue-on-error near a provider step
    for i, line in enumerate(lines):
        if _PAT_CONTINUE_ON_ERROR.search(line):
            # Check surrounding lines for provider context
            window_start = max(0, i - _WINDOW)
            window_end = min(len(lines), i + _WINDOW + 1)
            context = "\n".join(lines[window_start:window_end])
            if _PAT_PROVIDER_STEP.search(context):
                lineno = i + 1
                violations.append(
                    f"{path.relative_to(REPO_ROOT)}:{lineno}: "
                    f"FORBIDDEN — continue-on-error: true adjacent to a provider checkout step. "
                    f"Provider checkouts must hard-fail on error."
                )

    return violations


def main() -> int:
    print("=== check-provider-access ===")

    if not WORKFLOWS_DIR.exists():
        print(f"  INFO: no .github/workflows/ directory found in {REPO_ROOT} — nothing to check")
        return 0

    workflow_files = sorted(WORKFLOWS_DIR.glob("*.yml")) + sorted(WORKFLOWS_DIR.glob("*.yaml"))
    if not workflow_files:
        print("  INFO: no workflow files found — nothing to check")
        return 0

    print(f"  Scanning {len(workflow_files)} workflow file(s) in {WORKFLOWS_DIR.relative_to(REPO_ROOT)}/")

    all_violations: list[str] = []
    for wf in workflow_files:
        violations = _scan_file(wf)
        all_violations.extend(violations)

    if all_violations:
        print(f"\n  FAIL: {len(all_violations)} violation(s) found:\n")
        for v in all_violations:
            print(f"    {v}")
        print()
        print("  Remediation: replace git clone hacks with checkout-providers action.")
        print("  See docs/CI_ACCESS_LAYER.md for usage.")
        return 1

    print(f"  PASS: no forbidden provider access patterns found")
    return 0


if __name__ == "__main__":
    sys.exit(main())
