# Copyright © The Debusine Developers
# See the AUTHORS file at the top-level directory of this distribution
#
# This file is part of Debusine. It is subject to the license terms
# in the LICENSE file found in the top-level directory of this
# distribution. No part of Debusine, including this file, may be copied,
# modified, propagated, or distributed except according to the terms
# contained in the LICENSE file.

"""Middleware for maintaining debusine.db.context."""

from collections.abc import Awaitable, Callable
from inspect import iscoroutinefunction
from typing import TYPE_CHECKING

import django.http
from asgiref.typing import ASGIReceiveCallable, ASGISendCallable
from django.utils.deprecation import MiddlewareMixin

if TYPE_CHECKING:
    from channels.consumer import _ChannelScope


class ContextMiddleware(MiddlewareMixin):
    """Run get_response in a private contextvar context."""

    def process_request(
        self,
        request: django.http.HttpRequest,  # noqa: ARG002, U100
    ) -> None:
        """Process request in middleware."""
        from debusine.db.context import context

        # Make application context request-local
        context.reset()

    async def __acall__(
        self, request: django.http.HttpRequest
    ) -> django.http.HttpResponseBase:
        """Asynchronous middleware entry point."""
        # MiddlewareMixin provides __acall__, but its implementation calls
        # process_request within sync_to_async.  In this case, we must call
        # it in the async context: sync_to_async copies context variables
        # that have been set in the async context into and out of its
        # thread, but it doesn't copy out context variables that haven't
        # been set at all in the async context.
        self.process_request(request)
        assert iscoroutinefunction(self.get_response)
        response = await self.get_response(request)
        assert isinstance(response, django.http.HttpResponseBase)
        return response


class ContextMiddlewareChannels:
    """Run the app in a private contextvar context."""

    def __init__(
        self,
        app: Callable[
            ["_ChannelScope", ASGIReceiveCallable, ASGISendCallable],
            Awaitable[None],
        ],
    ) -> None:
        """Middleware API entry point."""
        self.app = app

    async def __call__(
        self,
        scope: "_ChannelScope",
        receive: ASGIReceiveCallable,
        send: ASGISendCallable,
    ) -> None:
        """Middleware entry point."""
        from debusine.db.context import context

        # Make application context request-local
        context.reset()
        return await self.app(scope, receive, send)
