Plain is headed towards 1.0! Subscribe for development updates →

  1from typing import TYPE_CHECKING, Any
  2
  3from plain import models
  4from plain.auth import get_user_model
  5from plain.exceptions import ValidationError
  6from plain.models import transaction
  7from plain.models.db import IntegrityError
  8from plain.runtime import SettingsReference
  9from plain.utils import timezone
 10
 11from .exceptions import OAuthUserAlreadyExistsError
 12
 13if TYPE_CHECKING:
 14    from .providers import OAuthToken, OAuthUser
 15
 16
 17@models.register_model
 18class OAuthConnection(models.Model):
 19    created_at = models.DateTimeField(auto_now_add=True)
 20    updated_at = models.DateTimeField(auto_now=True)
 21
 22    user = models.ForeignKey(
 23        SettingsReference("AUTH_USER_MODEL"),
 24        on_delete=models.CASCADE,
 25        related_name="oauth_connections",
 26    )
 27
 28    # The key used to refer to this provider type (in settings)
 29    provider_key = models.CharField(max_length=100)
 30
 31    # The unique ID of the user on the provider's system
 32    provider_user_id = models.CharField(max_length=100)
 33
 34    # Token data
 35    access_token = models.CharField(max_length=2000)
 36    refresh_token = models.CharField(max_length=2000, required=False)
 37    access_token_expires_at = models.DateTimeField(required=False, allow_null=True)
 38    refresh_token_expires_at = models.DateTimeField(required=False, allow_null=True)
 39
 40    model_options = models.Options(
 41        constraints=[
 42            models.UniqueConstraint(
 43                fields=["provider_key", "provider_user_id"],
 44                name="plainoauth_oauthconnection_unique_provider_key_user_id",
 45            )
 46        ],
 47        ordering=("provider_key",),
 48    )
 49
 50    def __str__(self) -> str:
 51        return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
 52
 53    def refresh_access_token(self) -> None:
 54        from .providers import OAuthToken, get_oauth_provider_instance
 55
 56        provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
 57        oauth_token = OAuthToken(
 58            access_token=self.access_token,
 59            refresh_token=self.refresh_token,
 60            access_token_expires_at=self.access_token_expires_at,
 61            refresh_token_expires_at=self.refresh_token_expires_at,
 62        )
 63        refreshed_oauth_token = provider_instance.refresh_oauth_token(
 64            oauth_token=oauth_token
 65        )
 66        self.set_token_fields(refreshed_oauth_token)
 67        self.save()
 68
 69    def set_token_fields(self, oauth_token: "OAuthToken") -> None:
 70        self.access_token = oauth_token.access_token
 71        self.refresh_token = oauth_token.refresh_token
 72        self.access_token_expires_at = oauth_token.access_token_expires_at
 73        self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
 74
 75    def set_user_fields(self, oauth_user: "OAuthUser") -> None:
 76        self.provider_user_id = oauth_user.provider_id
 77
 78    def access_token_expired(self) -> bool:
 79        return (
 80            self.access_token_expires_at is not None
 81            and self.access_token_expires_at < timezone.now()
 82        )
 83
 84    def refresh_token_expired(self) -> bool:
 85        return (
 86            self.refresh_token_expires_at is not None
 87            and self.refresh_token_expires_at < timezone.now()
 88        )
 89
 90    @classmethod
 91    def get_or_create_user(
 92        cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser"
 93    ) -> "OAuthConnection":
 94        try:
 95            connection = cls.query.get(
 96                provider_key=provider_key,
 97                provider_user_id=oauth_user.provider_id,
 98            )
 99            connection.set_token_fields(oauth_token)
100            connection.save()
101            return connection
102        except cls.DoesNotExist:
103            with transaction.atomic():
104                # If email needs to be unique, then we expect
105                # that to be taken care of on the user model itself
106                try:
107                    user = get_user_model()(
108                        **oauth_user.user_model_fields,
109                    )
110                    user.save()
111                except (IntegrityError, ValidationError):
112                    raise OAuthUserAlreadyExistsError()
113
114                return cls.connect(
115                    user=user,
116                    provider_key=provider_key,
117                    oauth_token=oauth_token,
118                    oauth_user=oauth_user,
119                )
120
121    @classmethod
122    def connect(
123        cls,
124        *,
125        user: Any,
126        provider_key: str,
127        oauth_token: "OAuthToken",
128        oauth_user: "OAuthUser",
129    ) -> "OAuthConnection":
130        """
131        Connect will either create a new connection or update an existing connection
132        """
133        try:
134            connection = cls.query.get(
135                user=user,
136                provider_key=provider_key,
137                provider_user_id=oauth_user.provider_id,
138            )
139        except cls.DoesNotExist:
140            # Create our own instance (not using get_or_create)
141            # so that any created signals contain the token fields too
142            connection = cls(
143                user=user,
144                provider_key=provider_key,
145                provider_user_id=oauth_user.provider_id,
146            )
147
148        connection.set_user_fields(oauth_user)
149        connection.set_token_fields(oauth_token)
150        connection.save()
151
152        return connection