1from typing import TYPE_CHECKING
2
3from plain import models
4from plain.auth import get_user_model
5from plain.exceptions import ValidationError
6from plain.models import transaction
7from plain.models.db import IntegrityError, OperationalError, ProgrammingError
8from plain.preflight import Error
9from plain.runtime import SettingsReference
10from plain.utils import timezone
11
12from .exceptions import OAuthUserAlreadyExistsError
13
14if TYPE_CHECKING:
15 from .providers import OAuthToken, OAuthUser
16
17
18# TODO preflight check for deploy that ensures all provider keys in db are also in settings?
19
20
21@models.register_model
22class OAuthConnection(models.Model):
23 created_at = models.DateTimeField(auto_now_add=True)
24 updated_at = models.DateTimeField(auto_now=True)
25
26 user = models.ForeignKey(
27 SettingsReference("AUTH_USER_MODEL"),
28 on_delete=models.CASCADE,
29 related_name="oauth_connections",
30 )
31
32 # The key used to refer to this provider type (in settings)
33 provider_key = models.CharField(max_length=100)
34
35 # The unique ID of the user on the provider's system
36 provider_user_id = models.CharField(max_length=100)
37
38 # Token data
39 access_token = models.CharField(max_length=2000)
40 refresh_token = models.CharField(max_length=2000, required=False)
41 access_token_expires_at = models.DateTimeField(required=False, allow_null=True)
42 refresh_token_expires_at = models.DateTimeField(required=False, allow_null=True)
43
44 class Meta:
45 constraints = [
46 models.UniqueConstraint(
47 fields=["provider_key", "provider_user_id"],
48 name="plainoauth_oauthconnection_unique_provider_key_user_id",
49 )
50 ]
51 ordering = ("provider_key",)
52
53 def __str__(self):
54 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
55
56 def refresh_access_token(self) -> None:
57 from .providers import OAuthToken, get_oauth_provider_instance
58
59 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
60 oauth_token = OAuthToken(
61 access_token=self.access_token,
62 refresh_token=self.refresh_token,
63 access_token_expires_at=self.access_token_expires_at,
64 refresh_token_expires_at=self.refresh_token_expires_at,
65 )
66 refreshed_oauth_token = provider_instance.refresh_oauth_token(
67 oauth_token=oauth_token
68 )
69 self.set_token_fields(refreshed_oauth_token)
70 self.save()
71
72 def set_token_fields(self, oauth_token: "OAuthToken"):
73 self.access_token = oauth_token.access_token
74 self.refresh_token = oauth_token.refresh_token
75 self.access_token_expires_at = oauth_token.access_token_expires_at
76 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
77
78 def set_user_fields(self, oauth_user: "OAuthUser"):
79 self.provider_user_id = oauth_user.id
80
81 def access_token_expired(self) -> bool:
82 return (
83 self.access_token_expires_at is not None
84 and self.access_token_expires_at < timezone.now()
85 )
86
87 def refresh_token_expired(self) -> bool:
88 return (
89 self.refresh_token_expires_at is not None
90 and self.refresh_token_expires_at < timezone.now()
91 )
92
93 @classmethod
94 def get_or_create_user(
95 cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser"
96 ) -> "OAuthConnection":
97 try:
98 connection = cls.objects.get(
99 provider_key=provider_key,
100 provider_user_id=oauth_user.id,
101 )
102 connection.set_token_fields(oauth_token)
103 connection.save()
104 return connection
105 except cls.DoesNotExist:
106 with transaction.atomic():
107 # If email needs to be unique, then we expect
108 # that to be taken care of on the user model itself
109 try:
110 user = get_user_model()(
111 **oauth_user.user_model_fields,
112 )
113 user.save()
114 except (IntegrityError, ValidationError):
115 raise OAuthUserAlreadyExistsError()
116
117 return cls.connect(
118 user=user,
119 provider_key=provider_key,
120 oauth_token=oauth_token,
121 oauth_user=oauth_user,
122 )
123
124 @classmethod
125 def connect(
126 cls,
127 *,
128 user,
129 provider_key: str,
130 oauth_token: "OAuthToken",
131 oauth_user: "OAuthUser",
132 ) -> "OAuthConnection":
133 """
134 Connect will either create a new connection or update an existing connection
135 """
136 try:
137 connection = cls.objects.get(
138 user=user,
139 provider_key=provider_key,
140 provider_user_id=oauth_user.id,
141 )
142 except cls.DoesNotExist:
143 # Create our own instance (not using get_or_create)
144 # so that any created signals contain the token fields too
145 connection = cls(
146 user=user,
147 provider_key=provider_key,
148 provider_user_id=oauth_user.id,
149 )
150
151 connection.set_user_fields(oauth_user)
152 connection.set_token_fields(oauth_token)
153 connection.save()
154
155 return connection
156
157 @classmethod
158 def check(cls, **kwargs):
159 """
160 A system check for ensuring that provider_keys in the database are also present in settings.
161
162 Note that the --database flag is required for this to work:
163 python manage.py check --database default
164 """
165 errors = super().check(**kwargs)
166
167 databases = kwargs.get("databases", None)
168 if not databases:
169 return errors
170
171 from .providers import get_provider_keys
172
173 for database in databases:
174 try:
175 keys_in_db = set(
176 cls.objects.using(database)
177 .values_list("provider_key", flat=True)
178 .distinct()
179 )
180 except (OperationalError, ProgrammingError):
181 # Check runs on manage.py migrate, and the table may not exist yet
182 # or it may not be installed on the particular database intentionally
183 continue
184
185 keys_in_settings = set(get_provider_keys())
186
187 if keys_in_db - keys_in_settings:
188 errors.append(
189 Error(
190 "The following OAuth providers are in the database but not in the settings: {}".format(
191 ", ".join(keys_in_db - keys_in_settings)
192 ),
193 id="plain.oauth.E001",
194 )
195 )
196
197 return errors