v0.148.0
  1from __future__ import annotations
  2
  3from typing import TYPE_CHECKING, Any
  4
  5import psycopg
  6from app.users.models import User
  7
  8from plain import postgres
  9from plain.exceptions import ValidationError
 10from plain.postgres import transaction, types
 11from plain.utils import timezone
 12
 13from .exceptions import OAuthUserAlreadyExistsError
 14
 15if TYPE_CHECKING:
 16    from .providers import OAuthToken, OAuthUser
 17
 18__all__ = ["OAuthConnection"]
 19
 20
 21@postgres.register_model
 22class OAuthConnection(postgres.Model):
 23    created_at = types.DateTimeField(create_now=True)
 24    updated_at = types.DateTimeField(create_now=True, update_now=True)
 25
 26    user = types.ForeignKeyField(
 27        "users.User",
 28        on_delete=postgres.CASCADE,
 29    )
 30
 31    # The key used to refer to this provider type (in settings)
 32    provider_key = types.TextField(max_length=100)
 33
 34    # The unique ID of the user on the provider's system
 35    provider_user_id = types.TextField(max_length=100)
 36
 37    # Token data
 38    access_token = types.EncryptedTextField(max_length=2000)
 39    refresh_token = types.EncryptedTextField(max_length=2000, required=False)
 40    access_token_expires_at = types.DateTimeField(required=False, allow_null=True)
 41    refresh_token_expires_at = types.DateTimeField(required=False, allow_null=True)
 42
 43    query: postgres.QuerySet[OAuthConnection] = postgres.QuerySet()
 44
 45    model_options = postgres.Options(
 46        indexes=[
 47            postgres.Index(
 48                name="plainoauth_oauthconnection_user_id_idx", fields=["user"]
 49            ),
 50        ],
 51        constraints=[
 52            postgres.UniqueConstraint(
 53                fields=["provider_key", "provider_user_id"],
 54                name="plainoauth_oauthconnection_unique_provider_key_user_id",
 55            )
 56        ],
 57        ordering=("provider_key",),
 58    )
 59
 60    def __str__(self) -> str:
 61        return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
 62
 63    def refresh_access_token(self) -> None:
 64        from .providers import OAuthToken, get_oauth_provider_instance
 65
 66        provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
 67        oauth_token = OAuthToken(
 68            access_token=self.access_token,
 69            refresh_token=self.refresh_token,
 70            access_token_expires_at=self.access_token_expires_at,
 71            refresh_token_expires_at=self.refresh_token_expires_at,
 72        )
 73        refreshed_oauth_token = provider_instance.refresh_oauth_token(
 74            oauth_token=oauth_token
 75        )
 76        self.set_token_fields(refreshed_oauth_token)
 77        self.save()
 78
 79    def set_token_fields(self, oauth_token: OAuthToken) -> None:
 80        self.access_token = oauth_token.access_token
 81        self.refresh_token = oauth_token.refresh_token
 82        self.access_token_expires_at = oauth_token.access_token_expires_at
 83        self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
 84
 85    def set_user_fields(self, oauth_user: OAuthUser) -> None:
 86        self.provider_user_id = oauth_user.provider_id
 87
 88    def access_token_expired(self) -> bool:
 89        return (
 90            self.access_token_expires_at is not None
 91            and self.access_token_expires_at < timezone.now()
 92        )
 93
 94    def refresh_token_expired(self) -> bool:
 95        return (
 96            self.refresh_token_expires_at is not None
 97            and self.refresh_token_expires_at < timezone.now()
 98        )
 99
100    @classmethod
101    def get_or_create_user(
102        cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
103    ) -> OAuthConnection:
104        try:
105            connection = cls.query.get(
106                provider_key=provider_key,
107                provider_user_id=oauth_user.provider_id,
108            )
109            connection.set_token_fields(oauth_token)
110            connection.save()
111            return connection
112        except cls.DoesNotExist:
113            # If email needs to be unique, then we expect
114            # that to be taken care of on the user model itself
115            with transaction.atomic():
116                try:
117                    with transaction.atomic():
118                        user = User(
119                            **oauth_user.user_model_fields,
120                        )
121                        user.save()
122                except (psycopg.IntegrityError, ValidationError):
123                    raise OAuthUserAlreadyExistsError(
124                        provider_key=provider_key,
125                        user_model_fields=oauth_user.user_model_fields,
126                    )
127
128                return cls.connect(
129                    user=user,
130                    provider_key=provider_key,
131                    oauth_token=oauth_token,
132                    oauth_user=oauth_user,
133                )
134
135    @classmethod
136    def connect(
137        cls,
138        *,
139        user: Any,
140        provider_key: str,
141        oauth_token: OAuthToken,
142        oauth_user: OAuthUser,
143    ) -> OAuthConnection:
144        """
145        Connect will either create a new connection or update an existing connection
146        """
147        try:
148            connection = cls.query.get(
149                user=user,
150                provider_key=provider_key,
151                provider_user_id=oauth_user.provider_id,
152            )
153        except cls.DoesNotExist:
154            # Create our own instance (not using get_or_create)
155            # so that any created signals contain the token fields too
156            connection = cls(
157                user=user,
158                provider_key=provider_key,
159                provider_user_id=oauth_user.provider_id,
160            )
161
162        connection.set_user_fields(oauth_user)
163        connection.set_token_fields(oauth_token)
164        connection.save()
165
166        return connection