Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import os
  4from html.parser import HTMLParser
  5from typing import TYPE_CHECKING, Any
  6from urllib.parse import urlparse, urlunparse
  7
  8import mistune
  9from pygments import highlight
 10from pygments.formatters import html
 11from pygments.lexers import get_lexer_by_name
 12
 13from plain.urls import reverse
 14from plain.utils.text import slugify
 15
 16if TYPE_CHECKING:
 17    from .registry import PagesRegistry
 18
 19
 20class PagesRenderer(mistune.HTMLRenderer):
 21    def __init__(
 22        self, current_page_path: str, pages_registry: PagesRegistry, **kwargs: Any
 23    ):
 24        super().__init__(**kwargs)
 25        self.current_page_path = current_page_path
 26        self.pages_registry = pages_registry
 27
 28    def link(self, text: str, url: str, title: str | None = None) -> str:
 29        """Convert relative markdown links to proper page URLs."""
 30        # Check if it's a relative link (starts with ./ or ../, or is just a filename)
 31        is_relative = url.startswith(("./", "../")) or (
 32            not url.startswith(("http://", "https://", "/", "#")) and ":" not in url
 33        )
 34
 35        if is_relative:
 36            # Parse URL to extract components
 37            parsed_url = urlparse(url)
 38
 39            # Resolve relative to current page's directory using just the path component
 40            current_dir = os.path.dirname(self.current_page_path)
 41            resolved_path = os.path.normpath(os.path.join(current_dir, parsed_url.path))
 42            page = self.pages_registry.get_page_from_path(resolved_path)
 43
 44            # Get the primary URL name for link conversion
 45            url_name = page.get_url_name()
 46            if url_name:
 47                base_url = reverse(f"pages:{url_name}")
 48                # Reconstruct URL with preserved query params and fragment
 49                url = str(
 50                    urlunparse(
 51                        (
 52                            parsed_url.scheme,  # scheme (empty for relative)
 53                            parsed_url.netloc,  # netloc (empty for relative)
 54                            base_url,  # path (our converted URL)
 55                            parsed_url.params,  # params
 56                            parsed_url.query,  # query
 57                            parsed_url.fragment,  # fragment
 58                        )
 59                    )
 60                )
 61
 62        return super().link(text, url, title)
 63
 64    def heading(self, text: str, level: int, **attrs: Any) -> str:
 65        """Automatically add an ID to headings if one is not provided."""
 66
 67        if "id" not in attrs:
 68            inner_text = get_inner_text(text)
 69            inner_text = inner_text.replace(
 70                ".", "-"
 71            )  # Replace dots with hyphens (slugify won't)
 72            attrs["id"] = slugify(inner_text)
 73
 74        return super().heading(text, level, **attrs)
 75
 76    def block_code(self, code: str, info: str | None = None) -> str:
 77        """Highlight code blocks using Pygments."""
 78
 79        if info:
 80            lexer = get_lexer_by_name(info, stripall=True)
 81            formatter = html.HtmlFormatter(wrapcode=True)
 82            return highlight(code, lexer, formatter)
 83
 84        return "<pre><code>" + mistune.escape(code) + "</code></pre>"
 85
 86
 87def render_markdown(content: str, current_page_path: str) -> str:
 88    from .registry import pages_registry
 89
 90    renderer = PagesRenderer(
 91        current_page_path=current_page_path, pages_registry=pages_registry, escape=False
 92    )
 93    markdown = mistune.create_markdown(
 94        renderer=renderer, plugins=["strikethrough", "table"]
 95    )
 96    return markdown(content)  # type: ignore[return-value]
 97
 98
 99class InnerTextParser(HTMLParser):
100    def __init__(self):
101        super().__init__()
102        self.text_content: list[str] = []
103
104    def handle_data(self, data: str) -> None:
105        # Collect all text data
106        self.text_content.append(data.strip())
107
108
109def get_inner_text(html_content: str) -> str:
110    parser = InnerTextParser()
111    parser.feed(html_content)
112    return " ".join([text for text in parser.text_content if text])