1from __future__ import annotations
2
3import datetime
4from typing import TYPE_CHECKING, Any
5
6from plain import models
7from plain.auth import get_user_model
8from plain.exceptions import ValidationError
9from plain.models import transaction, types
10from plain.models.db import IntegrityError
11from plain.runtime import SettingsReference
12from plain.utils import timezone
13
14from .exceptions import OAuthUserAlreadyExistsError
15
16if TYPE_CHECKING:
17 from .providers import OAuthToken, OAuthUser
18
19__all__ = ["OAuthConnection"]
20
21
22@models.register_model
23class OAuthConnection(models.Model):
24 created_at: datetime.datetime = types.DateTimeField(auto_now_add=True)
25 updated_at: datetime.datetime = types.DateTimeField(auto_now=True)
26
27 user = types.ForeignKeyField(
28 SettingsReference("AUTH_USER_MODEL"),
29 on_delete=models.CASCADE,
30 )
31
32 # The key used to refer to this provider type (in settings)
33 provider_key: str = types.CharField(max_length=100)
34
35 # The unique ID of the user on the provider's system
36 provider_user_id: str = types.CharField(max_length=100)
37
38 # Token data
39 access_token: str = types.CharField(max_length=2000)
40 refresh_token: str = types.CharField(max_length=2000, required=False)
41 access_token_expires_at: datetime.datetime | None = types.DateTimeField(
42 required=False, allow_null=True
43 )
44 refresh_token_expires_at: datetime.datetime | None = types.DateTimeField(
45 required=False, allow_null=True
46 )
47
48 query: models.QuerySet[OAuthConnection] = models.QuerySet()
49
50 model_options = models.Options(
51 constraints=[
52 models.UniqueConstraint(
53 fields=["provider_key", "provider_user_id"],
54 name="plainoauth_oauthconnection_unique_provider_key_user_id",
55 )
56 ],
57 ordering=("provider_key",),
58 )
59
60 def __str__(self) -> str:
61 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
62
63 def refresh_access_token(self) -> None:
64 from .providers import OAuthToken, get_oauth_provider_instance
65
66 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
67 oauth_token = OAuthToken(
68 access_token=self.access_token,
69 refresh_token=self.refresh_token,
70 access_token_expires_at=self.access_token_expires_at,
71 refresh_token_expires_at=self.refresh_token_expires_at,
72 )
73 refreshed_oauth_token = provider_instance.refresh_oauth_token(
74 oauth_token=oauth_token
75 )
76 self.set_token_fields(refreshed_oauth_token)
77 self.save()
78
79 def set_token_fields(self, oauth_token: OAuthToken) -> None:
80 self.access_token = oauth_token.access_token
81 self.refresh_token = oauth_token.refresh_token
82 self.access_token_expires_at = oauth_token.access_token_expires_at
83 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
84
85 def set_user_fields(self, oauth_user: OAuthUser) -> None:
86 self.provider_user_id = oauth_user.provider_id
87
88 def access_token_expired(self) -> bool:
89 return (
90 self.access_token_expires_at is not None
91 and self.access_token_expires_at < timezone.now()
92 )
93
94 def refresh_token_expired(self) -> bool:
95 return (
96 self.refresh_token_expires_at is not None
97 and self.refresh_token_expires_at < timezone.now()
98 )
99
100 @classmethod
101 def get_or_create_user(
102 cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
103 ) -> OAuthConnection:
104 try:
105 connection = cls.query.get(
106 provider_key=provider_key,
107 provider_user_id=oauth_user.provider_id,
108 )
109 connection.set_token_fields(oauth_token)
110 connection.save()
111 return connection
112 except cls.DoesNotExist:
113 with transaction.atomic():
114 # If email needs to be unique, then we expect
115 # that to be taken care of on the user model itself
116 try:
117 user = get_user_model()(
118 **oauth_user.user_model_fields,
119 )
120 user.save()
121 except (IntegrityError, ValidationError):
122 raise OAuthUserAlreadyExistsError()
123
124 return cls.connect(
125 user=user,
126 provider_key=provider_key,
127 oauth_token=oauth_token,
128 oauth_user=oauth_user,
129 )
130
131 @classmethod
132 def connect(
133 cls,
134 *,
135 user: Any,
136 provider_key: str,
137 oauth_token: OAuthToken,
138 oauth_user: OAuthUser,
139 ) -> OAuthConnection:
140 """
141 Connect will either create a new connection or update an existing connection
142 """
143 try:
144 connection = cls.query.get(
145 user=user,
146 provider_key=provider_key,
147 provider_user_id=oauth_user.provider_id,
148 )
149 except cls.DoesNotExist:
150 # Create our own instance (not using get_or_create)
151 # so that any created signals contain the token fields too
152 connection = cls(
153 user=user,
154 provider_key=provider_key,
155 provider_user_id=oauth_user.provider_id,
156 )
157
158 connection.set_user_fields(oauth_user)
159 connection.set_token_fields(oauth_token)
160 connection.save()
161
162 return connection