1from __future__ import annotations
2
3from typing import TYPE_CHECKING, Any
4
5import psycopg
6from app.users.models import User
7
8from plain import postgres
9from plain.exceptions import ValidationError
10from plain.postgres import transaction, types
11from plain.utils import timezone
12
13from .exceptions import OAuthUserAlreadyExistsError
14
15if TYPE_CHECKING:
16 from .providers import OAuthToken, OAuthUser
17
18__all__ = ["OAuthConnection"]
19
20
21@postgres.register_model
22class OAuthConnection(postgres.Model):
23 created_at = types.DateTimeField(create_now=True)
24 updated_at = types.DateTimeField(create_now=True, update_now=True)
25
26 user = types.ForeignKeyField(
27 "users.User",
28 on_delete=postgres.CASCADE,
29 )
30
31 # The key used to refer to this provider type (in settings)
32 provider_key = types.TextField(max_length=100)
33
34 # The unique ID of the user on the provider's system
35 provider_user_id = types.TextField(max_length=100)
36
37 # Token data
38 access_token = types.EncryptedTextField(max_length=2000)
39 refresh_token = types.EncryptedTextField(max_length=2000, required=False)
40 access_token_expires_at = types.DateTimeField(required=False, allow_null=True)
41 refresh_token_expires_at = types.DateTimeField(required=False, allow_null=True)
42
43 query: postgres.QuerySet[OAuthConnection] = postgres.QuerySet()
44
45 model_options = postgres.Options(
46 indexes=[
47 postgres.Index(
48 name="plainoauth_oauthconnection_user_id_idx", fields=["user"]
49 ),
50 ],
51 constraints=[
52 postgres.UniqueConstraint(
53 fields=["provider_key", "provider_user_id"],
54 name="plainoauth_oauthconnection_unique_provider_key_user_id",
55 )
56 ],
57 ordering=("provider_key",),
58 )
59
60 def __str__(self) -> str:
61 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
62
63 def refresh_access_token(self) -> None:
64 from .providers import OAuthToken, get_oauth_provider_instance
65
66 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
67 oauth_token = OAuthToken(
68 access_token=self.access_token,
69 refresh_token=self.refresh_token,
70 access_token_expires_at=self.access_token_expires_at,
71 refresh_token_expires_at=self.refresh_token_expires_at,
72 )
73 refreshed_oauth_token = provider_instance.refresh_oauth_token(
74 oauth_token=oauth_token
75 )
76 self.set_token_fields(refreshed_oauth_token)
77 self.save()
78
79 def set_token_fields(self, oauth_token: OAuthToken) -> None:
80 self.access_token = oauth_token.access_token
81 self.refresh_token = oauth_token.refresh_token
82 self.access_token_expires_at = oauth_token.access_token_expires_at
83 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
84
85 def set_user_fields(self, oauth_user: OAuthUser) -> None:
86 self.provider_user_id = oauth_user.provider_id
87
88 def access_token_expired(self) -> bool:
89 return (
90 self.access_token_expires_at is not None
91 and self.access_token_expires_at < timezone.now()
92 )
93
94 def refresh_token_expired(self) -> bool:
95 return (
96 self.refresh_token_expires_at is not None
97 and self.refresh_token_expires_at < timezone.now()
98 )
99
100 @classmethod
101 def get_or_create_user(
102 cls, *, provider_key: str, oauth_token: OAuthToken, oauth_user: OAuthUser
103 ) -> OAuthConnection:
104 try:
105 connection = cls.query.get(
106 provider_key=provider_key,
107 provider_user_id=oauth_user.provider_id,
108 )
109 connection.set_token_fields(oauth_token)
110 connection.save()
111 return connection
112 except cls.DoesNotExist:
113 # If email needs to be unique, then we expect
114 # that to be taken care of on the user model itself
115 with transaction.atomic():
116 try:
117 with transaction.atomic():
118 user = User(
119 **oauth_user.user_model_fields,
120 )
121 user.save()
122 except (psycopg.IntegrityError, ValidationError):
123 raise OAuthUserAlreadyExistsError(
124 provider_key=provider_key,
125 user_model_fields=oauth_user.user_model_fields,
126 )
127
128 return cls.connect(
129 user=user,
130 provider_key=provider_key,
131 oauth_token=oauth_token,
132 oauth_user=oauth_user,
133 )
134
135 @classmethod
136 def connect(
137 cls,
138 *,
139 user: Any,
140 provider_key: str,
141 oauth_token: OAuthToken,
142 oauth_user: OAuthUser,
143 ) -> OAuthConnection:
144 """
145 Connect will either create a new connection or update an existing connection
146 """
147 try:
148 connection = cls.query.get(
149 user=user,
150 provider_key=provider_key,
151 provider_user_id=oauth_user.provider_id,
152 )
153 except cls.DoesNotExist:
154 # Create our own instance (not using get_or_create)
155 # so that any created signals contain the token fields too
156 connection = cls(
157 user=user,
158 provider_key=provider_key,
159 provider_user_id=oauth_user.provider_id,
160 )
161
162 connection.set_user_fields(oauth_user)
163 connection.set_token_fields(oauth_token)
164 connection.save()
165
166 return connection