refactor cross-selling logic into its own module

This commit is contained in:
Mira Weller
2024-07-15 11:27:01 +02:00
parent cb635b2c37
commit ac771b8ca8
3 changed files with 188 additions and 156 deletions

View File

@@ -38,11 +38,10 @@ import os
import sys
import uuid
import warnings
from collections import Counter, OrderedDict, defaultdict
from collections import Counter, OrderedDict
from datetime import date, datetime, time, timedelta
from decimal import Decimal, DecimalException
from itertools import groupby
from typing import List, Optional, Tuple
from typing import Optional, Tuple
from zoneinfo import ZoneInfo
import dateutil.parser
@@ -65,7 +64,7 @@ from django_scopes import ScopedManager
from i18nfield.fields import I18nCharField, I18nTextField
from pretix.base.media import MEDIA_TYPES
from pretix.base.models import CartPosition, Event, SalesChannel, SubEvent
from pretix.base.models import Event, SubEvent
from pretix.base.models.base import LoggedModel
from pretix.base.models.fields import MultiStringField
from pretix.base.models.tax import TaxedPrice
@@ -144,78 +143,6 @@ class ItemCategory(LoggedModel):
verbose_name_plural = _("Product categories")
ordering = ('position', 'id')
def cross_sell_visible(self, cartpositions: List[CartPosition], sales_channel: SalesChannel):
"""
If this category should be visible in the cross-selling step for a given cart and sales_channel, this method
returns a queryset of the items that should be displayed, as well as a dict giving additional information on them.
:returns: (QuerySet<Item>, dict<item_pk: (max_count, discount_rule)>)
max_count is `inf` if the item should not be limited
discount_rule is None if the item will not be discounted
"""
if self.cross_selling_mode is None:
return None, {}
if self.cross_selling_condition == 'always':
return self.items.all(), {}
if self.cross_selling_condition == 'products':
match = set(match.pk for match in self.cross_selling_match_products.only('pk')) # TODO prefetch this
return (self.items.all(), {}) if any(pos.item.pk in match for pos in cartpositions) else (None, {})
if self.cross_selling_condition == 'discounts':
if not hasattr(self.event, '_potential_discounts_by_item_for_current_cart'):
potential_discounts_by_cartpos = defaultdict(list)
from ..services.pricing import apply_discounts
apply_discounts(
self.event,
sales_channel,
[
(cp.item_id, cp.subevent_id, cp.line_price_gross, bool(cp.addon_to), cp.is_bundled,
cp.listed_price - cp.price_after_voucher)
for cp in cartpositions
],
collect_potential_discounts=potential_discounts_by_cartpos
)
# flatten potential_discounts_by_cartpos (a dict of lists of potential discounts) into a set of potential discounts
# (which is technically stored as a dict, but we use it as an OrderedSet here)
potential_discount_set = dict.fromkeys(info for lst in potential_discounts_by_cartpos.values() for info in lst)
# sum up the max_counts and pass them on (also pass on the discount_rules so we can calculate actual discounted prices later):
# group by benefit product
# - max_count for product: sum up max_counts
# - discount_rule for product: take first discount_rule
def discount_info(item, infos_for_item):
infos_for_item = list(infos_for_item)
return (
item,
sum(max_count for (item, discount_rule, max_count, i) in infos_for_item),
next(discount_rule for (item, discount_rule, max_count, i) in infos_for_item)
)
self.event._potential_discounts_by_item_for_current_cart = [
discount_info(item, infos_for_item) for item, infos_for_item in
groupby(
sorted(
(
(item, discount_rule, max_count, i)
for (discount_rule, max_count, i) in potential_discount_set.keys()
for item in discount_rule.benefit_limit_products.all()
),
key=lambda tup: tup[0].pk
),
lambda tup: tup[0])
]
my_item_pks = self.items.values_list('pk', flat=True)
potential_discount_items = {
item.pk: (max_count, discount_rule)
for item, max_count, discount_rule in self.event._potential_discounts_by_item_for_current_cart
if max_count > 0 and item.pk in my_item_pks and item.is_available()
}
return self.items.filter(pk__in=potential_discount_items), potential_discount_items
def __str__(self):
name = self.internal_name or self.name
category_type = self.get_category_type_display()