1from __future__ import annotations
  2
  3import json
  4from collections.abc import AsyncIterator
  5from typing import Any
  6
  7from plain.http import AsyncStreamingResponse, Response
  8
  9from .base import View
 10
 11
 12class ServerSentEventsView(View):
 13    """Server-Sent Events view.
 14
 15    Subclass this and implement `stream()` to yield ServerSentEvent instances:
 16
 17        class TimeView(ServerSentEventsView):
 18            async def stream(self):
 19                while True:
 20                    yield ServerSentEvent(data={"time": datetime.now().isoformat()})
 21                    await asyncio.sleep(1)
 22    """
 23
 24    async def get(self) -> AsyncStreamingResponse:
 25        return AsyncStreamingResponse(
 26            streaming_content=self._format_events(),
 27            content_type="text/event-stream",
 28            headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
 29        )
 30
 31    def head(self) -> Response:
 32        return Response(
 33            content_type="text/event-stream",
 34            headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
 35        )
 36
 37    async def stream(self) -> AsyncIterator[ServerSentEvent]:
 38        """Override this to yield ServerSentEvent instances."""
 39        raise NotImplementedError(f"{self.__class__.__name__} must implement stream()")
 40        yield  # noqa: RET503 — unreachable, marks this as an async generator
 41
 42    async def _format_events(self) -> AsyncIterator[str]:
 43        async for event in self.stream():
 44            yield event.format()
 45
 46
 47class ServerSentEvent:
 48    """An SSE event with optional event type, id, and retry fields.
 49
 50    Usage:
 51        yield ServerSentEvent(data="hello")
 52        yield ServerSentEvent(data={"count": 1}, event="update")
 53        yield ServerSentEvent(data="hello", id="msg-1", retry=5000)
 54        yield ServerSentEvent.comment("keepalive")
 55    """
 56
 57    __slots__ = ("_comment", "data", "event", "id", "retry")
 58
 59    def __init__(
 60        self,
 61        data: Any,
 62        *,
 63        event: str | None = None,
 64        id: str | None = None,
 65        retry: int | None = None,
 66    ) -> None:
 67        self._comment: str | None = None
 68        self.data = data
 69        self.event = event
 70        self.id = id
 71        self.retry = retry
 72
 73    @classmethod
 74    def comment(cls, text: str = "") -> ServerSentEvent:
 75        """Create an SSE comment (line starting with ':').
 76
 77        Comments are ignored by EventSource but useful as keepalives
 78        to prevent proxies and browsers from closing idle connections.
 79        """
 80        instance = cls.__new__(cls)
 81        instance._comment = text
 82        instance.data = None
 83        instance.event = None
 84        instance.id = None
 85        instance.retry = None
 86        return instance
 87
 88    def __repr__(self) -> str:
 89        if self._comment is not None:
 90            return f"ServerSentEvent.comment({self._comment!r})"
 91        parts = [repr(self.data)]
 92        if self.event is not None:
 93            parts.append(f"event={self.event!r}")
 94        if self.id is not None:
 95            parts.append(f"id={self.id!r}")
 96        if self.retry is not None:
 97            parts.append(f"retry={self.retry!r}")
 98        return f"ServerSentEvent({', '.join(parts)})"
 99
100    def format(self) -> str:
101        """Format this event as an SSE event string."""
102        # Comment-only event (keepalive)
103        if self._comment is not None:
104            return f": {self._comment}\n\n"
105
106        lines: list[str] = []
107
108        if self.event is not None:
109            lines.append(f"event: {self.event}")
110
111        if self.id is not None:
112            lines.append(f"id: {self.id}")
113
114        if self.retry is not None:
115            lines.append(f"retry: {self.retry}")
116
117        serialized = self.data if isinstance(self.data, str) else json.dumps(self.data)
118
119        # SSE spec: each line of data gets its own "data:" prefix.
120        # Use split("\n") instead of splitlines() to preserve empty and
121        # trailing lines — splitlines() would drop them, altering the
122        # payload clients receive.
123        for line in serialized.split("\n"):
124            lines.append(f"data: {line}")
125
126        # Double newline terminates the event
127        return "\n".join(lines) + "\n\n"