diff --git a/src/pretix/presale/checkoutflow.py b/src/pretix/presale/checkoutflow.py
index b60a543ce7..ff2b2b8452 100644
--- a/src/pretix/presale/checkoutflow.py
+++ b/src/pretix/presale/checkoutflow.py
@@ -298,11 +298,13 @@ class CustomerStep(CartMixin, TemplateFlowStep):
elif request.customer:
self.cart_session['customer_mode'] = 'login'
self.cart_session['customer'] = request.customer.pk
+ self.cart_session['customer_cart_tied_to_login'] = True
return redirect(self.get_next_url(request))
elif self.login_form.is_valid():
customer_login(self.request, self.login_form.get_customer())
self.cart_session['customer_mode'] = 'login'
self.cart_session['customer'] = self.login_form.get_customer().pk
+ self.cart_session['customer_cart_tied_to_login'] = True
return redirect(self.get_next_url(request))
else:
return self.render()
@@ -311,6 +313,7 @@ class CustomerStep(CartMixin, TemplateFlowStep):
customer = self.register_form.create()
self.cart_session['customer_mode'] = 'login'
self.cart_session['customer'] = customer.pk
+ self.cart_session['customer_cart_tied_to_login'] = False
return redirect(self.get_next_url(request))
else:
return self.render()
diff --git a/src/pretix/presale/templates/pretixpresale/fragment_login_status.html b/src/pretix/presale/templates/pretixpresale/fragment_login_status.html
index 1e558804a8..d4c18ddf60 100644
--- a/src/pretix/presale/templates/pretixpresale/fragment_login_status.html
+++ b/src/pretix/presale/templates/pretixpresale/fragment_login_status.html
@@ -15,7 +15,7 @@
{% else %}
-
+
{% trans "Log in" %}
{% endif %}
diff --git a/src/pretix/presale/utils.py b/src/pretix/presale/utils.py
index 85393046d6..72ec113528 100644
--- a/src/pretix/presale/utils.py
+++ b/src/pretix/presale/utils.py
@@ -31,7 +31,7 @@
# Unless required by applicable law or agreed to in writing, software distributed under the Apache License 2.0 is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under the License.
-
+import re
import warnings
from importlib import import_module
from urllib.parse import urljoin
@@ -64,17 +64,36 @@ def get_customer(request):
if not hasattr(request, '_cached_customer'):
session_key = f'customer_auth_id:{request.organizer.pk}'
hash_session_key = f'customer_auth_hash:{request.organizer.pk}'
+ dependency_key = f'customer_auth_session_dependency:{request.organizer.pk}'
+
+ # By default, we look at the regular django session
+ session = request.session
+
+ # However, if an event uses a custom domain, the event is at a different domain
+ # than our actual session cookie. The login state is therefore not determined
+ # by our request session, but by the "parent session", the user's session on the
+ # organizer level. This approach guarantees e.g. a global logout feature.
+ if session.get(dependency_key):
+ sparent = SessionStore(session[dependency_key])
+ try:
+ sparent.load()
+ except:
+ # parent session no longer exists
+ request._cached_customer = None
+ return
+ else:
+ session = sparent
with scope(organizer=request.organizer):
try:
customer = request.organizer.customers.get(
is_active=True, is_verified=True,
- pk=request.session[session_key]
+ pk=session[session_key]
)
except (Customer.DoesNotExist, KeyError):
request._cached_customer = None
else:
- session_hash = request.session.get(hash_session_key)
+ session_hash = session.get(hash_session_key)
session_hash_verified = session_hash and constant_time_compare(
session_hash,
customer.get_session_auth_hash()
@@ -82,7 +101,7 @@ def get_customer(request):
if session_hash_verified:
request._cached_customer = customer
else:
- request.session.flush()
+ session.flush()
request._cached_customer = None
return request._cached_customer
@@ -96,12 +115,46 @@ def update_customer_session_auth_hash(request, customer):
def add_customer_to_request(request):
+ if 'cross_domain_customer_auth' in request.GET and request.event_domain:
+ # The user is logged in on the main domain and now wants to take their session
+ # to a event-specific domain. We validate the one time token received via a
+ # query parameter and make sure we invalidate it right away. Then, we look up
+ # the users session on the main domain and store the dependency between the two
+ # sessions.
+ otp = re.sub('[^a-zA-Z0-9]', '', request.GET['cross_domain_customer_auth'])
+
+ otpstore = SessionStore(otp)
+ try:
+ otpstore.load()
+ except:
+ pass
+ else:
+ parent_session_key = otpstore.get(f'customer_cross_domain_auth_{request.organizer.pk}')
+
+ if parent_session_key: # not already invalidated, expired, …
+ # Make sure the OTP can't be used again
+ otpstore.delete()
+
+ sparent = SessionStore(parent_session_key)
+ try:
+ sparent.load()
+ except:
+ # parent session no longer exists
+ pass
+ else:
+ dependency_key = f'customer_auth_session_dependency:{request.organizer.pk}'
+ session_key = f'customer_auth_id:{request.organizer.pk}'
+ request.session[dependency_key] = parent_session_key
+ if session_key in request.session:
+ del request.session[session_key]
+
request.customer = SimpleLazyObject(lambda: get_customer(request))
def customer_login(request, customer):
session_key = f'customer_auth_id:{request.organizer.pk}'
hash_session_key = f'customer_auth_hash:{request.organizer.pk}'
+ dependency_key = f'customer_auth_session_dependency:{request.organizer.pk}'
session_auth_hash = customer.get_session_auth_hash()
if session_key in request.session:
@@ -114,6 +167,7 @@ def customer_login(request, customer):
else:
request.session.cycle_key()
+ request.session.pop(dependency_key, None)
request.session[session_key] = customer.pk
request.session[hash_session_key] = session_auth_hash
request.customer = customer
@@ -127,6 +181,12 @@ def customer_login(request, customer):
def customer_logout(request):
session_key = f'customer_auth_id:{request.organizer.pk}'
hash_session_key = f'customer_auth_hash:{request.organizer.pk}'
+ dependency_key = f'customer_auth_session_dependency:{request.organizer.pk}'
+
+ # Remove dependency on parent session
+ request.session.pop(dependency_key, None)
+ # We do not remove the actual parent session as we have no way of e.g. cycling its ID.
+ # Instead, LogoutView will redirect the user to the logout of the parent session.
# Remove user session
customer_id = request.session.pop(session_key, None)
diff --git a/src/pretix/presale/views/cart.py b/src/pretix/presale/views/cart.py
index be3a652da7..31e9b9da8c 100644
--- a/src/pretix/presale/views/cart.py
+++ b/src/pretix/presale/views/cart.py
@@ -343,31 +343,45 @@ def get_or_create_cart_id(request, create=True):
if current_id and current_id in request.session.get('carts', {}):
if current_id != orig_current_id:
request.session[session_keyname] = current_id
- return current_id
- else:
- cart_data = {}
- if prefix and 'take_cart_id' in request.GET and current_id:
- new_id = current_id
- cached_widget_data = widget_data_cache.get('widget_data_{}'.format(current_id))
- if cached_widget_data:
- cart_data['widget_data'] = cached_widget_data
+
+ cart_invalidated = (
+ request.session['carts'][current_id].get('customer_cart_tied_to_login', False) and
+ request.session['carts'][current_id].get('customer') and
+ (not request.customer or request.session['carts'][current_id].get('customer') != request.customer.pk)
+ )
+
+ if cart_invalidated:
+ # This cart was created with a login but the person is now logged out.
+ # Destroy the cart for privacy protection.
+ request.session['carts'][current_id] = {}
else:
- if not create:
- return None
- new_id = generate_cart_id(request, prefix=prefix)
+ return current_id
- if 'widget_data' not in cart_data and 'widget_data' in request.GET:
- try:
- cart_data['widget_data'] = json.loads(request.GET.get('widget_data'))
- except ValueError:
- pass
+ cart_data = {}
+ if prefix and 'take_cart_id' in request.GET and current_id:
+ new_id = current_id
+ cached_widget_data = widget_data_cache.get('widget_data_{}'.format(current_id))
+ if cached_widget_data:
+ cart_data['widget_data'] = cached_widget_data
+ else:
+ if not create:
+ return None
+ new_id = generate_cart_id(request, prefix=prefix)
- if 'carts' not in request.session:
- request.session['carts'] = {}
- if new_id not in request.session['carts']:
- request.session['carts'][new_id] = cart_data
- request.session[session_keyname] = new_id
- return new_id
+ if 'widget_data' not in cart_data and 'widget_data' in request.GET:
+ try:
+ cart_data['widget_data'] = json.loads(request.GET.get('widget_data'))
+ except ValueError:
+ pass
+
+ if 'carts' not in request.session:
+ request.session['carts'] = {}
+
+ if new_id not in request.session['carts']:
+ request.session['carts'][new_id] = cart_data
+
+ request.session[session_keyname] = new_id
+ return new_id
def cart_session(request):
diff --git a/src/pretix/presale/views/customer.py b/src/pretix/presale/views/customer.py
index 05c711a24c..bf29630168 100644
--- a/src/pretix/presale/views/customer.py
+++ b/src/pretix/presale/views/customer.py
@@ -19,8 +19,12 @@
# You should have received a copy of the GNU Affero General Public License along with this program. If not, see
# .
#
-from urllib.parse import quote
+from importlib import import_module
+from urllib.parse import (
+ parse_qs, quote, urlencode, urljoin, urlparse, urlsplit, urlunparse,
+)
+from django.conf import settings
from django.contrib import messages
from django.core.signing import BadSignature, dumps, loads
from django.db import transaction
@@ -38,6 +42,7 @@ from django.views.generic import DeleteView, FormView, ListView, View
from pretix.base.models import Customer, InvoiceAddress, Order, OrderPosition
from pretix.base.services.mail import mail
+from pretix.multidomain.models import KnownDomain
from pretix.multidomain.urlreverse import build_absolute_uri, eventreverse
from pretix.presale.forms.customer import (
AuthenticationForm, ChangeInfoForm, ChangePasswordForm, RegistrationForm,
@@ -47,6 +52,8 @@ from pretix.presale.utils import (
customer_login, customer_logout, update_customer_session_auth_hash,
)
+SessionStore = import_module(settings.SESSION_ENGINE).SessionStore
+
class RedirectBackMixin:
redirect_field_name = 'next'
@@ -57,9 +64,14 @@ class RedirectBackMixin:
self.redirect_field_name,
self.request.GET.get(self.redirect_field_name, '')
)
+ hosts = list(KnownDomain.objects.filter(event__organizer=self.request.organizer).values_list('domainname', flat=True))
+ siteurlsplit = urlsplit(settings.SITE_URL)
+ if siteurlsplit.port and siteurlsplit.port not in (80, 443):
+ hosts = ['%s:%d' % (h, siteurlsplit.port) for h in hosts]
+
url_is_safe = url_has_allowed_host_and_scheme(
url=redirect_to,
- allowed_hosts=None,
+ allowed_hosts=hosts,
require_https=self.request.is_secure(),
)
return redirect_to if url_is_safe else ''
@@ -96,11 +108,23 @@ class LoginView(RedirectBackMixin, FormView):
def get_success_url(self):
url = self.get_redirect_url()
- if getattr(self.request, 'event_domain', False):
- default_url = '/'
- else:
- default_url = eventreverse(self.request.organizer, 'presale:organizer.customer.profile', kwargs={})
- return url or default_url
+
+ if not url:
+ return eventreverse(self.request.organizer, 'presale:organizer.customer.profile', kwargs={})
+
+ if self.request.GET.get("request_cross_domain_customer_auth") == "true":
+ otpstore = SessionStore()
+ otpstore[f'customer_cross_domain_auth_{self.request.organizer.pk}'] = self.request.session.session_key
+ otpstore.set_expiry(60)
+ otpstore.save(must_create=True)
+ otp = otpstore.session_key
+
+ u = urlparse(url)
+ qsl = parse_qs(u.query)
+ qsl['cross_domain_customer_auth'] = otp
+ url = urlunparse((u.scheme, u.netloc, u.path, u.params, urlencode(qsl, doseq=True), u.fragment))
+
+ return url
def form_valid(self, form):
"""Security check complete. Log the user in."""
@@ -119,25 +143,39 @@ class LogoutView(View):
def get_next_page(self):
if getattr(self.request, 'event_domain', False):
- next_page = '/'
+ # After we cleared the cookies on this domain, redirect to the parent domain to clear cookies as well
+ next_page = eventreverse(self.request.organizer, 'presale:organizer.customer.logout', kwargs={})
+ if self.redirect_field_name in self.request.POST or self.redirect_field_name in self.request.GET:
+ after_next_page = self.request.POST.get(
+ self.redirect_field_name,
+ self.request.GET.get(self.redirect_field_name)
+ )
+ next_page += '?' + urlencode({
+ 'next': urljoin(f'{self.request.scheme}://{self.request.get_host()}', after_next_page)
+ })
else:
next_page = eventreverse(self.request.organizer, 'presale:organizer.index', kwargs={})
- if (self.redirect_field_name in self.request.POST or
- self.redirect_field_name in self.request.GET):
- next_page = self.request.POST.get(
- self.redirect_field_name,
- self.request.GET.get(self.redirect_field_name)
- )
- url_is_safe = url_has_allowed_host_and_scheme(
- url=next_page,
- allowed_hosts=None,
- require_https=self.request.is_secure(),
- )
- # Security check -- Ensure the user-originating redirection URL is
- # safe.
- if not url_is_safe:
- next_page = self.request.path
+ if (self.redirect_field_name in self.request.POST or
+ self.redirect_field_name in self.request.GET):
+ next_page = self.request.POST.get(
+ self.redirect_field_name,
+ self.request.GET.get(self.redirect_field_name)
+ )
+ hosts = list(KnownDomain.objects.filter(event__organizer=self.request.organizer).values_list('domainname', flat=True))
+ siteurlsplit = urlsplit(settings.SITE_URL)
+ if siteurlsplit.port and siteurlsplit.port not in (80, 443):
+ hosts = ['%s:%d' % (h, siteurlsplit.port) for h in hosts]
+ url_is_safe = url_has_allowed_host_and_scheme(
+ url=next_page,
+ allowed_hosts=hosts,
+ require_https=self.request.is_secure(),
+ )
+ # Security check -- Ensure the user-originating redirection URL is
+ # safe.
+ if not url_is_safe:
+ next_page = self.request.path
+
return next_page
diff --git a/src/tests/presale/test_checkout.py b/src/tests/presale/test_checkout.py
index 183f621a10..310f0fbc47 100644
--- a/src/tests/presale/test_checkout.py
+++ b/src/tests/presale/test_checkout.py
@@ -3815,6 +3815,26 @@ class CustomerCheckoutTestCase(BaseCheckoutTestCase, TestCase):
order = self._finish()
assert order.customer == self.customer
+ def test_login_valid_but_removed_after_logout(self):
+ response = self.client.get('/%s/%s/checkout/start' % (self.orga.slug, self.event.slug), follow=True)
+ self.assertRedirects(response, '/%s/%s/checkout/customer/' % (self.orga.slug, self.event.slug),
+ target_status_code=200)
+
+ response = self.client.post('/%s/%s/checkout/customer/' % (self.orga.slug, self.event.slug), {
+ 'customer_mode': 'login',
+ 'login-email': 'john@example.org',
+ 'login-password': 'foo',
+ }, follow=True)
+ self.assertRedirects(response, '/%s/%s/checkout/questions/' % (self.orga.slug, self.event.slug),
+ target_status_code=200)
+
+ self.client.get('/%s/account/logout' % (self.orga.slug,), follow=True)
+
+ response = self.client.get('/%s/%s/checkout/questions/' % (self.orga.slug, self.event.slug), follow=True)
+ self.assertRedirects(response, '/%s/%s/?require_cookie=true' % (self.orga.slug, self.event.slug),
+ target_status_code=200)
+ assert response.status_code == 200
+
def test_login_invalid(self):
response = self.client.get('/%s/%s/checkout/start' % (self.orga.slug, self.event.slug), follow=True)
self.assertRedirects(response, '/%s/%s/checkout/customer/' % (self.orga.slug, self.event.slug),
diff --git a/src/tests/presale/test_customer.py b/src/tests/presale/test_customer.py
index 11b17ecf19..3555e25167 100644
--- a/src/tests/presale/test_customer.py
+++ b/src/tests/presale/test_customer.py
@@ -22,14 +22,17 @@
import datetime
from datetime import timedelta
from decimal import Decimal
+from urllib.parse import parse_qs, urlparse
import pytest
from django.core import mail as djmail
from django.core.signing import dumps
+from django.test import Client
from django.utils.timezone import now
from django_scopes import scopes_disabled
from pretix.base.models import Event, Item, Order, OrderPosition, Organizer
+from pretix.multidomain.models import KnownDomain
from pretix.presale.forms.customer import TokenGenerator
@@ -412,3 +415,152 @@ def test_login_per_org(env, client):
})
assert client.get('/bigevents/account/').status_code == 200
assert client.get('/demo/account/').status_code == 302
+
+
+@pytest.fixture
+def client2():
+ # We need a second test client instance for cross domain stuff since the test client
+ # does not isolate sessions per-domain like browsers do
+ return Client()
+
+
+def _cross_domain_login(env, client, client2):
+ with scopes_disabled():
+ customer = env[0].customers.create(email='john@example.org', is_verified=True)
+ customer.set_password('foo')
+ customer.save()
+ KnownDomain.objects.create(domainname='org.test', organizer=env[0])
+ KnownDomain.objects.create(domainname='event.test', organizer=env[0], event=env[1])
+
+ # Log in on org domain
+ r = client.post('/account/login?next=https://event.test/redeem&request_cross_domain_customer_auth=true', {
+ 'email': 'john@example.org',
+ 'password': 'foo',
+ }, HTTP_HOST='org.test')
+ assert r.status_code == 302
+
+ u = urlparse(r.headers['Location'])
+ assert u.netloc == 'event.test'
+ assert u.path == '/redeem'
+ q = parse_qs(u.query)
+ assert 'cross_domain_customer_auth' in q
+
+ # Take session over to event domain
+ r = client2.get(f'/?{u.query}', HTTP_HOST='event.test')
+ assert r.status_code == 200
+ assert b'john@example.org' in r.content
+
+
+@pytest.mark.django_db
+def test_cross_domain_login(env, client, client2):
+ _cross_domain_login(env, client, client2)
+
+ # Logged in on org domain
+ r = client.get('/', HTTP_HOST='event.test')
+ assert r.status_code == 200
+ assert b'john@example.org' in r.content
+
+ # Logged in on event domain
+ r = client2.get('/', HTTP_HOST='org.test')
+ assert r.status_code == 200
+ assert b'john@example.org' in r.content
+
+
+@pytest.mark.django_db
+def test_cross_domain_logout_on_org_domain(env, client, client2):
+ _cross_domain_login(env, client, client2)
+
+ r = client.get('/account/logout', HTTP_HOST='org.test')
+ assert r.status_code == 302
+
+ # Logged out on org domain
+ r = client.get('/', HTTP_HOST='event.test')
+ assert r.status_code == 200
+ assert b'john@example.org' not in r.content
+
+ # Logged out on event domain
+ r = client2.get('/', HTTP_HOST='org.test')
+ assert r.status_code == 200
+ assert b'john@example.org' not in r.content
+
+
+@pytest.mark.django_db
+def test_cross_domain_logout_on_event_domain(env, client, client2):
+ _cross_domain_login(env, client, client2)
+
+ r = client2.get('/account/logout?next=/redeem', HTTP_HOST='event.test')
+ assert r.status_code == 302
+
+ u = urlparse(r.headers['Location'])
+ assert u.netloc == 'org.test'
+ assert u.path == '/account/logout'
+
+ r = client.get(f'{u.path}?{u.query}', HTTP_HOST='org.test')
+ assert r.status_code == 302
+ assert r.headers['Location'] == 'http://event.test/redeem'
+
+ # Logged out on org domain
+ r = client.get('/', HTTP_HOST='event.test')
+ assert r.status_code == 200
+ assert b'john@example.org' not in r.content
+
+ # Logged out on event domain
+ r = client2.get('/', HTTP_HOST='org.test')
+ assert r.status_code == 200
+ assert b'john@example.org' not in r.content
+
+
+@pytest.mark.django_db
+def test_cross_domain_login_otp_only_valid_once(env, client, client2):
+ with scopes_disabled():
+ customer = env[0].customers.create(email='john@example.org', is_verified=True)
+ customer.set_password('foo')
+ customer.save()
+ KnownDomain.objects.create(domainname='org.test', organizer=env[0])
+ KnownDomain.objects.create(domainname='event.test', organizer=env[0], event=env[1])
+
+ # Log in on org domain
+ r = client.post('/account/login?next=https://event.test/redeem&request_cross_domain_customer_auth=true', {
+ 'email': 'john@example.org',
+ 'password': 'foo',
+ }, HTTP_HOST='org.test')
+ assert r.status_code == 302
+
+ u = urlparse(r.headers['Location'])
+ assert u.netloc == 'event.test'
+ assert u.path == '/redeem'
+ q = parse_qs(u.query)
+ assert 'cross_domain_customer_auth' in q
+
+ # Take session over to event domain
+ r = client.get(f'/?{u.query}', HTTP_HOST='event.test')
+ assert r.status_code == 200
+ assert b'john@example.org' in r.content
+
+ # Try to use again
+ r = client2.get(f'/?{u.query}', HTTP_HOST='event.test')
+ assert r.status_code == 200
+ assert b'john@example.org' not in r.content
+
+
+@pytest.mark.django_db
+def test_cross_domain_login_validate_redirect_url(env, client, client2):
+ with scopes_disabled():
+ customer = env[0].customers.create(email='john@example.org', is_verified=True)
+ customer.set_password('foo')
+ customer.save()
+ KnownDomain.objects.create(domainname='org.test', organizer=env[0])
+ KnownDomain.objects.create(domainname='event.test', organizer=env[0], event=env[1])
+
+ # Log in on org domain
+ r = client.post('/account/login?next=https://evilcorp.test/redeem&request_cross_domain_customer_auth=true', {
+ 'email': 'john@example.org',
+ 'password': 'foo',
+ }, HTTP_HOST='org.test')
+ assert r.status_code == 302
+
+ u = urlparse(r.headers['Location'])
+ assert u.netloc == 'org.test'
+ assert u.path == '/account/'
+ q = parse_qs(u.query)
+ assert 'cross_domain_customer_auth' not in q