1from typing import TYPE_CHECKING, Any
2
3from plain import models
4from plain.auth import get_user_model
5from plain.exceptions import ValidationError
6from plain.models import transaction
7from plain.models.db import IntegrityError
8from plain.runtime import SettingsReference
9from plain.utils import timezone
10
11from .exceptions import OAuthUserAlreadyExistsError
12
13if TYPE_CHECKING:
14 from .providers import OAuthToken, OAuthUser
15
16
17@models.register_model
18class OAuthConnection(models.Model):
19 created_at = models.DateTimeField(auto_now_add=True)
20 updated_at = models.DateTimeField(auto_now=True)
21
22 user = models.ForeignKey(
23 SettingsReference("AUTH_USER_MODEL"),
24 on_delete=models.CASCADE,
25 related_name="oauth_connections",
26 )
27
28 # The key used to refer to this provider type (in settings)
29 provider_key = models.CharField(max_length=100)
30
31 # The unique ID of the user on the provider's system
32 provider_user_id = models.CharField(max_length=100)
33
34 # Token data
35 access_token = models.CharField(max_length=2000)
36 refresh_token = models.CharField(max_length=2000, required=False)
37 access_token_expires_at = models.DateTimeField(required=False, allow_null=True)
38 refresh_token_expires_at = models.DateTimeField(required=False, allow_null=True)
39
40 model_options = models.Options(
41 constraints=[
42 models.UniqueConstraint(
43 fields=["provider_key", "provider_user_id"],
44 name="plainoauth_oauthconnection_unique_provider_key_user_id",
45 )
46 ],
47 ordering=("provider_key",),
48 )
49
50 def __str__(self) -> str:
51 return f"{self.provider_key}[{self.user}:{self.provider_user_id}]"
52
53 def refresh_access_token(self) -> None:
54 from .providers import OAuthToken, get_oauth_provider_instance
55
56 provider_instance = get_oauth_provider_instance(provider_key=self.provider_key)
57 oauth_token = OAuthToken(
58 access_token=self.access_token,
59 refresh_token=self.refresh_token,
60 access_token_expires_at=self.access_token_expires_at,
61 refresh_token_expires_at=self.refresh_token_expires_at,
62 )
63 refreshed_oauth_token = provider_instance.refresh_oauth_token(
64 oauth_token=oauth_token
65 )
66 self.set_token_fields(refreshed_oauth_token)
67 self.save()
68
69 def set_token_fields(self, oauth_token: "OAuthToken") -> None:
70 self.access_token = oauth_token.access_token
71 self.refresh_token = oauth_token.refresh_token
72 self.access_token_expires_at = oauth_token.access_token_expires_at
73 self.refresh_token_expires_at = oauth_token.refresh_token_expires_at
74
75 def set_user_fields(self, oauth_user: "OAuthUser") -> None:
76 self.provider_user_id = oauth_user.provider_id
77
78 def access_token_expired(self) -> bool:
79 return (
80 self.access_token_expires_at is not None
81 and self.access_token_expires_at < timezone.now()
82 )
83
84 def refresh_token_expired(self) -> bool:
85 return (
86 self.refresh_token_expires_at is not None
87 and self.refresh_token_expires_at < timezone.now()
88 )
89
90 @classmethod
91 def get_or_create_user(
92 cls, *, provider_key: str, oauth_token: "OAuthToken", oauth_user: "OAuthUser"
93 ) -> "OAuthConnection":
94 try:
95 connection = cls.query.get(
96 provider_key=provider_key,
97 provider_user_id=oauth_user.provider_id,
98 )
99 connection.set_token_fields(oauth_token)
100 connection.save()
101 return connection
102 except cls.DoesNotExist:
103 with transaction.atomic():
104 # If email needs to be unique, then we expect
105 # that to be taken care of on the user model itself
106 try:
107 user = get_user_model()(
108 **oauth_user.user_model_fields,
109 )
110 user.save()
111 except (IntegrityError, ValidationError):
112 raise OAuthUserAlreadyExistsError()
113
114 return cls.connect(
115 user=user,
116 provider_key=provider_key,
117 oauth_token=oauth_token,
118 oauth_user=oauth_user,
119 )
120
121 @classmethod
122 def connect(
123 cls,
124 *,
125 user: Any,
126 provider_key: str,
127 oauth_token: "OAuthToken",
128 oauth_user: "OAuthUser",
129 ) -> "OAuthConnection":
130 """
131 Connect will either create a new connection or update an existing connection
132 """
133 try:
134 connection = cls.query.get(
135 user=user,
136 provider_key=provider_key,
137 provider_user_id=oauth_user.provider_id,
138 )
139 except cls.DoesNotExist:
140 # Create our own instance (not using get_or_create)
141 # so that any created signals contain the token fields too
142 connection = cls(
143 user=user,
144 provider_key=provider_key,
145 provider_user_id=oauth_user.provider_id,
146 )
147
148 connection.set_user_fields(oauth_user)
149 connection.set_token_fields(oauth_token)
150 connection.save()
151
152 return connection