Plain is headed towards 1.0! Subscribe for development updates →

  1from __future__ import annotations
  2
  3import datetime
  4from typing import TYPE_CHECKING, Any
  5
  6from plain import models
  7from plain.auth import get_user_model
  8from plain.exceptions import ValidationError
  9from plain.models import transaction, types
 10from plain.models.db import IntegrityError
 11from plain.runtime import SettingsReference
 12from plain.utils import timezone
 13
 14from .exceptions import OAuthUserAlreadyExistsError
 15
 16if TYPE_CHECKING:
 17    from .providers import OAuthToken, OAuthUser
 18
 19
 20@models.register_model
 21class OAuthConnection(models.Model):
 22    created_at: datetime.datetime = types.DateTimeField(auto_now_add=True)
 23    updated_at: datetime.datetime = types.DateTimeField(auto_now=True)
 24
 25    user = types.ForeignKeyField(
 26        SettingsReference("AUTH_USER_MODEL"),
 27        on_delete=models.CASCADE,
 28    )
 29
 30    # The key used to refer to this provider type (in settings)
 31    provider_key: str = types.CharField(max_length=100)
 32
 33    # The unique ID of the user on the provider's system
 34    provider_user_id: str = types.CharField(max_length=100)
 35
 36    # Token data
 37    access_token: str = types.CharField(max_length=2000)
 38    refresh_token: str = types.CharField(max_length=2000, required=False)
 39    access_token_expires_at: datetime.datetime | None = types.DateTimeField(
 40        required=False, allow_null=True
 41    )
 42    refresh_token_expires_at: datetime.datetime | None = types.DateTimeField(
 43        required=False, allow_null=True
 44    )
 45
 46    query: models.QuerySet[OAuthConnection] = models.QuerySet()
 47
 48    model_options = models.Options(
 49        constraints=[
 50            models.UniqueConstraint(
 51                fields=["provider_key", "provider_user_id"],
 52                name="plainoauth_oauthconnection_unique_provider_key_user_id",
 53            )
 54        ],
 55        ordering=("provider_key",),
 56    )
 57
 58    def __str__(self) -> str:
 59        return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
 60
 61    def refresh_access_token(self) -> None:
 62        from .providers import OAuthToken, get_oauth_provider_instance
 63
 64        provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
 65        oauth_token = OAuthToken(
 66            access_token=self.access_token,
 67            refresh_token=self.refresh_token,
 68            access_token_expires_at=self.access_token_expires_at,
 69            refresh_token_expires_at=self.refresh_token_expires_at,
 70        )
 71        refreshed_oauth_token = provider_instance.refresh_oauth_token(
 72            oauth_token=oauth_token
 73        )
 74        self.set_token_fields(refreshed_oauth_token)
 75        self.save()
 76
 77    def set_token_fields(self, oauth_token: OAuthToken) -> None:
 78        self.access_token = oauth_token.access_token
 79        self.refresh_token = oauth_token.refresh_token
 80        self.access_token_expires_at = oauth_token.access_token_expires_at
 81        self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
 82
 83    def set_user_fields(self, oauth_user: OAuthUser) -> None:
 84        self.provider_user_id = oauth_user.provider_id
 85
 86    def access_token_expired(self) -> bool:
 87        return (
 88            self.access_token_expires_at is not None
 89            and self.access_token_expires_at < timezone.now()
 90        )
 91
 92    def refresh_token_expired(self) -> bool:
 93        return (
 94            self.refresh_token_expires_at is not None
 95            and self.refresh_token_expires_at < timezone.now()
 96        )
 97
 98    @classmethod
 99    def get_or_create_user(
100        cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
101    ) -> OAuthConnection:
102        try:
103            connection = cls.query.get(
104                provider_key=provider_key,
105                provider_user_id=oauth_user.provider_id,
106            )
107            connection.set_token_fields(oauth_token)
108            connection.save()
109            return connection
110        except cls.DoesNotExist:
111            with transaction.atomic():
112                # If email needs to be unique, then we expect
113                # that to be taken care of on the user model itself
114                try:
115                    user = get_user_model()(
116                        **oauth_user.user_model_fields,
117                    )
118                    user.save()
119                except (IntegrityError, ValidationError):
120                    raise OAuthUserAlreadyExistsError()
121
122                return cls.connect(
123                    user=user,
124                    provider_key=provider_key,
125                    oauth_token=oauth_token,
126                    oauth_user=oauth_user,
127                )
128
129    @classmethod
130    def connect(
131        cls,
132        *,
133        user: Any,
134        provider_key: str,
135        oauth_token: OAuthToken,
136        oauth_user: OAuthUser,
137    ) -> OAuthConnection:
138        """
139        Connect will either create a new connection or update an existing connection
140        """
141        try:
142            connection = cls.query.get(
143                user=user,
144                provider_key=provider_key,
145                provider_user_id=oauth_user.provider_id,
146            )
147        except cls.DoesNotExist:
148            # Create our own instance (not using get_or_create)
149            # so that any created signals contain the token fields too
150            connection = cls(
151                user=user,
152                provider_key=provider_key,
153                provider_user_id=oauth_user.provider_id,
154            )
155
156        connection.set_user_fields(oauth_user)
157        connection.set_token_fields(oauth_token)
158        connection.save()
159
160        return connection