1from __future__ import annotations
2
3import datetime
4from typing import TYPE_CHECKING, Any
5
6import psycopg
7from app.users.models import User
8
9from plain import postgres
10from plain.exceptions import ValidationError
11from plain.postgres import transaction, types
12from plain.utils import timezone
13
14from .exceptions import OAuthUserAlreadyExistsError
15
16if TYPE_CHECKING:
17 from .providers import OAuthToken, OAuthUser
18
19__all__ = ["OAuthConnection"]
20
21
22@postgres.register_model
23class OAuthConnection(postgres.Model):
24 created_at: datetime.datetime = types.DateTimeField(create_now=True)
25 updated_at: datetime.datetime = types.DateTimeField(
26 create_now=True, update_now=True
27 )
28
29 user = types.ForeignKeyField(
30 "users.User",
31 on_delete=postgres.CASCADE,
32 )
33
34 # The key used to refer to this provider type (in settings)
35 provider_key: str = types.TextField(max_length=100)
36
37 # The unique ID of the user on the provider's system
38 provider_user_id: str = types.TextField(max_length=100)
39
40 # Token data
41 access_token: str = types.EncryptedTextField(max_length=2000)
42 refresh_token: str = types.EncryptedTextField(max_length=2000, required=False)
43 access_token_expires_at: datetime.datetime | None = types.DateTimeField(
44 required=False, allow_null=True
45 )
46 refresh_token_expires_at: datetime.datetime | None = types.DateTimeField(
47 required=False, allow_null=True
48 )
49
50 query: postgres.QuerySet[OAuthConnection] = postgres.QuerySet()
51
52 model_options = postgres.Options(
53 indexes=[
54 postgres.Index(
55 name="plainoauth_oauthconnection_user_id_idx", fields=["user"]
56 ),
57 ],
58 constraints=[
59 postgres.UniqueConstraint(
60 fields=["provider_key", "provider_user_id"],
61 name="plainoauth_oauthconnection_unique_provider_key_user_id",
62 )
63 ],
64 ordering=("provider_key",),
65 )
66
67 def __str__(self) -> str:
68 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
69
70 def refresh_access_token(self) -> None:
71 from .providers import OAuthToken, get_oauth_provider_instance
72
73 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
74 oauth_token = OAuthToken(
75 access_token=self.access_token,
76 refresh_token=self.refresh_token,
77 access_token_expires_at=self.access_token_expires_at,
78 refresh_token_expires_at=self.refresh_token_expires_at,
79 )
80 refreshed_oauth_token = provider_instance.refresh_oauth_token(
81 oauth_token=oauth_token
82 )
83 self.set_token_fields(refreshed_oauth_token)
84 self.save()
85
86 def set_token_fields(self, oauth_token: OAuthToken) -> None:
87 self.access_token = oauth_token.access_token
88 self.refresh_token = oauth_token.refresh_token
89 self.access_token_expires_at = oauth_token.access_token_expires_at
90 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
91
92 def set_user_fields(self, oauth_user: OAuthUser) -> None:
93 self.provider_user_id = oauth_user.provider_id
94
95 def access_token_expired(self) -> bool:
96 return (
97 self.access_token_expires_at is not None
98 and self.access_token_expires_at < timezone.now()
99 )
100
101 def refresh_token_expired(self) -> bool:
102 return (
103 self.refresh_token_expires_at is not None
104 and self.refresh_token_expires_at < timezone.now()
105 )
106
107 @classmethod
108 def get_or_create_user(
109 cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
110 ) -> OAuthConnection:
111 try:
112 connection = cls.query.get(
113 provider_key=provider_key,
114 provider_user_id=oauth_user.provider_id,
115 )
116 connection.set_token_fields(oauth_token)
117 connection.save()
118 return connection
119 except cls.DoesNotExist:
120 # If email needs to be unique, then we expect
121 # that to be taken care of on the user model itself
122 with transaction.atomic():
123 try:
124 with transaction.atomic():
125 user = User(
126 **oauth_user.user_model_fields,
127 )
128 user.save()
129 except (psycopg.IntegrityError, ValidationError):
130 raise OAuthUserAlreadyExistsError(
131 provider_key=provider_key,
132 user_model_fields=oauth_user.user_model_fields,
133 )
134
135 return cls.connect(
136 user=user,
137 provider_key=provider_key,
138 oauth_token=oauth_token,
139 oauth_user=oauth_user,
140 )
141
142 @classmethod
143 def connect(
144 cls,
145 *,
146 user: Any,
147 provider_key: str,
148 oauth_token: OAuthToken,
149 oauth_user: OAuthUser,
150 ) -> OAuthConnection:
151 """
152 Connect will either create a new connection or update an existing connection
153 """
154 try:
155 connection = cls.query.get(
156 user=user,
157 provider_key=provider_key,
158 provider_user_id=oauth_user.provider_id,
159 )
160 except cls.DoesNotExist:
161 # Create our own instance (not using get_or_create)
162 # so that any created signals contain the token fields too
163 connection = cls(
164 user=user,
165 provider_key=provider_key,
166 provider_user_id=oauth_user.provider_id,
167 )
168
169 connection.set_user_fields(oauth_user)
170 connection.set_token_fields(oauth_token)
171 connection.save()
172
173 return connection