Source code for salesman.basket.models

from __future__ import annotations

from collections import OrderedDict
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Generator

from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models, transaction
from django.http import HttpRequest
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _

from salesman.conf import app_settings
from salesman.core.typing import Product
from salesman.core.utils import get_salesman_model

if TYPE_CHECKING:  # pragma: no cover
    from django.db.models.manager import RelatedManager

BASKET_ID_SESSION_KEY = "BASKET_ID"


class BasketManager(models.Manager["BaseBasket"]):
    def get_or_create_from_request(
        self,
        request: HttpRequest,
    ) -> tuple[BaseBasket, bool]:
        """
        Get basket from request or create a new one.
        If user is logged in session basket gets merged into a user basket.

        Returns:
            tuple: (basket, created)
        """
        if not hasattr(request, "session"):
            request.session = {}
        try:
            session_basket_id = request.session[BASKET_ID_SESSION_KEY]
            session_basket = self.get(id=session_basket_id, user=None)
        except (KeyError, self.model.DoesNotExist):
            session_basket = None

        if hasattr(request, "user") and request.user.is_authenticated:
            try:
                basket, created = self.get_or_create(user_id=request.user.id)
            except self.model.MultipleObjectsReturned:
                # User has multiple baskets, merge them.
                baskets = list(self.filter(user=request.user.id))
                basket, created = baskets[0], False
                for other in baskets[1:]:
                    basket.merge(other)

            if session_basket:
                # Merge session basket into user basket.
                basket.merge(session_basket)

            if BASKET_ID_SESSION_KEY in request.session:
                # Delete session basket id from session so that it doesn't get
                # re-fetched while user is still logged in.
                del request.session[BASKET_ID_SESSION_KEY]
        else:
            basket, created = session_basket or self.create(), not session_basket
            request.session[BASKET_ID_SESSION_KEY] = basket.pk

        return basket, created


[docs]class BaseBasket(models.Model): user = models.ForeignKey( settings.AUTH_USER_MODEL, on_delete=models.CASCADE, null=True, verbose_name=_("Owner"), ) extra = models.JSONField(_("Extra"), blank=True, default=dict) date_created = models.DateTimeField(_("Date created"), auto_now_add=True) date_updated = models.DateTimeField(_("Date updated"), auto_now=True) objects = BasketManager() items: RelatedManager[BaseBasketItem] _cached_items: list[BaseBasketItem] | None = None class Meta: abstract = True verbose_name = _("Basket") verbose_name_plural = _("Baskets") ordering = ["-date_created"] def __str__(self) -> str: return str(self.pk) if self.pk else "(unsaved)" def __iter__(self) -> Generator[BaseBasketItem, None, None]: for item in self.items.all(): yield item
[docs] def update(self, request: HttpRequest) -> None: """ Process basket with modifiers defined in ``SALESMAN_BASKET_MODIFIERS``. This method sets ``subtotal``, ``total`` and ``extra_rows`` attributes on the basket and updates the items. Should be called every time the basket item is added, removed or updated or basket extra is updated. Args: request (HttpRequest): Django request """ from .modifiers import basket_modifiers_pool items = self.get_items() # Setup basket and items. for modifier in basket_modifiers_pool.get_modifiers(): modifier.setup_basket(self, request) for item in items: modifier.setup_item(item, request) self.extra_rows: dict[str, Any] = OrderedDict() self.subtotal = Decimal(0) self.total = Decimal(0) # Process basket items. for item in items: item.update(request) self.subtotal += item.total self.total = self.subtotal # Finalize items and process basket. for modifier in basket_modifiers_pool.get_modifiers(): for item in items: modifier.finalize_item(item, request) modifier.process_basket(self, request) # Finalize basket. for modifier in basket_modifiers_pool.get_modifiers(): modifier.finalize_basket(self, request) self._cached_items = items
[docs] def add( self, product: Product, quantity: int = 1, ref: str | None = None, extra: dict[str, Any] | None = None, ) -> BaseBasketItem: """ Add product to the basket. Returns: BasketItem: BasketItem instance """ BasketItem = get_salesman_model("BasketItem") if not ref: ref = BasketItem.get_product_ref(product) try: item = self.items.get(ref=ref) item.quantity += quantity item.extra = extra or item.extra item.save(update_fields=["quantity", "extra", "date_updated"]) except BasketItem.DoesNotExist: item = BasketItem.objects.create( basket=self, product=product, quantity=quantity, ref=ref, extra=extra or {}, ) self._cached_items = None return item
[docs] def remove(self, ref: str) -> None: """ Remove item with given ``ref`` from the basket. Args: ref (str): Item ref to remove """ item = self.find(ref) if item: item.delete() self._cached_items = None
[docs] def find(self, ref: str) -> BaseBasketItem | None: """ Find item with given ``ref`` in the basket. Args: ref (str): Item ref Returns: Optional[BaseBasketItem]: Basket item if found. """ if self._cached_items is not None: try: return [item for item in self._cached_items if item.ref == ref][0] except IndexError: return None return self.items.filter(ref=ref).first()
[docs] def clear(self) -> None: """ Clear all items from the basket. """ self.items.all().delete() self._cached_items = None
[docs] @transaction.atomic def merge(self, other: BaseBasket) -> None: """ Merge other basket with this one, delete afterwards. Args: other (Basket): Basket which to merge """ for item in other: try: existing = self.items.get(ref=item.ref) existing.quantity += item.quantity existing.save(update_fields=["quantity"]) except item.DoesNotExist: item.basket = self item.save(update_fields=["basket"]) other.delete() self._cached_items = None
[docs] def get_items(self) -> list[BaseBasketItem]: """ Returns items from cache or stores new ones. """ if self._cached_items is None: self._cached_items = list(self.items.all().prefetch_related("product")) return self._cached_items
@property def count(self) -> int: """ Returns basket item count. """ if self._cached_items is not None: return len(self._cached_items) return self.items.count() @property def quantity(self) -> int: """ Returns the total quantity of all items in a basket. """ if self._cached_items is not None: return sum([item.quantity for item in self._cached_items]) aggr = self.items.aggregate(quantity=models.Sum("quantity")) return aggr["quantity"] or 0
[docs]class Basket(BaseBasket): """ Model that can be swapped by overriding `SALESMAN_BASKET_MODEL` setting. """ class Meta(BaseBasket.Meta): swappable = "SALESMAN_BASKET_MODEL"
[docs]class BaseBasketItem(models.Model): basket = models.ForeignKey( app_settings.SALESMAN_BASKET_MODEL, on_delete=models.CASCADE, related_name="items", verbose_name=_("Basket"), ) # Reference to this basket item, used to determine item duplicates. ref = models.SlugField(_("Reference"), max_length=128) # Generic relation to product. product_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) product_id = models.PositiveIntegerField(_("Product id")) product = GenericForeignKey("product_content_type", "product_id") quantity = models.PositiveIntegerField(_("Quantity"), default=1) extra = models.JSONField(_("Extra"), blank=True, default=dict) date_created = models.DateTimeField(_("Date created"), auto_now_add=True) date_updated = models.DateTimeField(_("Date updated"), auto_now=True) class Meta: abstract = True verbose_name = _("Item") verbose_name_plural = _("Items") unique_together = ("basket", "ref") ordering = ["date_created"] def __str__(self) -> str: return f"{self.quantity}x {self.product}"
[docs] def save(self, *args: Any, **kwargs: Any) -> None: # Set default ref. if not self.ref and self.product: self.ref = self.get_product_ref(self.product) super().save(*args, **kwargs)
[docs] def update(self, request: HttpRequest) -> None: """ Process items with modifiers defined in ``SALESMAN_BASKET_MODIFIERS``. This method sets ``unit_price``, ``subtotal``, ``total`` and ``extra_rows`` attributes on the item. Should be called every time the basket item is added, removed or updated. Args: request (HttpRequest): Django request """ from .modifiers import basket_modifiers_pool self.extra_rows: dict[str, Any] = OrderedDict() if self.product: self.unit_price = Decimal(self.product.get_price(request)) else: self.unit_price = Decimal(0) self.subtotal = self.unit_price * self.quantity self.total = self.subtotal for modifier in basket_modifiers_pool.get_modifiers(): modifier.process_item(self, request)
@property def name(self) -> str: """ Returns product `name`. """ return str(self.product.name) if self.product else "(no name)" @property def code(self) -> str: """ Returns product `name`. """ return str(self.product.code) if self.product else "(no code)"
[docs] @classmethod def get_product_ref(cls, product: Product) -> str: """ Returns default item ``ref`` for given product. Args: product (Product): Product instance Returns: str: Item ref """ return slugify(f"{product._meta.label}-{product.id}")
[docs]class BasketItem(BaseBasketItem): """ Model that can be swapped by overriding `SALESMAN_BASKET_ITEM_MODEL` setting. """ class Meta(BaseBasketItem.Meta): swappable = "SALESMAN_BASKET_ITEM_MODEL"