diff --git a/src/pretix/base/auth.py b/src/pretix/base/auth.py index a359ed93c..fa1cca25d 100644 --- a/src/pretix/base/auth.py +++ b/src/pretix/base/auth.py @@ -85,6 +85,16 @@ class BaseAuthBackend: """ return + def get_next_url(self, request): + """ + This method will be called after a successful login to determine the next URL. Pretix in general uses the + ``'next'`` query parameter. However, external authentication methods could use custom attributes with hardcoded + names for security purposes. For example, OAuth uses ``'state'`` for keeping track of application state. + """ + if "next" in request.GET: + return request.GET.get("next") + return None + class NativeAuthBackend(BaseAuthBackend): identifier = 'native' diff --git a/src/pretix/control/views/auth.py b/src/pretix/control/views/auth.py index 5bbf52787..176520a46 100644 --- a/src/pretix/control/views/auth.py +++ b/src/pretix/control/views/auth.py @@ -39,18 +39,19 @@ def process_login(request, user, keep_logged_in): :return: This method returns a ``HttpResponse``. """ request.session['pretix_auth_long_session'] = settings.PRETIX_LONG_SESSIONS and keep_logged_in + next_url = get_auth_backends()[user.auth_backend].get_next_url(request) if user.require_2fa: request.session['pretix_auth_2fa_user'] = user.pk request.session['pretix_auth_2fa_time'] = str(int(time.time())) twofa_url = reverse('control:auth.login.2fa') - if "next" in request.GET and is_safe_url(request.GET.get("next"), allowed_hosts=None): - twofa_url += '?next=' + quote(request.GET.get('next')) + if next_url and is_safe_url(next_url, allowed_hosts=None): + twofa_url += '?next=' + quote(next_url) return redirect(twofa_url) else: auth_login(request, user) request.session['pretix_auth_login_time'] = int(time.time()) - if "next" in request.GET and is_safe_url(request.GET.get("next"), allowed_hosts=None): - return redirect(request.GET.get("next")) + if next_url and is_safe_url(next_url, allowed_hosts=None): + return redirect(next_url) return redirect(reverse('control:index')) @@ -72,7 +73,8 @@ def login(request): if not backend.visible: backend = [b for b in backends if b.visible][0] if request.user.is_authenticated: - return redirect(request.GET.get("next", 'control:index')) + next_url = backend.get_next_url(request) or 'control:index' + return redirect(next_url) if request.method == 'POST': form = LoginForm(backend=backend, data=request.POST) if form.is_valid() and form.user_cache and form.user_cache.auth_backend == backend.identifier: diff --git a/src/pretix/control/views/user.py b/src/pretix/control/views/user.py index 120ca1064..21dd22f2c 100644 --- a/src/pretix/control/views/user.py +++ b/src/pretix/control/views/user.py @@ -101,18 +101,21 @@ class ReauthView(TemplateView): t = int(time.time()) request.session['pretix_auth_login_time'] = t request.session['pretix_auth_last_used'] = t - if "next" in request.GET and is_safe_url(request.GET.get("next"), allowed_hosts=None): - return redirect(request.GET.get("next")) + next_url = get_auth_backends()[request.user.auth_backend].get_next_url(request) + if next_url and is_safe_url(next_url, allowed_hosts=None): + return redirect(next_url) return redirect(reverse('control:index')) else: messages.error(request, _('The password you entered was invalid, please try again.')) return self.get(request, *args, **kwargs) def get(self, request, *args, **kwargs): - u = get_auth_backends()[request.user.auth_backend].request_authenticate(request) + backend = get_auth_backends()[request.user.auth_backend] + u = backend.request_authenticate(request) if u and u == request.user: - if "next" in request.GET and is_safe_url(request.GET.get("next"), allowed_hosts=None): - return redirect(request.GET.get("next")) + next_url = backend.get_next_url(request) + if next_url and is_safe_url(next_url, allowed_hosts=None): + return redirect(next_url) return redirect(reverse('control:index')) return super().get(request, *args, **kwargs) diff --git a/src/tests/control/test_auth.py b/src/tests/control/test_auth.py index cf3fc8007..69c7effbd 100644 --- a/src/tests/control/test_auth.py +++ b/src/tests/control/test_auth.py @@ -148,6 +148,11 @@ class LoginFormTest(TestCase): response = self.client.get('/control/') assert b'hallo@example.org' in response.content + def test_custom_get_next_url(self): + response = self.client.get('/control/login?state=/control/events/', HTTP_X_LOGIN_EMAIL='hallo@example.org') + self.assertEqual(response.status_code, 302) + self.assertIn('/control/events/', response['Location']) + class RegistrationFormTest(TestCase): diff --git a/src/tests/testdummy/auth.py b/src/tests/testdummy/auth.py index 639f1b8da..7a601dc32 100644 --- a/src/tests/testdummy/auth.py +++ b/src/tests/testdummy/auth.py @@ -36,3 +36,8 @@ class TestRequestAuthBackend(BaseAuthBackend): email=request.headers['X-Login-Email'], auth_backend='test_request' )[0] + + def get_next_url(self, request): + if 'state' in request.GET: + return request.GET.get('state') + return None