fix: Update djstripe signal imports and fix test mocking

- Use correct WEBHOOK_SIGNALS dict access for payment intent signals
- Simplify webhook tests by removing complex djstripe module mocking
- Fix TimezoneSerializerMixin tests to expect dynamic field addition
- Update TenantViewSet tests to mock exclude() chain for public schema

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
poduck
2025-12-10 00:24:37 -05:00
parent 507222316c
commit 2f6ea82114
4 changed files with 41 additions and 46 deletions

View File

@@ -4,38 +4,13 @@ Unit tests for Stripe webhook signal handlers.
Tests webhook signal handling logic with mocks to avoid database calls. Tests webhook signal handling logic with mocks to avoid database calls.
Follows CLAUDE.md guidelines: prefer mocks, avoid @pytest.mark.django_db. Follows CLAUDE.md guidelines: prefer mocks, avoid @pytest.mark.django_db.
Note: The webhooks.py module uses incorrect signal names (signals.payment_intent_succeeded Note: The webhooks.py module uses djstripe signals. These tests mock the
instead of signals.WEBHOOK_SIGNALS['payment_intent.succeeded']). These tests work around handler functions' dependencies to test their logic in isolation.
this by mocking the signals module before import.
""" """
from unittest.mock import Mock, patch, MagicMock from unittest.mock import Mock, patch, MagicMock
import pytest import pytest
from decimal import Decimal from decimal import Decimal
import sys
# Create a complete mock of djstripe.signals that matches what webhooks.py expects
class MockSignals:
"""Mock djstripe signals module with attribute-style signal access."""
webhook_processing_error = MagicMock()
payment_intent_succeeded = MagicMock()
payment_intent_payment_failed = MagicMock()
payment_intent_canceled = MagicMock()
WEBHOOK_SIGNALS = {
'payment_intent.succeeded': payment_intent_succeeded,
'payment_intent.payment_failed': payment_intent_payment_failed,
'payment_intent.canceled': payment_intent_canceled,
}
# Mock the djstripe module before any imports
mock_djstripe = MagicMock()
mock_djstripe.signals = MockSignals()
sys.modules['djstripe'] = mock_djstripe
# Now we can safely import the webhooks module
from smoothschedule.commerce.payments import webhooks from smoothschedule.commerce.payments import webhooks
from smoothschedule.commerce.payments.models import TransactionLink from smoothschedule.commerce.payments.models import TransactionLink

View File

@@ -4,7 +4,7 @@ Stripe Webhook Signal Handlers
Listens to dj-stripe signals to update TransactionLink and Event status. Listens to dj-stripe signals to update TransactionLink and Event status.
""" """
from django.dispatch import receiver from django.dispatch import receiver
from djstripe import signals from djstripe.signals import WEBHOOK_SIGNALS, webhook_processing_error
from django.utils import timezone from django.utils import timezone
from .models import TransactionLink from .models import TransactionLink
from smoothschedule.scheduling.schedule.models import Event from smoothschedule.scheduling.schedule.models import Event
@@ -13,11 +13,11 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@receiver(signals.webhook_processing_error) @receiver(webhook_processing_error)
def handle_webhook_error(sender, exception, event_type, **kwargs): def handle_webhook_error(sender, exception, event_type, **kwargs):
""" """
Log webhook processing errors for debugging. Log webhook processing errors for debugging.
This helps identify issues with Stripe webhook delivery or processing. This helps identify issues with Stripe webhook delivery or processing.
""" """
logger.error( logger.error(
@@ -31,7 +31,7 @@ def handle_webhook_error(sender, exception, event_type, **kwargs):
) )
@receiver(signals.payment_intent_succeeded) @receiver(WEBHOOK_SIGNALS['payment_intent.succeeded'])
def handle_payment_succeeded(sender, event, **kwargs): def handle_payment_succeeded(sender, event, **kwargs):
""" """
Handle successful payment and update Event status to PAID. Handle successful payment and update Event status to PAID.
@@ -80,7 +80,7 @@ def handle_payment_succeeded(sender, event, **kwargs):
) )
@receiver(signals.payment_intent_payment_failed) @receiver(WEBHOOK_SIGNALS['payment_intent.payment_failed'])
def handle_payment_failed(sender, event, **kwargs): def handle_payment_failed(sender, event, **kwargs):
"""Handle failed payments""" """Handle failed payments"""
payment_intent = event.data.object payment_intent = event.data.object
@@ -109,7 +109,7 @@ def handle_payment_failed(sender, event, **kwargs):
logger.error(f"Error processing payment_failed: {str(e)}", exc_info=e) logger.error(f"Error processing payment_failed: {str(e)}", exc_info=e)
@receiver(signals.payment_intent_canceled) @receiver(WEBHOOK_SIGNALS['payment_intent.canceled'])
def handle_payment_canceled(sender, event, **kwargs): def handle_payment_canceled(sender, event, **kwargs):
"""Handle canceled payments""" """Handle canceled payments"""
payment_intent = event.data.object payment_intent = event.data.object

View File

@@ -15,18 +15,23 @@ class TestTimezoneSerializerMixin:
"""Test TimezoneSerializerMixin class.""" """Test TimezoneSerializerMixin class."""
def test_adds_business_timezone_field_to_serializer(self): def test_adds_business_timezone_field_to_serializer(self):
"""Should add business_timezone as a SerializerMethodField.""" """Should add business_timezone as a SerializerMethodField when in Meta.fields."""
from smoothschedule.identity.core.mixins import TimezoneSerializerMixin from smoothschedule.identity.core.mixins import TimezoneSerializerMixin
class TestSerializer(TimezoneSerializerMixin, serializers.Serializer): class TestSerializer(TimezoneSerializerMixin, serializers.Serializer):
name = serializers.CharField() name = serializers.CharField()
class Meta:
fields = ['name', 'business_timezone']
# Need to instantiate with context to bind the serializer # Need to instantiate with context to bind the serializer
serializer = TestSerializer(context={}) serializer = TestSerializer(context={})
# Check that the mixin defines the business_timezone attribute # Check that the mixin provides the get_business_timezone method
assert hasattr(TimezoneSerializerMixin, 'business_timezone') assert hasattr(TimezoneSerializerMixin, 'get_business_timezone')
assert isinstance(TimezoneSerializerMixin.business_timezone, serializers.SerializerMethodField) # Check that business_timezone field is dynamically added when in Meta.fields
assert 'business_timezone' in serializer.fields
assert isinstance(serializer.fields['business_timezone'], serializers.SerializerMethodField)
def test_get_business_timezone_from_context_tenant(self): def test_get_business_timezone_from_context_tenant(self):
"""Should get timezone from tenant in context.""" """Should get timezone from tenant in context."""
@@ -258,6 +263,9 @@ class TestTimezoneSerializerMixin:
class TestSerializer(TimezoneSerializerMixin, serializers.Serializer): class TestSerializer(TimezoneSerializerMixin, serializers.Serializer):
name = serializers.CharField() name = serializers.CharField()
class Meta:
fields = ['name', 'business_timezone']
# Attempt to create with business_timezone # Attempt to create with business_timezone
data = { data = {
'name': 'Test Event', 'name': 'Test Event',
@@ -270,7 +278,9 @@ class TestTimezoneSerializerMixin:
assert serializer.is_valid() assert serializer.is_valid()
# The business_timezone field is a SerializerMethodField which is always read-only # The business_timezone field is a SerializerMethodField which is always read-only
assert isinstance(TimezoneSerializerMixin.business_timezone, serializers.SerializerMethodField) assert isinstance(serializer.fields['business_timezone'], serializers.SerializerMethodField)
# Validated data should not include business_timezone (it's read-only)
assert 'business_timezone' not in serializer.validated_data
class TestTimezoneContextMixin: class TestTimezoneContextMixin:

View File

@@ -1361,15 +1361,19 @@ class TestTenantViewSet:
request.query_params = {'is_active': 'true'} request.query_params = {'is_active': 'true'}
mock_queryset = Mock() mock_queryset = Mock()
excluded_queryset = Mock()
filtered_queryset = Mock() filtered_queryset = Mock()
mock_queryset.filter.return_value = filtered_queryset # Chain: queryset.exclude().filter()
mock_queryset.exclude.return_value = excluded_queryset
excluded_queryset.filter.return_value = filtered_queryset
with patch.object(self.viewset, 'queryset', mock_queryset): with patch.object(self.viewset, 'queryset', mock_queryset):
view = self.viewset() view = self.viewset()
view.request = request view.request = request
result = view.get_queryset() result = view.get_queryset()
mock_queryset.filter.assert_called_once_with(is_active=True) mock_queryset.exclude.assert_called_once_with(schema_name='public')
excluded_queryset.filter.assert_called_once_with(is_active=True)
def test_destroy_requires_superuser(self): def test_destroy_requires_superuser(self):
"""Test destroy requires superuser role""" """Test destroy requires superuser role"""
@@ -1424,17 +1428,23 @@ class TestTenantViewSet:
role=User.Role.SUPERUSER role=User.Role.SUPERUSER
) )
with patch('smoothschedule.identity.core.models.Tenant.objects.count', return_value=10): # Mock the Tenant.objects.exclude().count() and .filter().count() chains
with patch('smoothschedule.identity.core.models.Tenant.objects.filter') as mock_filter: with patch('smoothschedule.identity.core.models.Tenant.objects.exclude') as mock_exclude:
mock_filter.return_value.count.return_value = 8 mock_excluded = Mock()
with patch('smoothschedule.identity.users.models.User.objects.count', return_value=100): mock_exclude.return_value = mock_excluded
view = self.viewset.as_view({'get': 'metrics'}) mock_excluded.count.return_value = 10 # total_tenants
response = view(request) mock_excluded.filter.return_value.count.return_value = 8 # active_tenants
with patch('smoothschedule.identity.users.models.User.objects.count', return_value=100):
view = self.viewset.as_view({'get': 'metrics'})
response = view(request)
assert response.status_code == status.HTTP_200_OK assert response.status_code == status.HTTP_200_OK
assert 'total_tenants' in response.data assert 'total_tenants' in response.data
assert 'active_tenants' in response.data assert 'active_tenants' in response.data
assert 'total_users' in response.data assert 'total_users' in response.data
assert response.data['total_tenants'] == 10
assert response.data['active_tenants'] == 8
assert response.data['total_users'] == 100
# ============================================================================ # ============================================================================