Source code for salesman.basket.models

from __future__ import annotations

from collections import OrderedDict
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Generator

from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models, transaction
from django.http import HttpRequest
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _

from salesman.conf import app_settings
from salesman.core.typing import Product
from salesman.core.utils import get_salesman_model

if TYPE_CHECKING:  # pragma: no cover
    from django.db.models.manager import RelatedManager

BASKET_ID_SESSION_KEY = "BASKET_ID"


class BasketManager(models.Manager["BaseBasket"]):
    def get_or_create_from_request(
        self,
        request: HttpRequest,
    ) -> tuple[BaseBasket, bool]:
        """
        Get basket from request or create a new one.
        If user is logged in session basket gets merged into a user basket.

        Returns:
            tuple: (basket, created)
        """
        if not hasattr(request, "session"):
            request.session = {}
        try:
            session_basket_id = request.session[BASKET_ID_SESSION_KEY]
            session_basket = self.get(id=session_basket_id, user=None)
        except (KeyError, self.model.DoesNotExist):
            session_basket = None

        if hasattr(request, "user") and request.user.is_authenticated:
            try:
                basket, created = self.get_or_create(user_id=request.user.id)
            except self.model.MultipleObjectsReturned:
                # User has multiple baskets, merge them.
                baskets = list(self.filter(user=request.user.id))
                basket, created = baskets[0], False
                for other in baskets[1:]:
                    basket.merge(other)

            if session_basket:
                # Merge session basket into user basket.
                basket.merge(session_basket)

            if BASKET_ID_SESSION_KEY in request.session:
                # Delete session basket id from session so that it doesn't get
                # re-fetched while user is still logged in.
                del request.session[BASKET_ID_SESSION_KEY]
        else:
            basket, created = session_basket or self.create(), not session_basket
            request.session[BASKET_ID_SESSION_KEY] = basket.pk

        return basket, created


class BaseBasket(models.Model):
    user = models.ForeignKey(
        settings.AUTH_USER_MODEL,
        on_delete=models.CASCADE,
        null=True,
        verbose_name=_("Owner"),
    )

    extra = models.JSONField(_("Extra"), blank=True, default=dict)

    date_created = models.DateTimeField(_("Date created"), auto_now_add=True)
    date_updated = models.DateTimeField(_("Date updated"), auto_now=True)

    objects = BasketManager()
    items: RelatedManager[BaseBasketItem]

    _cached_items: list[BaseBasketItem] | None = None

    class Meta:
        abstract = True
        verbose_name = _("Basket")
        verbose_name_plural = _("Baskets")
        ordering = ["-date_created"]

    def __str__(self) -> str:
        return str(self.pk) if self.pk else "(unsaved)"

    def __iter__(self) -> Generator[BaseBasketItem, None, None]:
        for item in self.items.all():
            yield item

    def update(self, request: HttpRequest) -> None:
        """
        Process basket with modifiers defined in ``SALESMAN_BASKET_MODIFIERS``.
        This method sets ``subtotal``, ``total`` and ``extra_rows`` attributes on the
        basket and updates the items. Should be called every time the basket item is
        added, removed or updated or basket extra is updated.

        Args:
            request (HttpRequest): Django request
        """
        from .modifiers import basket_modifiers_pool

        items = self.get_items()

        # Setup basket and items.
        for modifier in basket_modifiers_pool.get_modifiers():
            modifier.setup_basket(self, request)
            for item in items:
                modifier.setup_item(item, request)

        self.extra_rows: dict[str, Any] = OrderedDict()
        self.subtotal = Decimal(0)
        self.total = Decimal(0)

        # Process basket items.
        for item in items:
            item.update(request)
            self.subtotal += item.total
        self.total = self.subtotal

        # Finalize items and process basket.
        for modifier in basket_modifiers_pool.get_modifiers():
            for item in items:
                modifier.finalize_item(item, request)
            modifier.process_basket(self, request)

        # Finalize basket.
        for modifier in basket_modifiers_pool.get_modifiers():
            modifier.finalize_basket(self, request)

        self._cached_items = items

    def add(
        self,
        product: Product,
        quantity: int = 1,
        ref: str | None = None,
        extra: dict[str, Any] | None = None,
    ) -> BaseBasketItem:
        """
        Add product to the basket.

        Returns:
            BasketItem: BasketItem instance
        """
        BasketItem = get_salesman_model("BasketItem")
        if not ref:
            ref = BasketItem.get_product_ref(product)
        try:
            item = self.items.get(ref=ref)
            item.quantity += quantity
            item.extra = extra or item.extra
            item.save(update_fields=["quantity", "extra", "date_updated"])
        except BasketItem.DoesNotExist:
            item = BasketItem.objects.create(
                basket=self,
                product=product,
                quantity=quantity,
                ref=ref,
                extra=extra or {},
            )
        self._cached_items = None
        return item

    def remove(self, ref: str) -> None:
        """
        Remove item with given ``ref`` from the basket.

        Args:
            ref (str): Item ref to remove
        """
        item = self.find(ref)
        if item:
            item.delete()
            self._cached_items = None

    def find(self, ref: str) -> BaseBasketItem | None:
        """
        Find item with given ``ref`` in the basket.

        Args:
            ref (str): Item ref

        Returns:
            Optional[BaseBasketItem]: Basket item if found.
        """
        if self._cached_items is not None:
            try:
                return [item for item in self._cached_items if item.ref == ref][0]
            except IndexError:
                return None
        return self.items.filter(ref=ref).first()

    def clear(self) -> None:
        """
        Clear all items from the basket.
        """
        self.items.all().delete()
        self._cached_items = None

    @transaction.atomic
    def merge(self, other: BaseBasket) -> None:
        """
        Merge other basket with this one, delete afterwards.

        Args:
            other (Basket): Basket which to merge
        """
        for item in other:
            try:
                existing = self.items.get(ref=item.ref)
                existing.quantity += item.quantity
                existing.save(update_fields=["quantity"])
            except item.DoesNotExist:
                item.basket = self
                item.save(update_fields=["basket"])
        other.delete()
        self._cached_items = None

    def get_items(self) -> list[BaseBasketItem]:
        """
        Returns items from cache or stores new ones.
        """
        if self._cached_items is None:
            self._cached_items = list(self.items.all().prefetch_related("product"))
        return self._cached_items

    @property
    def count(self) -> int:
        """
        Returns basket item count.
        """
        if self._cached_items is not None:
            return len(self._cached_items)
        return self.items.count()

    @property
    def quantity(self) -> int:
        """
        Returns the total quantity of all items in a basket.
        """
        if self._cached_items is not None:
            return sum([item.quantity for item in self._cached_items])
        aggr = self.items.aggregate(quantity=models.Sum("quantity"))
        return aggr["quantity"] or 0


[docs]class Basket(BaseBasket): """ Model that can be swapped by overriding `SALESMAN_BASKET_MODEL` setting. """ class Meta(BaseBasket.Meta): swappable = "SALESMAN_BASKET_MODEL"
class BaseBasketItem(models.Model): basket = models.ForeignKey( app_settings.SALESMAN_BASKET_MODEL, on_delete=models.CASCADE, related_name="items", verbose_name=_("Basket"), ) # Reference to this basket item, used to determine item duplicates. ref = models.SlugField(_("Reference"), max_length=128) # Generic relation to product. product_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) product_id = models.PositiveIntegerField(_("Product id")) product = GenericForeignKey("product_content_type", "product_id") quantity = models.PositiveIntegerField(_("Quantity"), default=1) extra = models.JSONField(_("Extra"), blank=True, default=dict) date_created = models.DateTimeField(_("Date created"), auto_now_add=True) date_updated = models.DateTimeField(_("Date updated"), auto_now=True) class Meta: abstract = True verbose_name = _("Item") verbose_name_plural = _("Items") unique_together = ("basket", "ref") ordering = ["date_created"] def __str__(self) -> str: return f"{self.quantity}x {self.product}" def save(self, *args: Any, **kwargs: Any) -> None: # Set default ref. if not self.ref and self.product: self.ref = self.get_product_ref(self.product) super().save(*args, **kwargs) def update(self, request: HttpRequest) -> None: """ Process items with modifiers defined in ``SALESMAN_BASKET_MODIFIERS``. This method sets ``unit_price``, ``subtotal``, ``total`` and ``extra_rows`` attributes on the item. Should be called every time the basket item is added, removed or updated. Args: request (HttpRequest): Django request """ from .modifiers import basket_modifiers_pool self.extra_rows: dict[str, Any] = OrderedDict() if self.product: self.unit_price = Decimal(self.product.get_price(request)) else: self.unit_price = Decimal(0) self.subtotal = self.unit_price * self.quantity self.total = self.subtotal for modifier in basket_modifiers_pool.get_modifiers(): modifier.process_item(self, request) @property def name(self) -> str: """ Returns product `name`. """ return str(self.product.name) if self.product else "(no name)" @property def code(self) -> str: """ Returns product `name`. """ return str(self.product.code) if self.product else "(no code)" @classmethod def get_product_ref(cls, product: Product) -> str: """ Returns default item ``ref`` for given product. Args: product (Product): Product instance Returns: str: Item ref """ return slugify(f"{product._meta.label}-{product.id}")
[docs]class BasketItem(BaseBasketItem): """ Model that can be swapped by overriding `SALESMAN_BASKET_ITEM_MODEL` setting. """ class Meta(BaseBasketItem.Meta): swappable = "SALESMAN_BASKET_ITEM_MODEL"