Plain is headed towards 1.0! Subscribe for development updates →

  1from typing import TYPE_CHECKING
  2
  3from plain import models
  4from plain.auth import get_user_model
  5from plain.exceptions import ValidationError
  6from plain.models import transaction
  7from plain.models.db import IntegrityError, OperationalError, ProgrammingError
  8from plain.preflight import Error
  9from plain.runtime import SettingsReference
 10from plain.utils import timezone
 11
 12from .exceptions import OAuthUserAlreadyExistsError
 13
 14if TYPE_CHECKING:
 15    from .providers import OAuthToken, OAuthUser
 16
 17
 18# TODO preflight check for deploy that ensures all provider keys in db are also in settings?
 19
 20
 21@models.register_model
 22class OAuthConnection(models.Model):
 23    created_at = models.DateTimeField(auto_now_add=True)
 24    updated_at = models.DateTimeField(auto_now=True)
 25
 26    user = models.ForeignKey(
 27        SettingsReference("AUTH_USER_MODEL"),
 28        on_delete=models.CASCADE,
 29        related_name="oauth_connections",
 30    )
 31
 32    # The key used to refer to this provider type (in settings)
 33    provider_key = models.CharField(max_length=100)
 34
 35    # The unique ID of the user on the provider's system
 36    provider_user_id = models.CharField(max_length=100)
 37
 38    # Token data
 39    access_token = models.CharField(max_length=2000)
 40    refresh_token = models.CharField(max_length=2000, required=False)
 41    access_token_expires_at = models.DateTimeField(required=False, allow_null=True)
 42    refresh_token_expires_at = models.DateTimeField(required=False, allow_null=True)
 43
 44    class Meta:
 45        constraints = [
 46            models.UniqueConstraint(
 47                fields=["provider_key", "provider_user_id"],
 48                name="plainoauth_oauthconnection_unique_provider_key_user_id",
 49            )
 50        ]
 51        ordering = ("provider_key",)
 52
 53    def __str__(self):
 54        return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
 55
 56    def refresh_access_token(self) -> None:
 57        from .providers import OAuthToken, get_oauth_provider_instance
 58
 59        provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
 60        oauth_token = OAuthToken(
 61            access_token=self.access_token,
 62            refresh_token=self.refresh_token,
 63            access_token_expires_at=self.access_token_expires_at,
 64            refresh_token_expires_at=self.refresh_token_expires_at,
 65        )
 66        refreshed_oauth_token = provider_instance.refresh_oauth_token(
 67            oauth_token=oauth_token
 68        )
 69        self.set_token_fields(refreshed_oauth_token)
 70        self.save()
 71
 72    def set_token_fields(self, oauth_token: "OAuthToken"):
 73        self.access_token = oauth_token.access_token
 74        self.refresh_token = oauth_token.refresh_token
 75        self.access_token_expires_at = oauth_token.access_token_expires_at
 76        self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
 77
 78    def set_user_fields(self, oauth_user: "OAuthUser"):
 79        self.provider_user_id = oauth_user.id
 80
 81    def access_token_expired(self) -> bool:
 82        return (
 83            self.access_token_expires_at is not None
 84            and self.access_token_expires_at < timezone.now()
 85        )
 86
 87    def refresh_token_expired(self) -> bool:
 88        return (
 89            self.refresh_token_expires_at is not None
 90            and self.refresh_token_expires_at < timezone.now()
 91        )
 92
 93    @classmethod
 94    def get_or_create_user(
 95        cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser"
 96    ) -> "OAuthConnection":
 97        try:
 98            connection = cls.objects.get(
 99                provider_key=provider_key,
100                provider_user_id=oauth_user.id,
101            )
102            connection.set_token_fields(oauth_token)
103            connection.save()
104            return connection
105        except cls.DoesNotExist:
106            with transaction.atomic():
107                # If email needs to be unique, then we expect
108                # that to be taken care of on the user model itself
109                try:
110                    user = get_user_model()(
111                        **oauth_user.user_model_fields,
112                    )
113                    user.save()
114                except (IntegrityError, ValidationError):
115                    raise OAuthUserAlreadyExistsError()
116
117                return cls.connect(
118                    user=user,
119                    provider_key=provider_key,
120                    oauth_token=oauth_token,
121                    oauth_user=oauth_user,
122                )
123
124    @classmethod
125    def connect(
126        cls,
127        *,
128        user,
129        provider_key: str,
130        oauth_token: "OAuthToken",
131        oauth_user: "OAuthUser",
132    ) -> "OAuthConnection":
133        """
134        Connect will either create a new connection or update an existing connection
135        """
136        try:
137            connection = cls.objects.get(
138                user=user,
139                provider_key=provider_key,
140                provider_user_id=oauth_user.id,
141            )
142        except cls.DoesNotExist:
143            # Create our own instance (not using get_or_create)
144            # so that any created signals contain the token fields too
145            connection = cls(
146                user=user,
147                provider_key=provider_key,
148                provider_user_id=oauth_user.id,
149            )
150
151        connection.set_user_fields(oauth_user)
152        connection.set_token_fields(oauth_token)
153        connection.save()
154
155        return connection
156
157    @classmethod
158    def check(cls, **kwargs):
159        """
160        A system check for ensuring that provider_keys in the database are also present in settings.
161
162        Note that the --database flag is required for this to work:
163          python manage.py check --database default
164        """
165        errors = super().check(**kwargs)
166
167        databases = kwargs.get("databases", None)
168        if not databases:
169            return errors
170
171        from .providers import get_provider_keys
172
173        for database in databases:
174            try:
175                keys_in_db = set(
176                    cls.objects.using(database)
177                    .values_list("provider_key", flat=True)
178                    .distinct()
179                )
180            except (OperationalError, ProgrammingError):
181                # Check runs on manage.py migrate, and the table may not exist yet
182                # or it may not be installed on the particular database intentionally
183                continue
184
185            keys_in_settings = set(get_provider_keys())
186
187            if keys_in_db - keys_in_settings:
188                errors.append(
189                    Error(
190                        "The following OAuth providers are in the database but not in the settings: {}".format(
191                            ", ".join(keys_in_db - keys_in_settings)
192                        ),
193                        id="plain.oauth.E001",
194                    )
195                )
196
197        return errors