from __future__ import annotations
from collections import OrderedDict
from decimal import Decimal
from typing import TYPE_CHECKING, Any, Generator
from django.conf import settings
from django.contrib.contenttypes.fields import GenericForeignKey
from django.contrib.contenttypes.models import ContentType
from django.db import models, transaction
from django.http import HttpRequest
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from salesman.conf import app_settings
from salesman.core.typing import Product
from salesman.core.utils import get_salesman_model
if TYPE_CHECKING: # pragma: no cover
from django.db.models.manager import RelatedManager
BASKET_ID_SESSION_KEY = "BASKET_ID"
class BasketManager(models.Manager["BaseBasket"]):
def get_or_create_from_request(
self,
request: HttpRequest,
) -> tuple[BaseBasket, bool]:
"""
Get basket from request or create a new one.
If user is logged in session basket gets merged into a user basket.
Returns:
tuple: (basket, created)
"""
if not hasattr(request, "session"):
request.session = {}
try:
session_basket_id = request.session[BASKET_ID_SESSION_KEY]
session_basket = self.get(id=session_basket_id, user=None)
except (KeyError, self.model.DoesNotExist):
session_basket = None
if hasattr(request, "user") and request.user.is_authenticated:
try:
basket, created = self.get_or_create(user_id=request.user.id)
except self.model.MultipleObjectsReturned:
# User has multiple baskets, merge them.
baskets = list(self.filter(user=request.user.id))
basket, created = baskets[0], False
for other in baskets[1:]:
basket.merge(other)
if session_basket:
# Merge session basket into user basket.
basket.merge(session_basket)
if BASKET_ID_SESSION_KEY in request.session:
# Delete session basket id from session so that it doesn't get
# re-fetched while user is still logged in.
del request.session[BASKET_ID_SESSION_KEY]
else:
basket, created = session_basket or self.create(), not session_basket
request.session[BASKET_ID_SESSION_KEY] = basket.pk
return basket, created
[docs]class BaseBasket(models.Model):
user = models.ForeignKey(
settings.AUTH_USER_MODEL,
on_delete=models.CASCADE,
null=True,
verbose_name=_("Owner"),
)
extra = models.JSONField(_("Extra"), blank=True, default=dict)
date_created = models.DateTimeField(_("Date created"), auto_now_add=True)
date_updated = models.DateTimeField(_("Date updated"), auto_now=True)
objects = BasketManager()
items: RelatedManager[BaseBasketItem]
_cached_items: list[BaseBasketItem] | None = None
class Meta:
abstract = True
verbose_name = _("Basket")
verbose_name_plural = _("Baskets")
ordering = ["-date_created"]
def __str__(self) -> str:
return str(self.pk) if self.pk else "(unsaved)"
def __iter__(self) -> Generator[BaseBasketItem, None, None]:
for item in self.items.all():
yield item
[docs] def update(self, request: HttpRequest) -> None:
"""
Process basket with modifiers defined in ``SALESMAN_BASKET_MODIFIERS``.
This method sets ``subtotal``, ``total`` and ``extra_rows`` attributes on the
basket and updates the items. Should be called every time the basket item is
added, removed or updated or basket extra is updated.
Args:
request (HttpRequest): Django request
"""
from .modifiers import basket_modifiers_pool
items = self.get_items()
# Setup basket and items.
for modifier in basket_modifiers_pool.get_modifiers():
modifier.setup_basket(self, request)
for item in items:
modifier.setup_item(item, request)
self.extra_rows: dict[str, Any] = OrderedDict()
self.subtotal = Decimal(0)
self.total = Decimal(0)
# Process basket items.
for item in items:
item.update(request)
self.subtotal += item.total
self.total = self.subtotal
# Finalize items and process basket.
for modifier in basket_modifiers_pool.get_modifiers():
for item in items:
modifier.finalize_item(item, request)
modifier.process_basket(self, request)
# Finalize basket.
for modifier in basket_modifiers_pool.get_modifiers():
modifier.finalize_basket(self, request)
self._cached_items = items
[docs] def add(
self,
product: Product,
quantity: int = 1,
ref: str | None = None,
extra: dict[str, Any] | None = None,
) -> BaseBasketItem:
"""
Add product to the basket.
Returns:
BasketItem: BasketItem instance
"""
BasketItem = get_salesman_model("BasketItem")
if not ref:
ref = BasketItem.get_product_ref(product)
try:
item = self.items.get(ref=ref)
item.quantity += quantity
item.extra = extra or item.extra
item.save(update_fields=["quantity", "extra", "date_updated"])
except BasketItem.DoesNotExist:
item = BasketItem.objects.create(
basket=self,
product=product,
quantity=quantity,
ref=ref,
extra=extra or {},
)
self._cached_items = None
return item
[docs] def remove(self, ref: str) -> None:
"""
Remove item with given ``ref`` from the basket.
Args:
ref (str): Item ref to remove
"""
item = self.find(ref)
if item:
item.delete()
self._cached_items = None
[docs] def find(self, ref: str) -> BaseBasketItem | None:
"""
Find item with given ``ref`` in the basket.
Args:
ref (str): Item ref
Returns:
Optional[BaseBasketItem]: Basket item if found.
"""
if self._cached_items is not None:
try:
return [item for item in self._cached_items if item.ref == ref][0]
except IndexError:
return None
return self.items.filter(ref=ref).first()
[docs] def clear(self) -> None:
"""
Clear all items from the basket.
"""
self.items.all().delete()
self._cached_items = None
[docs] @transaction.atomic
def merge(self, other: BaseBasket) -> None:
"""
Merge other basket with this one, delete afterwards.
Args:
other (Basket): Basket which to merge
"""
for item in other:
try:
existing = self.items.get(ref=item.ref)
existing.quantity += item.quantity
existing.save(update_fields=["quantity"])
except item.DoesNotExist:
item.basket = self
item.save(update_fields=["basket"])
other.delete()
self._cached_items = None
[docs] def get_items(self) -> list[BaseBasketItem]:
"""
Returns items from cache or stores new ones.
"""
if self._cached_items is None:
self._cached_items = list(self.items.all().prefetch_related("product"))
return self._cached_items
@property
def count(self) -> int:
"""
Returns basket item count.
"""
if self._cached_items is not None:
return len(self._cached_items)
return self.items.count()
@property
def quantity(self) -> int:
"""
Returns the total quantity of all items in a basket.
"""
if self._cached_items is not None:
return sum([item.quantity for item in self._cached_items])
aggr = self.items.aggregate(quantity=models.Sum("quantity"))
return aggr["quantity"] or 0
[docs]class Basket(BaseBasket):
"""
Model that can be swapped by overriding `SALESMAN_BASKET_MODEL` setting.
"""
class Meta(BaseBasket.Meta):
swappable = "SALESMAN_BASKET_MODEL"
[docs]class BaseBasketItem(models.Model):
basket = models.ForeignKey(
app_settings.SALESMAN_BASKET_MODEL,
on_delete=models.CASCADE,
related_name="items",
verbose_name=_("Basket"),
)
# Reference to this basket item, used to determine item duplicates.
ref = models.SlugField(_("Reference"), max_length=128)
# Generic relation to product.
product_content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
product_id = models.PositiveIntegerField(_("Product id"))
product = GenericForeignKey("product_content_type", "product_id")
quantity = models.PositiveIntegerField(_("Quantity"), default=1)
extra = models.JSONField(_("Extra"), blank=True, default=dict)
date_created = models.DateTimeField(_("Date created"), auto_now_add=True)
date_updated = models.DateTimeField(_("Date updated"), auto_now=True)
class Meta:
abstract = True
verbose_name = _("Item")
verbose_name_plural = _("Items")
unique_together = ("basket", "ref")
ordering = ["date_created"]
def __str__(self) -> str:
return f"{self.quantity}x {self.product}"
[docs] def save(self, *args: Any, **kwargs: Any) -> None:
# Set default ref.
if not self.ref and self.product:
self.ref = self.get_product_ref(self.product)
super().save(*args, **kwargs)
[docs] def update(self, request: HttpRequest) -> None:
"""
Process items with modifiers defined in ``SALESMAN_BASKET_MODIFIERS``.
This method sets ``unit_price``, ``subtotal``, ``total`` and ``extra_rows``
attributes on the item. Should be called every time the basket item
is added, removed or updated.
Args:
request (HttpRequest): Django request
"""
from .modifiers import basket_modifiers_pool
self.extra_rows: dict[str, Any] = OrderedDict()
if self.product:
self.unit_price = Decimal(self.product.get_price(request))
else:
self.unit_price = Decimal(0)
self.subtotal = self.unit_price * self.quantity
self.total = self.subtotal
for modifier in basket_modifiers_pool.get_modifiers():
modifier.process_item(self, request)
@property
def name(self) -> str:
"""
Returns product `name`.
"""
return str(self.product.name) if self.product else "(no name)"
@property
def code(self) -> str:
"""
Returns product `name`.
"""
return str(self.product.code) if self.product else "(no code)"
[docs] @classmethod
def get_product_ref(cls, product: Product) -> str:
"""
Returns default item ``ref`` for given product.
Args:
product (Product): Product instance
Returns:
str: Item ref
"""
return slugify(f"{product._meta.label}-{product.id}")
[docs]class BasketItem(BaseBasketItem):
"""
Model that can be swapped by overriding `SALESMAN_BASKET_ITEM_MODEL` setting.
"""
class Meta(BaseBasketItem.Meta):
swappable = "SALESMAN_BASKET_ITEM_MODEL"