1from __future__ import annotations
  2
  3import datetime
  4from typing import TYPE_CHECKING, Any
  5
  6from plain import models
  7from plain.auth import get_user_model
  8from plain.exceptions import ValidationError
  9from plain.models import transaction, types
 10from plain.models.db import IntegrityError
 11from plain.runtime import SettingsReference
 12from plain.utils import timezone
 13
 14from .exceptions import OAuthUserAlreadyExistsError
 15
 16if TYPE_CHECKING:
 17    from .providers import OAuthToken, OAuthUser
 18
 19__all__ = ["OAuthConnection"]
 20
 21
 22@models.register_model
 23class OAuthConnection(models.Model):
 24    created_at: datetime.datetime = types.DateTimeField(auto_now_add=True)
 25    updated_at: datetime.datetime = types.DateTimeField(auto_now=True)
 26
 27    user = types.ForeignKeyField(
 28        SettingsReference("AUTH_USER_MODEL"),
 29        on_delete=models.CASCADE,
 30    )
 31
 32    # The key used to refer to this provider type (in settings)
 33    provider_key: str = types.CharField(max_length=100)
 34
 35    # The unique ID of the user on the provider's system
 36    provider_user_id: str = types.CharField(max_length=100)
 37
 38    # Token data
 39    access_token: str = types.CharField(max_length=2000)
 40    refresh_token: str = types.CharField(max_length=2000, required=False)
 41    access_token_expires_at: datetime.datetime | None = types.DateTimeField(
 42        required=False, allow_null=True
 43    )
 44    refresh_token_expires_at: datetime.datetime | None = types.DateTimeField(
 45        required=False, allow_null=True
 46    )
 47
 48    query: models.QuerySet[OAuthConnection] = models.QuerySet()
 49
 50    model_options = models.Options(
 51        constraints=[
 52            models.UniqueConstraint(
 53                fields=["provider_key", "provider_user_id"],
 54                name="plainoauth_oauthconnection_unique_provider_key_user_id",
 55            )
 56        ],
 57        ordering=("provider_key",),
 58    )
 59
 60    def __str__(self) -> str:
 61        return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
 62
 63    def refresh_access_token(self) -> None:
 64        from .providers import OAuthToken, get_oauth_provider_instance
 65
 66        provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
 67        oauth_token = OAuthToken(
 68            access_token=self.access_token,
 69            refresh_token=self.refresh_token,
 70            access_token_expires_at=self.access_token_expires_at,
 71            refresh_token_expires_at=self.refresh_token_expires_at,
 72        )
 73        refreshed_oauth_token = provider_instance.refresh_oauth_token(
 74            oauth_token=oauth_token
 75        )
 76        self.set_token_fields(refreshed_oauth_token)
 77        self.save()
 78
 79    def set_token_fields(self, oauth_token: OAuthToken) -> None:
 80        self.access_token = oauth_token.access_token
 81        self.refresh_token = oauth_token.refresh_token
 82        self.access_token_expires_at = oauth_token.access_token_expires_at
 83        self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
 84
 85    def set_user_fields(self, oauth_user: OAuthUser) -> None:
 86        self.provider_user_id = oauth_user.provider_id
 87
 88    def access_token_expired(self) -> bool:
 89        return (
 90            self.access_token_expires_at is not None
 91            and self.access_token_expires_at < timezone.now()
 92        )
 93
 94    def refresh_token_expired(self) -> bool:
 95        return (
 96            self.refresh_token_expires_at is not None
 97            and self.refresh_token_expires_at < timezone.now()
 98        )
 99
100    @classmethod
101    def get_or_create_user(
102        cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
103    ) -> OAuthConnection:
104        try:
105            connection = cls.query.get(
106                provider_key=provider_key,
107                provider_user_id=oauth_user.provider_id,
108            )
109            connection.set_token_fields(oauth_token)
110            connection.save()
111            return connection
112        except cls.DoesNotExist:
113            with transaction.atomic():
114                # If email needs to be unique, then we expect
115                # that to be taken care of on the user model itself
116                try:
117                    user = get_user_model()(
118                        **oauth_user.user_model_fields,
119                    )
120                    user.save()
121                except (IntegrityError, ValidationError):
122                    raise OAuthUserAlreadyExistsError()
123
124                return cls.connect(
125                    user=user,
126                    provider_key=provider_key,
127                    oauth_token=oauth_token,
128                    oauth_user=oauth_user,
129                )
130
131    @classmethod
132    def connect(
133        cls,
134        *,
135        user: Any,
136        provider_key: str,
137        oauth_token: OAuthToken,
138        oauth_user: OAuthUser,
139    ) -> OAuthConnection:
140        """
141        Connect will either create a new connection or update an existing connection
142        """
143        try:
144            connection = cls.query.get(
145                user=user,
146                provider_key=provider_key,
147                provider_user_id=oauth_user.provider_id,
148            )
149        except cls.DoesNotExist:
150            # Create our own instance (not using get_or_create)
151            # so that any created signals contain the token fields too
152            connection = cls(
153                user=user,
154                provider_key=provider_key,
155                provider_user_id=oauth_user.provider_id,
156            )
157
158        connection.set_user_fields(oauth_user)
159        connection.set_token_fields(oauth_token)
160        connection.save()
161
162        return connection