Allow to restrict payment methods by invoice address country

This commit is contained in:
Raphael Michel
2018-09-19 16:10:40 +02:00
parent 1155d18b7f
commit 06d9c48ed4
4 changed files with 101 additions and 6 deletions

View File

@@ -57,7 +57,7 @@ class SettingsForm(i18nfield.forms.I18nFormMixin, HierarkeyForm):
kwargs['locales'] = self.locales kwargs['locales'] = self.locales
kwargs['initial'] = self.obj.settings.freeze() kwargs['initial'] = self.obj.settings.freeze()
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
for f in self.fields.values(): for k, f in self.fields.items():
if isinstance(f, (RelativeDateTimeField, RelativeDateField)): if isinstance(f, (RelativeDateTimeField, RelativeDateField)):
f.set_event(self.obj) f.set_event(self.obj)

View File

@@ -14,12 +14,14 @@ from django.http import HttpRequest
from django.template.loader import get_template from django.template.loader import get_template
from django.utils.timezone import now from django.utils.timezone import now
from django.utils.translation import pgettext_lazy, ugettext_lazy as _ from django.utils.translation import pgettext_lazy, ugettext_lazy as _
from django_countries import Countries
from i18nfield.forms import I18nFormField, I18nTextarea, I18nTextInput from i18nfield.forms import I18nFormField, I18nTextarea, I18nTextInput
from i18nfield.strings import LazyI18nString from i18nfield.strings import LazyI18nString
from pretix.base.forms import PlaceholderValidator from pretix.base.forms import PlaceholderValidator
from pretix.base.models import ( from pretix.base.models import (
CartPosition, Event, Order, OrderPayment, OrderRefund, Quota, CartPosition, Event, InvoiceAddress, Order, OrderPayment, OrderRefund,
Quota,
) )
from pretix.base.reldate import RelativeDateField, RelativeDateWrapper from pretix.base.reldate import RelativeDateField, RelativeDateWrapper
from pretix.base.settings import SettingsSandbox from pretix.base.settings import SettingsSandbox
@@ -28,7 +30,7 @@ from pretix.base.templatetags.money import money_filter
from pretix.base.templatetags.rich_text import rich_text from pretix.base.templatetags.rich_text import rich_text
from pretix.helpers.money import DecimalTextInput from pretix.helpers.money import DecimalTextInput
from pretix.presale.views import get_cart_total from pretix.presale.views import get_cart_total
from pretix.presale.views.cart import get_or_create_cart_id from pretix.presale.views.cart import cart_session, get_or_create_cart_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -179,7 +181,7 @@ class BasePaymentProvider:
implementation. implementation.
""" """
places = settings.CURRENCY_PLACES.get(self.event.currency, 2) places = settings.CURRENCY_PLACES.get(self.event.currency, 2)
return OrderedDict([ d = OrderedDict([
('_enabled', ('_enabled',
forms.BooleanField( forms.BooleanField(
label=_('Enable payment method'), label=_('Enable payment method'),
@@ -250,7 +252,20 @@ class BasePaymentProvider:
'above!').format(docs_url='https://docs.pretix.eu/en/latest/user/payments/fees.html'), 'above!').format(docs_url='https://docs.pretix.eu/en/latest/user/payments/fees.html'),
required=False required=False
)), )),
('_restricted_countries',
forms.MultipleChoiceField(
label=_('Restrict to countries'),
choices=Countries(),
help_text=_('Only allow choosing this payment provider for invoice addresses in the selected '
'countries. If you don\'t select any country, all countries are allowed.'),
widget=forms.CheckboxSelectMultiple(
attrs={'class': 'scrolling-multiple-choice'}
),
required=False
)),
]) ])
d['_restricted_countries']._as_type = list
return d
def settings_content_render(self, request: HttpRequest) -> str: def settings_content_render(self, request: HttpRequest) -> str:
""" """
@@ -350,7 +365,8 @@ class BasePaymentProvider:
during checkout, not on retrying. during checkout, not on retrying.
The default implementation checks for the _availability_date setting to be either unset or in the future The default implementation checks for the _availability_date setting to be either unset or in the future
and for the _total_max and _total_min requirements to be met. and for the _total_max and _total_min requirements to be met. It also checks the ``_restrict_countries``
setting.
:param total: The total value without the payment method fee, after taxes. :param total: The total value without the payment method fee, after taxes.
@@ -371,6 +387,25 @@ class BasePaymentProvider:
if self.settings._total_min is not None: if self.settings._total_min is not None:
pricing = pricing and total >= Decimal(self.settings._total_min) pricing = pricing and total >= Decimal(self.settings._total_min)
def get_invoice_address():
if not hasattr(request, '_checkout_flow_invoice_address'):
cs = cart_session(request)
iapk = cs.get('invoice_address')
if not iapk:
request._checkout_flow_invoice_address = InvoiceAddress()
else:
try:
request._checkout_flow_invoice_address = InvoiceAddress.objects.get(pk=iapk, order__isnull=True)
except InvoiceAddress.DoesNotExist:
request._checkout_flow_invoice_address = InvoiceAddress()
return request._checkout_flow_invoice_address
restricted_countries = self.settings.get('_restricted_countries', as_type=list)
if restricted_countries:
ia = get_invoice_address()
if str(ia.country) not in restricted_countries:
return False
return timing and pricing return timing and pricing
def payment_form_render(self, request: HttpRequest, total: Decimal) -> str: def payment_form_render(self, request: HttpRequest, total: Decimal) -> str:
@@ -503,7 +538,8 @@ class BasePaymentProvider:
Will be called to check whether it is allowed to change the payment method of Will be called to check whether it is allowed to change the payment method of
an order to this one. an order to this one.
The default implementation checks for the _availability_date setting to be either unset or in the future. The default implementation checks for the _availability_date setting to be either unset or in the future,
as well as for the _total_max, _total_min and _restricted_countries settings.
:param order: The order object :param order: The order object
""" """
@@ -514,6 +550,16 @@ class BasePaymentProvider:
if self.settings._total_min is not None and ps < Decimal(self.settings._total_min): if self.settings._total_min is not None and ps < Decimal(self.settings._total_min):
return False return False
restricted_countries = self.settings.get('_restricted_countries', as_type=list)
if restricted_countries:
try:
ia = order.invoice_address
except InvoiceAddress.DoesNotExist:
return True
else:
if str(ia.country) not in restricted_countries:
return False
return self._is_still_available(order=order) return self._is_still_available(order=order)
def payment_prepare(self, request: HttpRequest, payment: OrderPayment) -> Union[bool, str]: def payment_prepare(self, request: HttpRequest, payment: OrderPayment) -> Union[bool, str]:

View File

@@ -497,6 +497,9 @@ class ProviderForm(SettingsForm):
elif isinstance(v, (RelativeDateTimeField, RelativeDateField)): elif isinstance(v, (RelativeDateTimeField, RelativeDateField)):
v.set_event(self.obj) v.set_event(self.obj)
if hasattr(v, '_as_type'):
self.initial[k] = self.obj.settings.get(k, as_type=v._as_type)
def clean(self): def clean(self):
cleaned_data = super().clean() cleaned_data = super().clean()
enabled = cleaned_data.get(self.settingspref + '_enabled') enabled = cleaned_data.get(self.settingspref + '_enabled')

View File

@@ -731,6 +731,52 @@ class CheckoutTestCase(TestCase):
doc = BeautifulSoup(response.rendered_content, "lxml") doc = BeautifulSoup(response.rendered_content, "lxml")
assert doc.select(".alert-danger") assert doc.select(".alert-danger")
def test_payment_country_allowed(self):
self.event.settings.set('payment_stripe__enabled', True)
self.event.settings.set('payment_banktransfer__restricted_countries', ['DE', 'AT'])
self.event.settings.set('payment_banktransfer__enabled', True)
ia = InvoiceAddress.objects.create(
is_business=True, vat_id='ATU1234567', vat_id_validated=True,
country=Country('DE')
)
self._set_session('invoice_address', ia.pk)
CartPosition.objects.create(
event=self.event, cart_id=self.session_key, item=self.ticket,
price=23, expires=now() + timedelta(minutes=10)
)
response = self.client.get('/%s/%s/checkout/payment/' % (self.orga.slug, self.event.slug), follow=True)
doc = BeautifulSoup(response.rendered_content, "lxml")
self.assertEqual(len(doc.select('input[name=payment]')), 2)
response = self.client.post('/%s/%s/checkout/payment/' % (self.orga.slug, self.event.slug), {
'payment': 'banktransfer'
}, follow=True)
self.assertEqual(response.status_code, 200)
doc = BeautifulSoup(response.rendered_content, "lxml")
assert not doc.select(".alert-danger")
def test_payment_country_blocked(self):
self.event.settings.set('payment_stripe__enabled', True)
self.event.settings.set('payment_banktransfer__restricted_countries', ['DE', 'AT'])
self.event.settings.set('payment_banktransfer__enabled', True)
ia = InvoiceAddress.objects.create(
is_business=True, vat_id='ATU1234567', vat_id_validated=True,
country=Country('CH')
)
self._set_session('invoice_address', ia.pk)
CartPosition.objects.create(
event=self.event, cart_id=self.session_key, item=self.ticket,
price=23, expires=now() + timedelta(minutes=10)
)
response = self.client.get('/%s/%s/checkout/payment/' % (self.orga.slug, self.event.slug), follow=True)
doc = BeautifulSoup(response.rendered_content, "lxml")
self.assertEqual(len(doc.select('input[name=payment]')), 1)
response = self.client.post('/%s/%s/checkout/payment/' % (self.orga.slug, self.event.slug), {
'payment': 'banktransfer'
}, follow=True)
self.assertEqual(response.status_code, 200)
doc = BeautifulSoup(response.rendered_content, "lxml")
assert doc.select(".alert-danger")
def test_premature_confirm(self): def test_premature_confirm(self):
response = self.client.get('/%s/%s/checkout/confirm/' % (self.orga.slug, self.event.slug), follow=True) response = self.client.get('/%s/%s/checkout/confirm/' % (self.orga.slug, self.event.slug), follow=True)
self.assertRedirects(response, '/%s/%s/?require_cookie=true' % (self.orga.slug, self.event.slug), self.assertRedirects(response, '/%s/%s/?require_cookie=true' % (self.orga.slug, self.event.slug),