fix(security): Multi-tenancy isolation and customer appointment filtering

- Add request tenant validation to all ViewSets (EventViewSet, ResourceViewSet,
  ParticipantViewSet, CustomerViewSet, StaffViewSet) to prevent cross-tenant
  data access via subdomain/header manipulation
- Change permission_classes from AllowAny to IsAuthenticated for EventViewSet
  and ResourceViewSet
- Filter events for customers to only show appointments where they are a
  participant
- Add customer field to EventSerializer to create Customer participants when
  appointments are created
- Update CustomerDashboard to fetch appointments from API instead of mock data
- Fix TenantViewSet.destroy() to properly handle cross-schema cascade when
  deleting tenants

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
poduck
2025-12-04 11:05:01 -05:00
parent dbe91ec2ff
commit 65faaae864
4 changed files with 250 additions and 51 deletions

View File

@@ -1,30 +1,60 @@
import React, { useState, useMemo } from 'react'; import React, { useState, useMemo } from 'react';
import { useOutletContext, Link } from 'react-router-dom'; import { useOutletContext, Link } from 'react-router-dom';
import { User, Business, Appointment } from '../../types'; import { User, Business, Appointment } from '../../types';
import { APPOINTMENTS, SERVICES } from '../../mockData'; import { useAppointments, useUpdateAppointment } from '../../hooks/useAppointments';
import { Calendar, Clock, MapPin, AlertTriangle } from 'lucide-react'; import { useServices } from '../../hooks/useServices';
import { Calendar, Clock, MapPin, AlertTriangle, Loader2 } from 'lucide-react';
const AppointmentList: React.FC<{ user: User, business: Business }> = ({ user, business }) => { const AppointmentList: React.FC<{ user: User, business: Business }> = ({ user, business }) => {
const [appointments, setAppointments] = useState(APPOINTMENTS);
const [activeTab, setActiveTab] = useState<'upcoming' | 'past'>('upcoming'); const [activeTab, setActiveTab] = useState<'upcoming' | 'past'>('upcoming');
const myAppointments = useMemo(() => appointments.filter(apt => apt.customerName.includes(user.name.split(' ')[0])).sort((a, b) => b.startTime.getTime() - a.startTime.getTime()), [user.name, appointments]); // Fetch appointments from API - backend filters for current customer
const { data: appointments = [], isLoading, error } = useAppointments();
const upcomingAppointments = myAppointments.filter(apt => new Date(apt.startTime) >= new Date() && apt.status !== 'CANCELLED'); const { data: services = [] } = useServices();
const pastAppointments = myAppointments.filter(apt => new Date(apt.startTime) < new Date() || apt.status === 'CANCELLED'); const updateAppointment = useUpdateAppointment();
const handleCancel = (appointment: Appointment) => { // Sort appointments by start time (newest first)
const sortedAppointments = useMemo(() =>
[...appointments].sort((a, b) => b.startTime.getTime() - a.startTime.getTime()),
[appointments]
);
const upcomingAppointments = sortedAppointments.filter(apt => new Date(apt.startTime) >= new Date() && apt.status !== 'CANCELLED');
const pastAppointments = sortedAppointments.filter(apt => new Date(apt.startTime) < new Date() || apt.status === 'CANCELLED');
const handleCancel = async (appointment: Appointment) => {
const hoursBefore = (new Date(appointment.startTime).getTime() - new Date().getTime()) / 3600000; const hoursBefore = (new Date(appointment.startTime).getTime() - new Date().getTime()) / 3600000;
if (hoursBefore < business.cancellationWindowHours) { if (hoursBefore < business.cancellationWindowHours) {
const service = SERVICES.find(s => s.id === appointment.serviceId); const service = services.find(s => s.id === appointment.serviceId);
const fee = service ? (service.price * (business.lateCancellationFeePercent / 100)).toFixed(2) : 'a fee'; const fee = service ? (service.price * (business.lateCancellationFeePercent / 100)).toFixed(2) : 'a fee';
if (!window.confirm(`Cancelling within the ${business.cancellationWindowHours}-hour window may incur a fee of $${fee}. Are you sure?`)) return; if (!window.confirm(`Cancelling within the ${business.cancellationWindowHours}-hour window may incur a fee of $${fee}. Are you sure?`)) return;
} else { } else {
if (!window.confirm("Are you sure you want to cancel this appointment?")) return; if (!window.confirm("Are you sure you want to cancel this appointment?")) return;
} }
setAppointments(prev => prev.map(apt => apt.id === appointment.id ? {...apt, status: 'CANCELLED'} : apt)); try {
await updateAppointment.mutateAsync({ id: appointment.id, updates: { status: 'CANCELLED' } });
} catch (err) {
console.error('Failed to cancel appointment:', err);
alert('Failed to cancel appointment. Please try again.');
}
}; };
if (isLoading) {
return (
<div className="mt-8 flex items-center justify-center py-12">
<Loader2 className="w-8 h-8 animate-spin text-brand-500" />
</div>
);
}
if (error) {
return (
<div className="mt-8 text-center py-8 bg-red-50 dark:bg-red-900/20 rounded-lg border border-red-200 dark:border-red-800">
<p className="text-red-600 dark:text-red-400">Failed to load appointments. Please try again later.</p>
</div>
);
}
return ( return (
<div className="mt-8"> <div className="mt-8">
<h2 className="text-xl font-bold mb-4">Your Appointments</h2> <h2 className="text-xl font-bold mb-4">Your Appointments</h2>
@@ -34,14 +64,22 @@ const AppointmentList: React.FC<{ user: User, business: Business }> = ({ user, b
</div> </div>
<div className="space-y-4"> <div className="space-y-4">
{(activeTab === 'upcoming' ? upcomingAppointments : pastAppointments).map(apt => { {(activeTab === 'upcoming' ? upcomingAppointments : pastAppointments).map(apt => {
const service = SERVICES.find(s => s.id === apt.serviceId); const service = services.find(s => s.id === apt.serviceId);
return ( return (
<div key={apt.id} className="bg-white dark:bg-gray-800 p-4 rounded-lg border border-gray-200 dark:border-gray-700 flex items-center justify-between"> <div key={apt.id} className="bg-white dark:bg-gray-800 p-4 rounded-lg border border-gray-200 dark:border-gray-700 flex items-center justify-between">
<div> <div>
<h3 className="font-semibold">{service?.name}</h3> <h3 className="font-semibold">{service?.name || 'Appointment'}</h3>
<p className="text-sm text-gray-500">{new Date(apt.startTime).toLocaleString()}</p> <p className="text-sm text-gray-500">{new Date(apt.startTime).toLocaleString()}</p>
</div> </div>
{activeTab === 'upcoming' && <button onClick={() => handleCancel(apt)} className="text-sm font-medium text-red-600 hover:underline">Cancel</button>} {activeTab === 'upcoming' && (
<button
onClick={() => handleCancel(apt)}
disabled={updateAppointment.isPending}
className="text-sm font-medium text-red-600 hover:underline disabled:opacity-50"
>
{updateAppointment.isPending ? 'Cancelling...' : 'Cancel'}
</button>
)}
</div> </div>
); );
})} })}

View File

@@ -834,6 +834,36 @@ class TenantViewSet(viewsets.ModelViewSet):
status=status.HTTP_403_FORBIDDEN status=status.HTTP_403_FORBIDDEN
) )
# First, unlink staff_resources from users WITHIN the tenant's schema
# This prevents cross-schema SET_NULL cascade issues when users are deleted
with schema_context(tenant.schema_name):
from schedule.models import Resource
# Unlink all resources from users (set user_id to NULL)
Resource.objects.filter(user__isnull=False).update(user=None)
# Delete all users associated with this tenant
# Use _raw_delete to avoid triggering cascades
# (cascades would try to access tenant schema tables which may not exist from public)
user_ids = list(User.objects.filter(tenant=tenant).values_list('id', flat=True))
if user_ids:
# Delete related objects that are in the public schema first
from rest_framework.authtoken.models import Token
Token.objects.filter(user_id__in=user_ids).delete()
# Delete MFA-related objects
from smoothschedule.users.models import EmailVerificationToken, MFAVerificationCode, TrustedDevice
EmailVerificationToken.objects.filter(user_id__in=user_ids).delete()
MFAVerificationCode.objects.filter(user_id__in=user_ids).delete()
TrustedDevice.objects.filter(user_id__in=user_ids).delete()
# Now delete users using raw SQL to skip Django's cascade
from django.db import connection
with connection.cursor() as cursor:
cursor.execute(
"DELETE FROM users_user WHERE id = ANY(%s)",
[user_ids]
)
# Delete the tenant (this will drop the schema due to django-tenants) # Delete the tenant (this will drop the schema due to django-tenants)
tenant.delete() tenant.delete()

View File

@@ -186,14 +186,36 @@ class ResourceSerializer(serializers.ModelSerializer):
ret['user_id'] = instance.user_id ret['user_id'] = instance.user_id
return ret return ret
def _get_valid_user(self, user_id):
"""
Get a user by ID, validating they belong to the same tenant as the request user.
Returns None if user doesn't exist or doesn't belong to the same tenant.
CRITICAL: This prevents cross-tenant user linking (multi-tenancy security).
"""
if not user_id:
return None
request = self.context.get('request')
if not request or not request.user.is_authenticated:
return None
try:
user = User.objects.get(id=user_id)
# Verify user belongs to the same tenant
if request.user.tenant and user.tenant == request.user.tenant:
return user
return None
except User.DoesNotExist:
return None
def create(self, validated_data): def create(self, validated_data):
"""Handle user_id when creating a resource""" """Handle user_id when creating a resource"""
user_id = validated_data.pop('user_id', None) user_id = validated_data.pop('user_id', None)
if user_id: if user_id:
try: user = self._get_valid_user(user_id)
validated_data['user'] = User.objects.get(id=user_id) if user:
except User.DoesNotExist: validated_data['user'] = user
pass
return super().create(validated_data) return super().create(validated_data)
def update(self, instance, validated_data): def update(self, instance, validated_data):
@@ -201,10 +223,9 @@ class ResourceSerializer(serializers.ModelSerializer):
user_id = validated_data.pop('user_id', None) user_id = validated_data.pop('user_id', None)
if user_id is not None: if user_id is not None:
if user_id: if user_id:
try: user = self._get_valid_user(user_id)
validated_data['user'] = User.objects.get(id=user_id) if user:
except User.DoesNotExist: validated_data['user'] = user
pass
else: else:
validated_data['user'] = None validated_data['user'] = None
return super().update(instance, validated_data) return super().update(instance, validated_data)
@@ -284,12 +305,17 @@ class EventSerializer(serializers.ModelSerializer):
required=False, required=False,
help_text="List of Staff (User) IDs to assign" help_text="List of Staff (User) IDs to assign"
) )
customer = serializers.IntegerField(
write_only=True,
required=False,
help_text="Customer (User) ID to assign"
)
class Meta: class Meta:
model = Event model = Event
fields = [ fields = [
'id', 'title', 'start_time', 'end_time', 'status', 'notes', 'id', 'title', 'start_time', 'end_time', 'status', 'notes',
'duration_minutes', 'participants', 'resource_ids', 'staff_ids', 'duration_minutes', 'participants', 'resource_ids', 'staff_ids', 'customer',
'resource_id', 'customer_id', 'service_id', 'customer_name', 'service_name', 'is_paid', 'resource_id', 'customer_id', 'service_id', 'customer_name', 'service_name', 'is_paid',
'created_at', 'updated_at', 'created_by', 'created_at', 'updated_at', 'created_by',
] ]
@@ -426,17 +452,18 @@ class EventSerializer(serializers.ModelSerializer):
"""Create event and associated participants""" """Create event and associated participants"""
resource_ids = validated_data.pop('resource_ids', []) resource_ids = validated_data.pop('resource_ids', [])
staff_ids = validated_data.pop('staff_ids', []) staff_ids = validated_data.pop('staff_ids', [])
customer_id = validated_data.pop('customer', None)
# Set created_by from request user (only if authenticated) # Set created_by from request user (only if authenticated)
request = self.context.get('request') request = self.context.get('request')
if request and hasattr(request, 'user') and request.user.is_authenticated: if request and hasattr(request, 'user') and request.user.is_authenticated:
validated_data['created_by'] = request.user validated_data['created_by'] = request.user
else: else:
validated_data['created_by'] = None # TODO: Remove for production validated_data['created_by'] = None # TODO: Remove for production
# Create the event # Create the event
event = Event.objects.create(**validated_data) event = Event.objects.create(**validated_data)
# Create Resource participants # Create Resource participants
resource_content_type = ContentType.objects.get_for_model(Resource) resource_content_type = ContentType.objects.get_for_model(Resource)
for resource_id in resource_ids: for resource_id in resource_ids:
@@ -446,7 +473,7 @@ class EventSerializer(serializers.ModelSerializer):
object_id=resource_id, object_id=resource_id,
role=Participant.Role.RESOURCE role=Participant.Role.RESOURCE
) )
# Create Staff participants # Create Staff participants
from smoothschedule.users.models import User from smoothschedule.users.models import User
user_content_type = ContentType.objects.get_for_model(User) user_content_type = ContentType.objects.get_for_model(User)
@@ -457,13 +484,23 @@ class EventSerializer(serializers.ModelSerializer):
object_id=staff_id, object_id=staff_id,
role=Participant.Role.STAFF role=Participant.Role.STAFF
) )
# Create Customer participant
if customer_id:
Participant.objects.create(
event=event,
content_type=user_content_type,
object_id=customer_id,
role=Participant.Role.CUSTOMER
)
return event return event
def update(self, instance, validated_data): def update(self, instance, validated_data):
"""Update event. Participants managed separately.""" """Update event. Participants managed separately."""
validated_data.pop('resource_ids', None) validated_data.pop('resource_ids', None)
validated_data.pop('staff_ids', None) validated_data.pop('staff_ids', None)
validated_data.pop('customer', None)
for attr, value in validated_data.items(): for attr, value in validated_data.items():
setattr(instance, attr, value) setattr(instance, attr, value)

View File

@@ -77,18 +77,37 @@ class ResourceViewSet(viewsets.ModelViewSet):
""" """
queryset = Resource.objects.all() queryset = Resource.objects.all()
serializer_class = ResourceSerializer serializer_class = ResourceSerializer
# TODO: Re-enable authentication for production permission_classes = [IsAuthenticated]
permission_classes = [AllowAny] # Temporarily allow unauthenticated access for development
filterset_fields = ['is_active', 'max_concurrent_events'] filterset_fields = ['is_active', 'max_concurrent_events']
search_fields = ['name', 'description'] search_fields = ['name', 'description']
ordering_fields = ['name', 'created_at', 'max_concurrent_events'] ordering_fields = ['name', 'created_at', 'max_concurrent_events']
ordering = ['name'] ordering = ['name']
def get_queryset(self):
"""
Return resources for the current tenant.
CRITICAL: Validates user belongs to the current tenant.
"""
queryset = Resource.objects.all()
user = self.request.user
if not user.is_authenticated:
return queryset.none()
# Validate user belongs to the current tenant
request_tenant = getattr(self.request, 'tenant', None)
if user.tenant and request_tenant:
if user.tenant.schema_name != request_tenant.schema_name:
return queryset.none()
return queryset
def perform_create(self, serializer): def perform_create(self, serializer):
"""Create resource (quota-checked by HasQuota permission)""" """Create resource (quota-checked by HasQuota permission)"""
serializer.save() serializer.save()
def perform_update(self, serializer): def perform_update(self, serializer):
"""Update resource""" """Update resource"""
serializer.save() serializer.save()
@@ -113,8 +132,7 @@ class EventViewSet(viewsets.ModelViewSet):
""" """
queryset = Event.objects.all() queryset = Event.objects.all()
serializer_class = EventSerializer serializer_class = EventSerializer
# TODO: Re-enable authentication for production permission_classes = [IsAuthenticated]
permission_classes = [AllowAny] # Temporarily allow unauthenticated access for development
filterset_fields = ['status'] filterset_fields = ['status']
search_fields = ['title', 'notes'] search_fields = ['title', 'notes']
@@ -123,10 +141,41 @@ class EventViewSet(viewsets.ModelViewSet):
def get_queryset(self): def get_queryset(self):
""" """
Filter events by date range if start_date and end_date are provided. Filter events by date range and user role.
CRITICAL for multi-tenancy:
- Users can only see events from their own tenant
- Customers can only see events where they are a participant
- Staff/Managers/Owners see all events in their tenant
""" """
queryset = Event.objects.all() queryset = Event.objects.all()
# CRITICAL: Validate user belongs to the current tenant
user = self.request.user
if not user.is_authenticated:
return queryset.none()
# Get the current tenant from the request (set by TenantMainMiddleware/TenantHeaderMiddleware)
request_tenant = getattr(self.request, 'tenant', None)
# If user has a tenant, verify it matches the request tenant
# This prevents users from accessing other tenants' data via subdomain/header manipulation
if user.tenant and request_tenant:
if user.tenant.schema_name != request_tenant.schema_name:
# User is accessing a tenant they don't belong to - return empty
return queryset.none()
# Filter by user role
if user.role == User.Role.CUSTOMER:
# Customers only see events where they are a participant
from django.contrib.contenttypes.models import ContentType
user_content_type = ContentType.objects.get_for_model(User)
participant_event_ids = Participant.objects.filter(
content_type=user_content_type,
object_id=user.id
).values_list('event_id', flat=True)
queryset = queryset.filter(id__in=participant_event_ids)
# Filter by date range # Filter by date range
start_date = self.request.query_params.get('start_date') start_date = self.request.query_params.get('start_date')
end_date = self.request.query_params.get('end_date') end_date = self.request.query_params.get('end_date')
@@ -153,11 +202,7 @@ class EventViewSet(viewsets.ModelViewSet):
to check if resources have capacity. If not, DRF automatically to check if resources have capacity. If not, DRF automatically
returns 400 Bad Request with error details. returns 400 Bad Request with error details.
""" """
# TODO: Re-enable authentication - this is temporary for development serializer.save(created_by=self.request.user)
if self.request.user.is_authenticated:
serializer.save(created_by=self.request.user)
else:
serializer.save(created_by=None)
def perform_update(self, serializer): def perform_update(self, serializer):
""" """
@@ -184,6 +229,26 @@ class ParticipantViewSet(viewsets.ModelViewSet):
ordering_fields = ['created_at'] ordering_fields = ['created_at']
ordering = ['-created_at'] ordering = ['-created_at']
def get_queryset(self):
"""
Return participants for the current tenant.
CRITICAL: Validates user belongs to the current tenant.
"""
queryset = Participant.objects.all()
user = self.request.user
if not user.is_authenticated:
return queryset.none()
# Validate user belongs to the current tenant
request_tenant = getattr(self.request, 'tenant', None)
if user.tenant and request_tenant:
if user.tenant.schema_name != request_tenant.schema_name:
return queryset.none()
return queryset
class CustomerViewSet(viewsets.ModelViewSet): class CustomerViewSet(viewsets.ModelViewSet):
""" """
@@ -192,8 +257,7 @@ class CustomerViewSet(viewsets.ModelViewSet):
Customers are Users with role=CUSTOMER belonging to the current tenant. Customers are Users with role=CUSTOMER belonging to the current tenant.
""" """
serializer_class = CustomerSerializer serializer_class = CustomerSerializer
# TODO: Re-enable authentication for production permission_classes = [IsAuthenticated]
permission_classes = [AllowAny] # Temporarily allow unauthenticated access for development
filterset_fields = ['is_active'] filterset_fields = ['is_active']
search_fields = ['email', 'first_name', 'last_name'] search_fields = ['email', 'first_name', 'last_name']
@@ -207,12 +271,28 @@ class CustomerViewSet(viewsets.ModelViewSet):
Customers are Users with role=CUSTOMER. Customers are Users with role=CUSTOMER.
In sandbox mode, only returns customers with is_sandbox=True. In sandbox mode, only returns customers with is_sandbox=True.
In live mode, only returns customers with is_sandbox=False. In live mode, only returns customers with is_sandbox=False.
CRITICAL: Only returns customers belonging to the current user's tenant.
""" """
queryset = User.objects.filter(role=User.Role.CUSTOMER) queryset = User.objects.filter(role=User.Role.CUSTOMER)
# Filter by tenant if user is authenticated and has a tenant user = self.request.user
if self.request.user.is_authenticated and self.request.user.tenant: if not user.is_authenticated:
queryset = queryset.filter(tenant=self.request.user.tenant) return queryset.none()
# CRITICAL: Validate user belongs to the current request tenant
request_tenant = getattr(self.request, 'tenant', None)
if user.tenant and request_tenant:
if user.tenant.schema_name != request_tenant.schema_name:
# User is accessing a tenant they don't belong to - return empty
return queryset.none()
# Filter by user's tenant for multi-tenancy security
if user.tenant:
queryset = queryset.filter(tenant=user.tenant)
else:
# If user has no tenant, return empty queryset for safety
return queryset.none()
# Filter by sandbox mode - check request.sandbox_mode set by middleware # Filter by sandbox mode - check request.sandbox_mode set by middleware
is_sandbox = getattr(self.request, 'sandbox_mode', False) is_sandbox = getattr(self.request, 'sandbox_mode', False)
@@ -319,8 +399,7 @@ class StaffViewSet(viewsets.ModelViewSet):
- POST /api/staff/{id}/toggle_active/ - Toggle active status - POST /api/staff/{id}/toggle_active/ - Toggle active status
""" """
serializer_class = StaffSerializer serializer_class = StaffSerializer
# TODO: Re-enable authentication for production permission_classes = [IsAuthenticated]
permission_classes = [AllowAny]
search_fields = ['email', 'first_name', 'last_name'] search_fields = ['email', 'first_name', 'last_name']
ordering_fields = ['email', 'first_name', 'last_name'] ordering_fields = ['email', 'first_name', 'last_name']
@@ -337,9 +416,22 @@ class StaffViewSet(viewsets.ModelViewSet):
Staff are Users with roles: TENANT_OWNER, TENANT_MANAGER, TENANT_STAFF. Staff are Users with roles: TENANT_OWNER, TENANT_MANAGER, TENANT_STAFF.
In sandbox mode, only returns staff with is_sandbox=True. In sandbox mode, only returns staff with is_sandbox=True.
In live mode, only returns staff with is_sandbox=False. In live mode, only returns staff with is_sandbox=False.
CRITICAL: Only returns users belonging to the current user's tenant.
""" """
from django.db.models import Q from django.db.models import Q
user = self.request.user
if not user.is_authenticated:
return User.objects.none()
# CRITICAL: Validate user belongs to the current request tenant
request_tenant = getattr(self.request, 'tenant', None)
if user.tenant and request_tenant:
if user.tenant.schema_name != request_tenant.schema_name:
# User is accessing a tenant they don't belong to - return empty
return User.objects.none()
# Include inactive staff for listing (so admins can reactivate them) # Include inactive staff for listing (so admins can reactivate them)
show_inactive = self.request.query_params.get('show_inactive', 'true') show_inactive = self.request.query_params.get('show_inactive', 'true')
@@ -352,10 +444,12 @@ class StaffViewSet(viewsets.ModelViewSet):
if show_inactive.lower() != 'true': if show_inactive.lower() != 'true':
queryset = queryset.filter(is_active=True) queryset = queryset.filter(is_active=True)
# Filter by tenant if user is authenticated and has a tenant # Filter by user's tenant for multi-tenancy security
# TODO: Re-enable this when authentication is enabled if user.tenant:
# if self.request.user.is_authenticated and self.request.user.tenant: queryset = queryset.filter(tenant=user.tenant)
# queryset = queryset.filter(tenant=self.request.user.tenant) else:
# If user has no tenant, return empty queryset for safety
return queryset.none()
# Filter by sandbox mode - check request.sandbox_mode set by middleware # Filter by sandbox mode - check request.sandbox_mode set by middleware
is_sandbox = getattr(self.request, 'sandbox_mode', False) is_sandbox = getattr(self.request, 'sandbox_mode', False)