Add progress bar to some large exports

This commit is contained in:
Raphael Michel
2020-07-23 21:27:14 +02:00
parent fc5c3caf66
commit a310c33497
8 changed files with 114 additions and 17 deletions

View File

@@ -1,6 +1,6 @@
import io
import tempfile
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from decimal import Decimal
from typing import Tuple
@@ -20,8 +20,9 @@ class BaseExporter:
This is the base class for all data exporters
"""
def __init__(self, event):
def __init__(self, event, progress_callback=lambda v: None):
self.event = event
self.progress_callback = progress_callback
self.is_multievent = isinstance(event, QuerySet)
if isinstance(event, QuerySet):
self.events = event
@@ -94,6 +95,7 @@ class BaseExporter:
class ListExporter(BaseExporter):
ProgressSetTotal = namedtuple('ProgressSetTotal', 'total')
@property
def export_form_fields(self) -> dict:
@@ -127,21 +129,39 @@ class ListExporter(BaseExporter):
def _render_csv(self, form_data, output_file=None, **kwargs):
if output_file:
writer = csv.writer(output_file, **kwargs)
total = 0
counter = 0
for line in self.iterate_list(form_data):
if isinstance(line, self.ProgressSetTotal):
total = line.total
continue
line = [
localize(f) if isinstance(f, Decimal) else f
for f in line
]
if total:
counter += 1
if counter % max(10, total // 100) == 0:
self.progress_callback(counter / total * 100)
writer.writerow(line)
return self.get_filename() + '.csv', 'text/csv', None
else:
output = io.StringIO()
writer = csv.writer(output, **kwargs)
total = 0
counter = 0
for line in self.iterate_list(form_data):
if isinstance(line, self.ProgressSetTotal):
total = line.total
continue
line = [
localize(f) if isinstance(f, Decimal) else f
for f in line
]
if total:
counter += 1
if counter % max(10, total // 100) == 0:
self.progress_callback(counter / total * 100)
writer.writerow(line)
return self.get_filename() + '.csv', 'text/csv', output.getvalue().encode("utf-8")
@@ -152,11 +172,20 @@ class ListExporter(BaseExporter):
ws.title = str(self.verbose_name)
except:
pass
total = 0
counter = 0
for i, line in enumerate(self.iterate_list(form_data)):
if isinstance(line, self.ProgressSetTotal):
total = line.total
continue
ws.append([
str(val) if not isinstance(val, KNOWN_TYPES) else val
for val in line
])
if total:
counter += 1
if counter % max(10, total // 100) == 0:
self.progress_callback(counter / total * 100)
if output_file:
wb.save(output_file)
@@ -214,35 +243,61 @@ class MultiSheetListExporter(ListExporter):
raise NotImplementedError() # noqa
def _render_sheet_csv(self, form_data, sheet, output_file=None, **kwargs):
total = 0
counter = 0
if output_file:
writer = csv.writer(output_file, **kwargs)
for line in self.iterate_sheet(form_data, sheet):
if isinstance(line, self.ProgressSetTotal):
total = line.total
continue
line = [
localize(f) if isinstance(f, Decimal) else f
for f in line
]
writer.writerow(line)
if total:
counter += 1
if counter % max(10, total // 100) == 0:
self.progress_callback(counter / total * 100)
return self.get_filename() + '.csv', 'text/csv', None
else:
output = io.StringIO()
writer = csv.writer(output, **kwargs)
for line in self.iterate_sheet(form_data, sheet):
if isinstance(line, self.ProgressSetTotal):
total = line.total
continue
line = [
localize(f) if isinstance(f, Decimal) else f
for f in line
]
writer.writerow(line)
if total:
counter += 1
if counter % max(10, total // 100) == 0:
self.progress_callback(counter / total * 100)
return self.get_filename() + '.csv', 'text/csv', output.getvalue().encode("utf-8")
def _render_xlsx(self, form_data, output_file=None):
wb = Workbook(write_only=True)
for s, l in self.sheets:
n_sheets = len(self.sheets)
for i_sheet, (s, l) in enumerate(self.sheets):
ws = wb.create_sheet(str(l))
total = 0
counter = 0
for i, line in enumerate(self.iterate_sheet(form_data, sheet=s)):
if isinstance(line, self.ProgressSetTotal):
total = line.total
continue
ws.append([
str(val) if not isinstance(val, KNOWN_TYPES) else val
for val in line
])
if total:
counter += 1
if counter % max(10, total // 100) == 0:
self.progress_callback(counter / total * 100 / n_sheets + 100 / n_sheets * i_sheet)
if output_file:
wb.save(output_file)

View File

@@ -99,7 +99,12 @@ class InvoiceExporter(InvoiceExporterMixin, BaseExporter):
qs = self.invoices_queryset(form_data).filter(shredded=False)
with tempfile.TemporaryDirectory() as d:
any = False
total = qs.count()
if not total:
return None
counter = 0
with ZipFile(output_file or os.path.join(d, 'tmp.zip'), 'w') as zipf:
for i in qs.iterator():
try:
@@ -108,18 +113,16 @@ class InvoiceExporter(InvoiceExporterMixin, BaseExporter):
i.refresh_from_db()
i.file.open('rb')
zipf.writestr('{}.pdf'.format(i.number), i.file.read())
any = True
i.file.close()
except FileNotFoundError:
invoice_pdf_task.apply(args=(i.pk,))
i.refresh_from_db()
i.file.open('rb')
zipf.writestr('{}.pdf'.format(i.number), i.file.read())
any = True
i.file.close()
if not any:
return None
counter += 1
if total and counter % max(10, total // 100) == 0:
self.progress_callback(counter / total * 100)
if self.is_multievent:
filename = '{}_invoices.zip'.format(self.events.first().organizer.slug)
@@ -222,6 +225,7 @@ class InvoiceDataExporter(InvoiceExporterMixin, MultiSheetListExporter):
)
all_ids = base_qs.order_by('full_invoice_no').values_list('pk', flat=True)
yield self.ProgressSetTotal(total=len(all_ids))
for ids in chunked_iterable(all_ids, 1000):
invs = sorted(qs.filter(id__in=ids), key=lambda k: ids.index(k.pk))
@@ -326,6 +330,7 @@ class InvoiceDataExporter(InvoiceExporterMixin, MultiSheetListExporter):
).order_by('invoice__full_invoice_no', 'position').select_related(
'invoice', 'invoice__order', 'invoice__refers'
)
yield self.ProgressSetTotal(total=qs.count())
for l in qs.iterator():
i = l.invoice

View File

@@ -170,6 +170,7 @@ class OrderListExporter(MultiSheetListExporter):
)
}
yield self.ProgressSetTotal(total=qs.count())
for order in qs.order_by('datetime').iterator():
tz = pytz.timezone(self.event_object_cache[order.event_id].settings.timezone)
@@ -278,6 +279,7 @@ class OrderListExporter(MultiSheetListExporter):
headers.append(_('Payment providers'))
yield headers
yield self.ProgressSetTotal(total=qs.count())
for op in qs.order_by('order__datetime').iterator():
order = op.order
tz = pytz.timezone(order.event.settings.timezone)
@@ -411,7 +413,8 @@ class OrderListExporter(MultiSheetListExporter):
yield headers
all_ids = base_qs.order_by('order__datetime', 'positionid').values_list('pk', flat=True)
all_ids = list(base_qs.order_by('order__datetime', 'positionid').values_list('pk', flat=True))
yield self.ProgressSetTotal(total=len(all_ids))
for ids in chunked_iterable(all_ids, 1000):
ops = sorted(qs.filter(id__in=ids), key=lambda k: ids.index(k.pk))
@@ -561,6 +564,7 @@ class PaymentListExporter(ListExporter):
]
yield headers
yield self.ProgressSetTotal(total=len(objs))
for obj in objs:
tz = pytz.timezone(obj.order.event.settings.timezone)
if isinstance(obj, OrderPayment) and obj.payment_date:

View File

@@ -4,6 +4,7 @@ import sys
from django.core.management.base import BaseCommand
from django.utils.timezone import override
from django_scopes import scope
from tqdm import tqdm
from pretix.base.i18n import language
from pretix.base.models import Event, Organizer
@@ -34,10 +35,15 @@ class Command(BaseCommand):
self.stderr.write(self.style.ERROR('Event not found.'))
sys.exit(1)
pbar = tqdm(total=100)
def report_status(val):
pbar.update(round(val, 2) - pbar.n)
with language(e.settings.locale), override(e.settings.timezone):
responses = register_data_exporters.send(e)
for receiver, response in responses:
ex = response(e)
ex = response(e, report_status)
if ex.identifier == options['export_provider'][0]:
params = json.loads(options.get('parameters') or '{}')
with open(options['output_file'][0], 'wb') as f:
@@ -53,6 +59,7 @@ class Command(BaseCommand):
f.write(d[2])
sys.exit(0)
pbar.close()
self.stderr.write(self.style.ERROR('Export provider not found.'))
sys.exit(1)

View File

@@ -21,13 +21,20 @@ class ExportError(LazyLocaleException):
pass
@app.task(base=ProfiledEventTask, throws=(ExportError,))
def export(event: Event, fileid: str, provider: str, form_data: Dict[str, Any]) -> None:
@app.task(base=ProfiledEventTask, throws=(ExportError,), bind=True)
def export(self, event: Event, fileid: str, provider: str, form_data: Dict[str, Any]) -> None:
def set_progress(val):
if not self.request.called_directly:
self.update_state(
state='PROGRESS',
meta={'value': val}
)
file = CachedFile.objects.get(id=fileid)
with language(event.settings.locale), override(event.settings.timezone):
responses = register_data_exporters.send(event)
for receiver, response in responses:
ex = response(event)
ex = response(event, set_progress)
if ex.identifier == provider:
d = ex.render(form_data)
if d is None:
@@ -40,8 +47,15 @@ def export(event: Event, fileid: str, provider: str, form_data: Dict[str, Any])
return file.pk
@app.task(base=ProfiledOrganizerUserTask, throws=(ExportError,))
def multiexport(organizer: Organizer, user: User, fileid: str, provider: str, form_data: Dict[str, Any]) -> None:
@app.task(base=ProfiledOrganizerUserTask, throws=(ExportError,), bind=True)
def multiexport(self, organizer: Organizer, user: User, fileid: str, provider: str, form_data: Dict[str, Any]) -> None:
def set_progress(val):
if not self.request.called_directly:
self.update_state(
state='PROGRESS',
meta={'value': val}
)
file = CachedFile.objects.get(id=fileid)
with language(user.locale), override(user.timezone):
allowed_events = user.get_events_with_permission('can_view_orders')
@@ -52,7 +66,7 @@ def multiexport(organizer: Organizer, user: User, fileid: str, provider: str, fo
for receiver, response in responses:
if not response:
continue
ex = response(events)
ex = response(events, set_progress)
if ex.identifier == provider:
d = ex.render(form_data)
if d is None:

View File

@@ -100,6 +100,10 @@ class AsyncAction:
'success': False,
'message': str(self.get_error_message(res.info))
})
elif res.state == 'PROGRESS':
data.update({
'percentage': res.result.get('value', 0)
})
return data
def get_result(self, request):