Use namedtuple

This commit is contained in:
Mira Weller
2024-07-12 15:20:28 +02:00
parent d29b183801
commit 6bf16f1510
2 changed files with 22 additions and 26 deletions

View File

@@ -20,11 +20,11 @@
# <https://www.gnu.org/licenses/>.
#
from collections import defaultdict
from collections import defaultdict, namedtuple
from decimal import Decimal
from itertools import groupby
from math import ceil, inf
from typing import Dict, Optional, Tuple
from typing import Dict
from django.core.exceptions import ValidationError
from django.core.validators import MinValueValidator
@@ -36,11 +36,7 @@ from django_scopes import ScopedManager
from pretix.base.decimal import round_decimal
from pretix.base.models.base import LoggedModel
ITEM_ID = 0
SUBEVENT_ID = 1
LINE_PRICE_GROSS = 2
IS_ADDON_TO = 3
VOUCHER_DISCOUNT = 4
PositionInfo = namedtuple('PositionInfo', ['item_id', 'subevent_id', 'line_price_gross', 'is_addon_to', 'voucher_discount'])
class Discount(LoggedModel):
@@ -252,14 +248,14 @@ class Discount(LoggedModel):
return True
def _apply_min_value(self, positions, condition_idx_group, benefit_idx_group, result, collect_potential_discounts):
if self.condition_min_value and sum(positions[idx][LINE_PRICE_GROSS] for idx in condition_idx_group) < self.condition_min_value:
if self.condition_min_value and sum(positions[idx].line_price_gross for idx in condition_idx_group) < self.condition_min_value:
return
if self.condition_min_count or self.benefit_only_apply_to_cheapest_n_matches:
raise ValueError('Validation invariant violated.')
for idx in benefit_idx_group:
previous_price = positions[idx][LINE_PRICE_GROSS]
previous_price = positions[idx].line_price_gross
new_price = round_decimal(
previous_price * (Decimal('100.00') - self.benefit_discount_matching_percent) / Decimal('100.00'),
self.event.currency,
@@ -282,8 +278,8 @@ class Discount(LoggedModel):
raise ValueError('Validation invariant violated.')
# sort by line_price
condition_idx_group = sorted(condition_idx_group, key=lambda idx: (positions[idx][LINE_PRICE_GROSS], -idx))
benefit_idx_group = sorted(benefit_idx_group, key=lambda idx: (positions[idx][LINE_PRICE_GROSS], -idx))
condition_idx_group = sorted(condition_idx_group, key=lambda idx: (positions[idx].line_price_gross, -idx))
benefit_idx_group = sorted(benefit_idx_group, key=lambda idx: (positions[idx].line_price_gross, -idx))
# Prevent over-consuming of items, i.e. if our discount is "buy 2, get 1 free", we only
# want to match multiples of 3
@@ -324,7 +320,7 @@ class Discount(LoggedModel):
collect_potential_discounts[idx] = [(self, inf, -1)]
for idx in benefit_idx:
previous_price = positions[idx][LINE_PRICE_GROSS]
previous_price = positions[idx].line_price_gross
new_price = round_decimal(
previous_price * (Decimal('100.00') - self.benefit_discount_matching_percent) / Decimal('100.00'),
self.event.currency,
@@ -332,15 +328,14 @@ class Discount(LoggedModel):
result[idx] = new_price
for idx in consume_idx:
result.setdefault(idx, positions[idx][LINE_PRICE_GROSS])
result.setdefault(idx, positions[idx].line_price_gross)
def apply(self, positions: Dict[int, Tuple[int, Optional[int], Decimal, bool, Decimal]],
def apply(self, positions: Dict[int, PositionInfo],
collect_potential_discounts=None) -> Dict[int, Decimal]:
"""
Tries to apply this discount to a cart
:param positions: Dictionary mapping IDs to tuples of the form
``(item_id, subevent_id, line_price_gross, is_addon_to, voucher_discount)``.
:param positions: Dictionary mapping IDs to PositionInfo tuples.
Bundled positions may not be included.
:return: A dictionary mapping keys from the input dictionary to new prices. All positions
@@ -389,7 +384,7 @@ class Discount(LoggedModel):
elif self.subevent_mode == self.SUBEVENT_MODE_SAME:
def key(idx):
return positions[idx][1] or 0 # subevent_id
return positions[idx].subevent_id or 0
# Build groups of candidates with the same subevent, then apply our regular algorithm
# to each group
@@ -398,7 +393,7 @@ class Discount(LoggedModel):
candidate_groups = [(k, list(g)) for k, g in _groups]
for subevent_id, g in candidate_groups:
benefit_g = [idx for idx in benefit_candidates if positions[idx][SUBEVENT_ID] == subevent_id]
benefit_g = [idx for idx in benefit_candidates if positions[idx].subevent_id == subevent_id]
if self.condition_min_count:
self._apply_min_count(positions, g, benefit_g, result, collect_potential_discounts)
else:
@@ -418,9 +413,9 @@ class Discount(LoggedModel):
# Build a list of subevent IDs in descending order of frequency
subevent_to_idx = defaultdict(list)
for idx, p in positions.items():
subevent_to_idx[p[SUBEVENT_ID]].append(idx)
subevent_to_idx[p.subevent_id].append(idx)
for v in subevent_to_idx.values():
v.sort(key=lambda idx: positions[idx][LINE_PRICE_GROSS])
v.sort(key=lambda idx: positions[idx].line_price_gross)
subevent_order = sorted(list(subevent_to_idx.keys()), key=lambda s: len(subevent_to_idx[s]), reverse=True)
# Build groups of exactly condition_min_count distinct subevents
@@ -435,7 +430,7 @@ class Discount(LoggedModel):
l = [ll for ll in l if ll in condition_candidates and ll not in current_group]
if cardinality and len(l) != cardinality:
continue
if se not in {positions[idx][SUBEVENT_ID] for idx in current_group}:
if se not in {positions[idx].subevent_id for idx in current_group}:
candidates += l
cardinality = len(l)
@@ -444,7 +439,7 @@ class Discount(LoggedModel):
# Sort the list by prices, then pick one. For "buy 2 get 1 free" we apply a "pick 1 from the start
# and 2 from the end" scheme to optimize price distribution among groups
candidates = sorted(candidates, key=lambda idx: positions[idx][LINE_PRICE_GROSS])
candidates = sorted(candidates, key=lambda idx: positions[idx].line_price_gross)
if len(current_group) < (self.benefit_only_apply_to_cheapest_n_matches or 0):
candidate = candidates[0]
else:
@@ -456,14 +451,14 @@ class Discount(LoggedModel):
if len(current_group) >= max(self.condition_min_count, 1):
candidate_groups.append(current_group)
for c in current_group:
subevent_to_idx[positions[c][SUBEVENT_ID]].remove(c)
subevent_to_idx[positions[c].subevent_id].remove(c)
current_group = []
# Distribute "leftovers"
for se in subevent_order:
if subevent_to_idx[se]:
for group in candidate_groups:
if se not in {positions[idx][SUBEVENT_ID] for idx in group}:
if se not in {positions[idx].subevent_id for idx in group}:
group.append(subevent_to_idx[se].pop())
if not subevent_to_idx[se]:
break

View File

@@ -31,6 +31,7 @@ from pretix.base.models import (
AbstractPosition, InvoiceAddress, Item, ItemAddOn, ItemVariation,
SalesChannel, Voucher,
)
from pretix.base.models.discount import PositionInfo
from pretix.base.models.event import Event, SubEvent
from pretix.base.models.tax import TAXED_ZERO, TaxedPrice, TaxRule
from pretix.base.timemachine import time_machine_now
@@ -156,7 +157,7 @@ def get_line_price(price_after_voucher: Decimal, custom_price_input: Decimal, cu
def apply_discounts(event: Event, sales_channel: str,
positions: List[Tuple[int, Optional[int], Decimal, bool, bool]],
positions: List[Tuple[int, Optional[int], Decimal, bool, bool, Decimal]],
collect_potential_discounts=None) -> List[Decimal]:
"""
Applies any dynamic discounts to a cart
@@ -178,7 +179,7 @@ def apply_discounts(event: Event, sales_channel: str,
).prefetch_related('condition_limit_products', 'benefit_limit_products').order_by('position', 'pk')
for discount in discount_qs:
result = discount.apply({
idx: (item_id, subevent_id, line_price_gross, is_addon_to, voucher_discount)
idx: PositionInfo(item_id, subevent_id, line_price_gross, is_addon_to, voucher_discount)
for idx, (item_id, subevent_id, line_price_gross, is_addon_to, is_bundled, voucher_discount) in enumerate(positions)
if not is_bundled and idx not in new_prices
}, collect_potential_discounts)