1from __future__ import annotations
2
3import datetime
4from typing import TYPE_CHECKING, Any
5
6from plain import models
7from plain.auth import get_user_model
8from plain.exceptions import ValidationError
9from plain.models import transaction, types
10from plain.models.db import IntegrityError
11from plain.runtime import SettingsReference
12from plain.utils import timezone
13
14from .exceptions import OAuthUserAlreadyExistsError
15
16if TYPE_CHECKING:
17 from .providers import OAuthToken, OAuthUser
18
19
20@models.register_model
21class OAuthConnection(models.Model):
22 created_at: datetime.datetime = types.DateTimeField(auto_now_add=True)
23 updated_at: datetime.datetime = types.DateTimeField(auto_now=True)
24
25 user = types.ForeignKeyField(
26 SettingsReference("AUTH_USER_MODEL"),
27 on_delete=models.CASCADE,
28 )
29
30 # The key used to refer to this provider type (in settings)
31 provider_key: str = types.CharField(max_length=100)
32
33 # The unique ID of the user on the provider's system
34 provider_user_id: str = types.CharField(max_length=100)
35
36 # Token data
37 access_token: str = types.CharField(max_length=2000)
38 refresh_token: str = types.CharField(max_length=2000, required=False)
39 access_token_expires_at: datetime.datetime | None = types.DateTimeField(
40 required=False, allow_null=True
41 )
42 refresh_token_expires_at: datetime.datetime | None = types.DateTimeField(
43 required=False, allow_null=True
44 )
45
46 query: models.QuerySet[OAuthConnection] = models.QuerySet()
47
48 model_options = models.Options(
49 constraints=[
50 models.UniqueConstraint(
51 fields=["provider_key", "provider_user_id"],
52 name="plainoauth_oauthconnection_unique_provider_key_user_id",
53 )
54 ],
55 ordering=("provider_key",),
56 )
57
58 def __str__(self) -> str:
59 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
60
61 def refresh_access_token(self) -> None:
62 from .providers import OAuthToken, get_oauth_provider_instance
63
64 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
65 oauth_token = OAuthToken(
66 access_token=self.access_token,
67 refresh_token=self.refresh_token,
68 access_token_expires_at=self.access_token_expires_at,
69 refresh_token_expires_at=self.refresh_token_expires_at,
70 )
71 refreshed_oauth_token = provider_instance.refresh_oauth_token(
72 oauth_token=oauth_token
73 )
74 self.set_token_fields(refreshed_oauth_token)
75 self.save()
76
77 def set_token_fields(self, oauth_token: OAuthToken) -> None:
78 self.access_token = oauth_token.access_token
79 self.refresh_token = oauth_token.refresh_token
80 self.access_token_expires_at = oauth_token.access_token_expires_at
81 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
82
83 def set_user_fields(self, oauth_user: OAuthUser) -> None:
84 self.provider_user_id = oauth_user.provider_id
85
86 def access_token_expired(self) -> bool:
87 return (
88 self.access_token_expires_at is not None
89 and self.access_token_expires_at < timezone.now()
90 )
91
92 def refresh_token_expired(self) -> bool:
93 return (
94 self.refresh_token_expires_at is not None
95 and self.refresh_token_expires_at < timezone.now()
96 )
97
98 @classmethod
99 def get_or_create_user(
100 cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
101 ) -> OAuthConnection:
102 try:
103 connection = cls.query.get(
104 provider_key=provider_key,
105 provider_user_id=oauth_user.provider_id,
106 )
107 connection.set_token_fields(oauth_token)
108 connection.save()
109 return connection
110 except cls.DoesNotExist:
111 with transaction.atomic():
112 # If email needs to be unique, then we expect
113 # that to be taken care of on the user model itself
114 try:
115 user = get_user_model()(
116 **oauth_user.user_model_fields,
117 )
118 user.save()
119 except (IntegrityError, ValidationError):
120 raise OAuthUserAlreadyExistsError()
121
122 return cls.connect(
123 user=user,
124 provider_key=provider_key,
125 oauth_token=oauth_token,
126 oauth_user=oauth_user,
127 )
128
129 @classmethod
130 def connect(
131 cls,
132 *,
133 user: Any,
134 provider_key: str,
135 oauth_token: OAuthToken,
136 oauth_user: OAuthUser,
137 ) -> OAuthConnection:
138 """
139 Connect will either create a new connection or update an existing connection
140 """
141 try:
142 connection = cls.query.get(
143 user=user,
144 provider_key=provider_key,
145 provider_user_id=oauth_user.provider_id,
146 )
147 except cls.DoesNotExist:
148 # Create our own instance (not using get_or_create)
149 # so that any created signals contain the token fields too
150 connection = cls(
151 user=user,
152 provider_key=provider_key,
153 provider_user_id=oauth_user.provider_id,
154 )
155
156 connection.set_user_fields(oauth_user)
157 connection.set_token_fields(oauth_token)
158 connection.save()
159
160 return connection