1from __future__ import annotations
  2
  3import datetime
  4from typing import TYPE_CHECKING, Any
  5
  6import psycopg
  7from app.users.models import User
  8
  9from plain import postgres
 10from plain.exceptions import ValidationError
 11from plain.postgres import transaction, types
 12from plain.utils import timezone
 13
 14from .exceptions import OAuthUserAlreadyExistsError
 15
 16if TYPE_CHECKING:
 17    from .providers import OAuthToken, OAuthUser
 18
 19__all__ = ["OAuthConnection"]
 20
 21
 22@postgres.register_model
 23class OAuthConnection(postgres.Model):
 24    created_at: datetime.datetime = types.DateTimeField(create_now=True)
 25    updated_at: datetime.datetime = types.DateTimeField(
 26        create_now=True, update_now=True
 27    )
 28
 29    user = types.ForeignKeyField(
 30        "users.User",
 31        on_delete=postgres.CASCADE,
 32    )
 33
 34    # The key used to refer to this provider type (in settings)
 35    provider_key: str = types.TextField(max_length=100)
 36
 37    # The unique ID of the user on the provider's system
 38    provider_user_id: str = types.TextField(max_length=100)
 39
 40    # Token data
 41    access_token: str = types.EncryptedTextField(max_length=2000)
 42    refresh_token: str = types.EncryptedTextField(max_length=2000, required=False)
 43    access_token_expires_at: datetime.datetime | None = types.DateTimeField(
 44        required=False, allow_null=True
 45    )
 46    refresh_token_expires_at: datetime.datetime | None = types.DateTimeField(
 47        required=False, allow_null=True
 48    )
 49
 50    query: postgres.QuerySet[OAuthConnection] = postgres.QuerySet()
 51
 52    model_options = postgres.Options(
 53        indexes=[
 54            postgres.Index(
 55                name="plainoauth_oauthconnection_user_id_idx", fields=["user"]
 56            ),
 57        ],
 58        constraints=[
 59            postgres.UniqueConstraint(
 60                fields=["provider_key", "provider_user_id"],
 61                name="plainoauth_oauthconnection_unique_provider_key_user_id",
 62            )
 63        ],
 64        ordering=("provider_key",),
 65    )
 66
 67    def __str__(self) -> str:
 68        return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
 69
 70    def refresh_access_token(self) -> None:
 71        from .providers import OAuthToken, get_oauth_provider_instance
 72
 73        provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
 74        oauth_token = OAuthToken(
 75            access_token=self.access_token,
 76            refresh_token=self.refresh_token,
 77            access_token_expires_at=self.access_token_expires_at,
 78            refresh_token_expires_at=self.refresh_token_expires_at,
 79        )
 80        refreshed_oauth_token = provider_instance.refresh_oauth_token(
 81            oauth_token=oauth_token
 82        )
 83        self.set_token_fields(refreshed_oauth_token)
 84        self.save()
 85
 86    def set_token_fields(self, oauth_token: OAuthToken) -> None:
 87        self.access_token = oauth_token.access_token
 88        self.refresh_token = oauth_token.refresh_token
 89        self.access_token_expires_at = oauth_token.access_token_expires_at
 90        self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
 91
 92    def set_user_fields(self, oauth_user: OAuthUser) -> None:
 93        self.provider_user_id = oauth_user.provider_id
 94
 95    def access_token_expired(self) -> bool:
 96        return (
 97            self.access_token_expires_at is not None
 98            and self.access_token_expires_at < timezone.now()
 99        )
100
101    def refresh_token_expired(self) -> bool:
102        return (
103            self.refresh_token_expires_at is not None
104            and self.refresh_token_expires_at < timezone.now()
105        )
106
107    @classmethod
108    def get_or_create_user(
109        cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
110    ) -> OAuthConnection:
111        try:
112            connection = cls.query.get(
113                provider_key=provider_key,
114                provider_user_id=oauth_user.provider_id,
115            )
116            connection.set_token_fields(oauth_token)
117            connection.save()
118            return connection
119        except cls.DoesNotExist:
120            # If email needs to be unique, then we expect
121            # that to be taken care of on the user model itself
122            with transaction.atomic():
123                try:
124                    with transaction.atomic():
125                        user = User(
126                            **oauth_user.user_model_fields,
127                        )
128                        user.save()
129                except (psycopg.IntegrityError, ValidationError):
130                    raise OAuthUserAlreadyExistsError(
131                        provider_key=provider_key,
132                        user_model_fields=oauth_user.user_model_fields,
133                    )
134
135                return cls.connect(
136                    user=user,
137                    provider_key=provider_key,
138                    oauth_token=oauth_token,
139                    oauth_user=oauth_user,
140                )
141
142    @classmethod
143    def connect(
144        cls,
145        *,
146        user: Any,
147        provider_key: str,
148        oauth_token: OAuthToken,
149        oauth_user: OAuthUser,
150    ) -> OAuthConnection:
151        """
152        Connect will either create a new connection or update an existing connection
153        """
154        try:
155            connection = cls.query.get(
156                user=user,
157                provider_key=provider_key,
158                provider_user_id=oauth_user.provider_id,
159            )
160        except cls.DoesNotExist:
161            # Create our own instance (not using get_or_create)
162            # so that any created signals contain the token fields too
163            connection = cls(
164                user=user,
165                provider_key=provider_key,
166                provider_user_id=oauth_user.provider_id,
167            )
168
169        connection.set_user_fields(oauth_user)
170        connection.set_token_fields(oauth_token)
171        connection.save()
172
173        return connection