OpenID Connect OP support for customer accounts

This commit is contained in:
Raphael Michel
2022-08-10 14:22:30 +02:00
committed by Raphael Michel
parent 7f5518dbf6
commit a4171ef819
20 changed files with 1735 additions and 23 deletions

View File

@@ -19,17 +19,34 @@
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# <https://www.gnu.org/licenses/>.
#
import base64
import hashlib
import logging
import time
from datetime import datetime
from urllib.parse import urlencode, urljoin
import jwt
import requests
from cryptography.hazmat.primitives.asymmetric.rsa import generate_private_key
from cryptography.hazmat.primitives.serialization import (
Encoding, NoEncryption, PrivateFormat, PublicFormat,
)
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
from requests import RequestException
from pretix.multidomain.urlreverse import build_absolute_uri
logger = logging.getLogger(__name__)
"""
This module contains utilities for implementing OpenID Connect for customer authentication both as a receiving party (RP)
as well as an OpenID Provider (OP).
"""
def _urljoin(base, path):
if not base.endswith("/"):
base += "/"
@@ -205,3 +222,74 @@ def oidc_validate_authorization(provider, code, redirect_uri):
)
return profile
def _hash_scheme(value):
# As described in https://openid.net/specs/openid-connect-core-1_0.html#HybridIDToken
digest = hashlib.sha256(value.encode()).digest()
digest_truncated = digest[:(len(digest) // 2)]
return base64.urlsafe_b64encode(digest_truncated).decode().rstrip("=")
def customer_claims(customer, scope):
scope = scope.split(' ')
claims = {
'sub': customer.identifier,
'locale': customer.locale,
}
if 'profile' in scope:
if customer.name:
claims['name'] = customer.name
if 'given_name' in customer.name_parts:
claims['given_name'] = customer.name_parts['given_name']
if 'family_name' in customer.name_parts:
claims['family_name'] = customer.name_parts['family_name']
if 'middle_name' in customer.name_parts:
claims['middle_name'] = customer.name_parts['middle_name']
if 'calling_name' in customer.name_parts:
claims['nickname'] = customer.name_parts['calling_name']
if 'email' in scope and customer.email:
claims['email'] = customer.email
claims['email_verified'] = customer.is_verified
if 'phone' in scope and customer.phone:
claims['phone_number'] = customer.phone.as_international
return claims
def _get_or_create_server_keypair(organizer):
if not organizer.settings.sso_server_signing_key_rsa256_private:
privkey = generate_private_key(key_size=4096, public_exponent=65537)
pubkey = privkey.public_key()
organizer.settings.sso_server_signing_key_rsa256_private = privkey.private_bytes(
Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()
).decode()
organizer.settings.sso_server_signing_key_rsa256_public = pubkey.public_bytes(
Encoding.PEM, PublicFormat.SubjectPublicKeyInfo
).decode()
return organizer.settings.sso_server_signing_key_rsa256_private, organizer.settings.sso_server_signing_key_rsa256_public
def generate_id_token(customer, client, auth_time, nonce, scope, expires: datetime, scope_claims=False, with_code=None, with_access_token=None):
payload = {
'iss': build_absolute_uri(client.organizer, 'presale:organizer.index').rstrip('/'),
'aud': client.client_id,
'exp': int(expires.timestamp()),
'iat': int(time.time()),
'auth_time': auth_time,
**customer_claims(customer, client.evaluated_scope(scope) if scope_claims else ''),
}
if nonce:
payload['nonce'] = nonce
if with_code:
payload['c_hash'] = _hash_scheme(with_code)
if with_access_token:
payload['at_hash'] = _hash_scheme(with_access_token)
privkey, pubkey = _get_or_create_server_keypair(client.organizer)
return jwt.encode(
payload,
privkey,
headers={
"kid": hashlib.sha256(pubkey.encode()).hexdigest()[:16]
},
algorithm="RS256",
)

View File

@@ -0,0 +1,68 @@
# Generated by Django 3.2.12 on 2022-08-11 10:02
import django.db.models.deletion
from django.db import migrations, models
import pretix.base.models.base
import pretix.base.models.customers
import pretix.base.models.fields
class Migration(migrations.Migration):
dependencies = [
('pretixbase', '0219_auto_20220706_0913'),
]
operations = [
migrations.CreateModel(
name='CustomerSSOClient',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('name', models.CharField(max_length=255)),
('is_active', models.BooleanField(default=True)),
('client_id', models.CharField(db_index=True, default=pretix.base.models.customers.generate_client_id, max_length=100, unique=True)),
('client_secret', models.CharField(max_length=255)),
('client_type', models.CharField(default='confidential', max_length=32)),
('authorization_grant_type', models.CharField(default='authorization-code', max_length=32)),
('redirect_uris', models.TextField()),
('allowed_scopes', pretix.base.models.fields.MultiStringField(default=['openid', 'profile', 'email', 'phone'])),
('organizer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sso_clients', to='pretixbase.organizer')),
],
options={
'abstract': False,
},
bases=(models.Model, pretix.base.models.base.LoggingMixin),
),
migrations.AlterField(
model_name='customer',
name='provider',
field=models.ForeignKey(null=True, on_delete=django.db.models.deletion.PROTECT, related_name='customers', to='pretixbase.customerssoprovider'),
),
migrations.CreateModel(
name='CustomerSSOGrant',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('code', models.CharField(max_length=255, unique=True)),
('nonce', models.CharField(max_length=255, null=True)),
('auth_time', models.IntegerField()),
('expires', models.DateTimeField()),
('redirect_uri', models.TextField()),
('scope', models.TextField()),
('client', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='grants', to='pretixbase.customerssoclient')),
('customer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sso_grants', to='pretixbase.customer')),
],
),
migrations.CreateModel(
name='CustomerSSOAccessToken',
fields=[
('id', models.BigAutoField(primary_key=True, serialize=False)),
('from_code', models.CharField(max_length=255, null=True)),
('token', models.CharField(max_length=255, unique=True)),
('expires', models.DateTimeField()),
('scope', models.TextField()),
('client', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='access_tokens', to='pretixbase.customerssoclient')),
('customer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='sso_access_tokens', to='pretixbase.customer')),
],
),
]

View File

@@ -24,7 +24,7 @@ from django.conf import settings
from django.contrib.auth.hashers import (
check_password, is_password_usable, make_password,
)
from django.core.validators import RegexValidator
from django.core.validators import RegexValidator, URLValidator
from django.db import models
from django.db.models import F, Q
from django.utils.crypto import get_random_string, salted_hmac
@@ -35,6 +35,7 @@ from phonenumber_field.modelfields import PhoneNumberField
from pretix.base.banlist import banned
from pretix.base.models.base import LoggedModel
from pretix.base.models.fields import MultiStringField
from pretix.base.models.organizer import Organizer
from pretix.base.settings import PERSON_NAME_SCHEMES
from pretix.helpers.countries import FastCountryField
@@ -348,3 +349,134 @@ class AttendeeProfile(models.Model):
parts.append(f'{a["field_label"]}: {val}')
return '\n'.join([str(p).strip() for p in parts if p and str(p).strip()])
def generate_client_id():
return get_random_string(40)
def generate_client_secret():
return get_random_string(40)
class CustomerSSOClient(LoggedModel):
CLIENT_CONFIDENTIAL = "confidential"
CLIENT_PUBLIC = "public"
CLIENT_TYPES = (
(CLIENT_CONFIDENTIAL, pgettext_lazy("openidconnect", "Confidential")),
(CLIENT_PUBLIC, pgettext_lazy("openidconnect", "Public")),
)
GRANT_AUTHORIZATION_CODE = "authorization-code"
GRANT_IMPLICIT = "implicit"
GRANT_TYPES = (
(GRANT_AUTHORIZATION_CODE, pgettext_lazy("openidconnect", "Authorization code")),
(GRANT_IMPLICIT, pgettext_lazy("openidconnect", "Implicit")),
)
SCOPE_CHOICES = (
('openid', _('OpenID Connect access (required)')),
('profile', _('Profile data (name, addresses)')),
('email', _('E-mail address')),
('phone', _('Phone number')),
)
id = models.BigAutoField(primary_key=True)
organizer = models.ForeignKey(Organizer, related_name='sso_clients', on_delete=models.CASCADE)
name = models.CharField(verbose_name=_("Application name"), max_length=255, blank=False)
is_active = models.BooleanField(default=True, verbose_name=_('Active'))
client_id = models.CharField(
verbose_name=_("Client ID"),
max_length=100, unique=True, default=generate_client_id, db_index=True
)
client_secret = models.CharField(
max_length=255, blank=False,
)
client_type = models.CharField(
max_length=32, choices=CLIENT_TYPES, verbose_name=_("Client type"), default=CLIENT_CONFIDENTIAL,
)
authorization_grant_type = models.CharField(
max_length=32, choices=GRANT_TYPES, verbose_name=_("Grant type"), default=GRANT_AUTHORIZATION_CODE,
)
redirect_uris = models.TextField(
blank=False,
verbose_name=_("Redirection URIs"),
help_text=_("Allowed URIs list, space separated")
)
allowed_scopes = MultiStringField(
default=['openid', 'profile', 'email', 'phone'],
delimiter=" ",
blank=True,
verbose_name=_('Allowed access scopes'),
help_text=_('Separate multiple values with spaces'),
)
def is_usable(self):
return self.is_active
def allow_redirect_uri(self, redirect_uri):
return self.redirect_uris and any(r.strip() == redirect_uri for r in self.redirect_uris.split(' '))
def allow_delete(self):
return True
def evaluated_scope(self, scope):
scope = set(scope.split(' '))
allowed_scopes = set(self.allowed_scopes)
return ' '.join(scope & allowed_scopes)
def clean(self):
redirect_uris = self.redirect_uris.strip().split()
if redirect_uris:
validator = URLValidator()
for uri in redirect_uris:
validator(uri)
def set_client_secret(self):
secret = get_random_string(64)
self.client_secret = make_password(secret)
return secret
def check_client_secret(self, raw_secret):
"""
Return a boolean of whether the ra_secret was correct. Handles
hashing formats behind the scenes.
"""
def setter(raw_secret):
self.client_secret = make_password(raw_secret)
self.save(update_fields=["client_secret"])
return check_password(raw_secret, self.client_secret, setter)
class CustomerSSOGrant(models.Model):
id = models.BigAutoField(primary_key=True)
client = models.ForeignKey(
CustomerSSOClient, on_delete=models.CASCADE, related_name="grants"
)
customer = models.ForeignKey(
Customer, on_delete=models.CASCADE, related_name="sso_grants"
)
code = models.CharField(max_length=255, unique=True)
nonce = models.CharField(max_length=255, null=True, blank=True)
auth_time = models.IntegerField()
expires = models.DateTimeField()
redirect_uri = models.TextField()
scope = models.TextField(blank=True)
class CustomerSSOAccessToken(models.Model):
id = models.BigAutoField(primary_key=True)
client = models.ForeignKey(
CustomerSSOClient, on_delete=models.CASCADE, related_name="access_tokens"
)
customer = models.ForeignKey(
Customer, on_delete=models.CASCADE, related_name="sso_access_tokens"
)
from_code = models.CharField(max_length=255, null=True, blank=True)
token = models.CharField(max_length=255, unique=True)
expires = models.DateTimeField()
scope = models.TextField(blank=True)

View File

@@ -33,7 +33,8 @@ class MultiStringField(TextField):
'delimiter_found': _('No value can contain the delimiter character.')
}
def __init__(self, verbose_name=None, name=None, **kwargs):
def __init__(self, verbose_name=None, name=None, delimiter=DELIMITER, **kwargs):
self.delimiter = delimiter
super().__init__(verbose_name, name, **kwargs)
def deconstruct(self):
@@ -44,13 +45,13 @@ class MultiStringField(TextField):
if isinstance(value, (list, tuple)):
return value
elif value:
return [v for v in value.split(DELIMITER) if v]
return [v for v in value.split(self.delimiter) if v]
else:
return []
def get_prep_value(self, value):
if isinstance(value, (list, tuple)):
return DELIMITER + DELIMITER.join(value) + DELIMITER
return self.delimiter + self.delimiter.join(value) + self.delimiter
elif value is None:
if self.null:
return None
@@ -63,14 +64,14 @@ class MultiStringField(TextField):
def from_db_value(self, value, expression, connection):
if value:
return [v for v in value.split(DELIMITER) if v]
return [v for v in value.split(self.delimiter) if v]
else:
return []
def validate(self, value, model_instance):
super().validate(value, model_instance)
for l in value:
if DELIMITER in l:
if self.delimiter in l:
raise exceptions.ValidationError(
self.error_messages['delimiter_found'],
code='delimiter_found',
@@ -78,9 +79,9 @@ class MultiStringField(TextField):
def get_lookup(self, lookup_name):
if lookup_name == 'contains':
return MultiStringContains
return make_multistring_contains_lookup(self.delimiter)
elif lookup_name == 'icontains':
return MultiStringIContains
return make_multistring_icontains_lookup(self.delimiter)
elif lookup_name == 'isnull':
return builtin_lookups.IsNull
raise NotImplementedError(
@@ -88,18 +89,22 @@ class MultiStringField(TextField):
)
class MultiStringContains(builtin_lookups.Contains):
def process_rhs(self, qn, connection):
sql, params = super().process_rhs(qn, connection)
params[0] = "%" + DELIMITER + params[0][1:-1] + DELIMITER + "%"
return sql, params
def make_multistring_contains_lookup(delimiter):
class Cls(builtin_lookups.Contains):
def process_rhs(self, qn, connection):
sql, params = super().process_rhs(qn, connection)
params[0] = "%" + delimiter + params[0][1:-1] + delimiter + "%"
return sql, params
return Cls
class MultiStringIContains(builtin_lookups.IContains):
def process_rhs(self, qn, connection):
sql, params = super().process_rhs(qn, connection)
params[0] = "%" + DELIMITER + params[0][1:-1] + DELIMITER + "%"
return sql, params
def make_multistring_icontains_lookup(delimiter):
class Cls(builtin_lookups.IContains):
def process_rhs(self, qn, connection):
sql, params = super().process_rhs(qn, connection)
params[0] = "%" + delimiter + params[0][1:-1] + delimiter + "%"
return sql, params
return Cls
class MultiStringSerializer(serializers.Field):

View File

@@ -28,6 +28,7 @@ from django.utils.timezone import now
from django_scopes import scopes_disabled
from pretix.base.models import CachedCombinedTicket, CachedTicket
from pretix.base.models.customers import CustomerSSOGrant
from ..models import CachedFile, CartPosition, InvoiceAddress
from ..signals import periodic_task
@@ -68,3 +69,9 @@ def clean_cached_tickets(sender, **kwargs):
@scopes_disabled()
def clearsessions(sender, **kwargs):
call_command('clearsessions')
@receiver(signal=periodic_task)
@scopes_disabled()
def clear_oidc_data(sender, **kwargs):
CustomerSSOGrant.objects.filter(expires__lt=now() - timedelta(days=14)).delete()