diff --git a/backend/commerce/models.py b/backend/commerce/models.py index 1e58abf..9f78dec 100644 --- a/backend/commerce/models.py +++ b/backend/commerce/models.py @@ -212,7 +212,8 @@ class Order(models.Model): def calculate_total_price(self): carrier_price = self.carrier.get_price() if self.carrier else Decimal("0.0") - if self.discount.exists(): + # Check if order has been saved (has an ID) before accessing many-to-many relationships + if self.pk and self.discount.exists(): discounts = list(self.discount.all()) total = Decimal('0.0') @@ -227,8 +228,10 @@ class Order(models.Model): total = Decimal('0.0') # getting all prices from order items (without discount) - using VAT-inclusive prices - for item in self.items.all(): - total = total + (item.product.get_price_with_vat() * item.quantity) + # Only try to access items if order has been saved + if self.pk: + for item in self.items.all(): + total = total + (item.product.get_price_with_vat() * item.quantity) return total + carrier_price diff --git a/backend/commerce/serializers.py b/backend/commerce/serializers.py index 3a0e37f..21f599d 100644 --- a/backend/commerce/serializers.py +++ b/backend/commerce/serializers.py @@ -251,7 +251,7 @@ class OrderCreateSerializer(serializers.Serializer): # přidame fieldy, které nejsou vyplněné for field in required_fields: - if attrs.get(field) not in required_fields: + if not attrs.get(field): missing_fields.append(field) if missing_fields: @@ -307,10 +307,13 @@ class OrderCreateSerializer(serializers.Serializer): # -- Slevové kódy -- + # Discount codes need to be added before payment/final save because calculate_total_price uses them if codes: discounts = list(DiscountCode.objects.filter(code__in=codes)) if discounts: order.discount.add(*discounts) + # Save to recalculate total with discounts + order.save(update_fields=["total_price", "updated_at"]) @@ -324,7 +327,7 @@ class OrderCreateSerializer(serializers.Serializer): # přiřadíme k orderu order.payment = payment - order.save(update_fields=["payment"]) + order.save(update_fields=["payment", "updated_at"]) return order diff --git a/backend/vontor_cz/settings.py b/backend/vontor_cz/settings.py index 54c266b..1c2a063 100644 --- a/backend/vontor_cz/settings.py +++ b/backend/vontor_cz/settings.py @@ -326,7 +326,10 @@ REST_FRAMEWORK = { 'DEFAULT_THROTTLE_RATES': { 'anon': '100/hour', # unauthenticated 'user': '2000/hour', # authenticated - } + }, + + 'EXCEPTION_HANDLER': 'trznice.utils.custom_exception_handler', + } #--------------------------------END REST FRAMEWORK 🛠️------------------------------------- diff --git a/backend/vontor_cz/utils.py b/backend/vontor_cz/utils.py new file mode 100644 index 0000000..dd8469e --- /dev/null +++ b/backend/vontor_cz/utils.py @@ -0,0 +1,44 @@ +import os +from datetime import datetime +from rest_framework.fields import DateTimeField +from django.conf import settings +from django.core.exceptions import ValidationError as DjangoValidationError +from rest_framework.views import exception_handler +from rest_framework.exceptions import ValidationError as DRFValidationError + + +def custom_exception_handler(exc, context): + """ + Custom exception handler to convert Django ValidationError to DRF ValidationError (400 instead of 500) + """ + # Convert Django ValidationError to DRF ValidationError + if isinstance(exc, DjangoValidationError): + if hasattr(exc, 'error_dict'): + # Multiple field errors + exc = DRFValidationError(detail=exc.message_dict) + elif hasattr(exc, 'error_list'): + # Single error or list of errors + exc = DRFValidationError(detail=exc.messages) + else: + # Fallback + exc = DRFValidationError(detail=str(exc)) + + # Call REST framework's default exception handler + return exception_handler(exc, context) + + +def truncate_to_minutes(dt: datetime) -> datetime: + return dt.replace(second=0, microsecond=0) + + +class RoundedDateTimeField(DateTimeField): + def to_internal_value(self, value): + dt = super().to_internal_value(value) + return truncate_to_minutes(dt) + +IMPORT_DIR = "data_imports" +def get_imports_dir(): + base_dir = settings.BASE_DIR # adjust if your BASE_DIR is different + imports_dir = os.path.join(base_dir, IMPORT_DIR) + os.makedirs(imports_dir, exist_ok=True) + return imports_dir