import secrets
from base64 import b32encode

from django.conf import settings
from django.contrib.auth import login
from django.urls import reverse
from django.utils import timezone
from django.utils.http import url_has_allowed_host_and_scheme
from django_otp.plugins.otp_totp.models import TOTPDevice

from clients.models import ClientPortalAccess

from .models import RecoveryCode, SecurityEvent, UserProfile

SESSION_2FA_VERIFIED = "k4w_2fa_verified"
SESSION_PREAUTH_USER = "k4w_preauth_user_id"
SESSION_PREAUTH_BACKEND = "k4w_preauth_backend"
SESSION_PREAUTH_NEXT = "k4w_preauth_next"


def get_profile(user):
    profile, _ = UserProfile.objects.get_or_create(user=user)
    return profile


def confirmed_device(user):
    if not getattr(user, "is_authenticated", False):
        return None
    return TOTPDevice.objects.filter(user=user, confirmed=True).order_by("pk").first()


def pending_device(user):
    return TOTPDevice.objects.filter(user=user, confirmed=False).order_by("-pk").first()


def two_factor_enabled(user):
    return confirmed_device(user) is not None


def two_factor_required(user):
    if not getattr(user, "is_authenticated", False):
        return False
    if user.is_staff and getattr(settings, "K4W_REQUIRE_2FA_FOR_STAFF", False):
        return True
    return get_profile(user).two_factor_required


def user_destination(user):
    if user.is_staff or user.is_superuser:
        return reverse("admin:index")
    try:
        access = user.client_portal_access
    except Exception:
        return reverse("home")
    if access.is_active and access.client.portal_access_enabled:
        return reverse("client_portal:dashboard")
    return reverse("home")


def safe_next(request, fallback=None):
    candidate = request.POST.get("next") or request.GET.get("next") or ""
    if candidate and url_has_allowed_host_and_scheme(
        candidate,
        allowed_hosts={request.get_host()},
        require_https=request.is_secure(),
    ):
        return candidate
    return fallback


def client_ip(request):
    forwarded = request.META.get("HTTP_X_FORWARDED_FOR", "")
    if forwarded:
        return forwarded.split(",")[0].strip()
    return request.META.get("REMOTE_ADDR") or None


def log_event(request, event_type, *, user=None, username_entered="", details=None):
    return SecurityEvent.objects.create(
        user=user,
        event_type=event_type,
        username_entered=username_entered[:254],
        ip_address=client_ip(request),
        user_agent=request.META.get("HTTP_USER_AGENT", "")[:500],
        details=details or {},
    )


def remember_preauth(request, user, next_url=""):
    request.session[SESSION_PREAUTH_USER] = user.pk
    request.session[SESSION_PREAUTH_BACKEND] = getattr(
        user,
        "backend",
        settings.AUTHENTICATION_BACKENDS[0],
    )
    request.session[SESSION_PREAUTH_NEXT] = next_url or user_destination(user)
    request.session.modified = True


def clear_preauth(request):
    for key in (SESSION_PREAUTH_USER, SESSION_PREAUTH_BACKEND, SESSION_PREAUTH_NEXT):
        request.session.pop(key, None)


def complete_login(request, user, *, backend=None, next_url=None):
    login(
        request,
        user,
        backend=backend or getattr(user, "backend", settings.AUTHENTICATION_BACKENDS[0]),
    )
    request.session[SESSION_2FA_VERIFIED] = True
    profile = get_profile(user)
    profile.last_two_factor_at = timezone.now() if two_factor_enabled(user) else profile.last_two_factor_at
    profile.save(update_fields=["last_two_factor_at", "updated_at"])
    try:
        access = user.client_portal_access
    except Exception:
        access = None
    if access:
        access.last_portal_login_at = timezone.now()
        if not access.accepted_at:
            access.accepted_at = timezone.now()
        access.save(update_fields=["last_portal_login_at", "accepted_at", "updated_at"])
    clear_preauth(request)
    return next_url or user_destination(user)


def generate_recovery_codes(user, count=10):
    RecoveryCode.objects.filter(user=user).delete()
    raw_codes = []
    for _ in range(count):
        raw = f"{secrets.token_hex(2).upper()}-{secrets.token_hex(2).upper()}-{secrets.token_hex(2).upper()}"
        RecoveryCode.create_for_user(user, raw)
        raw_codes.append(raw)
    return raw_codes


def consume_recovery_code(user, raw_code):
    normalized = (raw_code or "").strip().upper()
    for code in RecoveryCode.objects.filter(user=user, used_at__isnull=True):
        if code.matches(normalized):
            code.consume()
            return True
    return False


def manual_totp_secret(device):
    return b32encode(device.bin_key).decode("ascii").rstrip("=")
