Files
smoothschedule/smoothschedule/schedule/safe_scripting.py
poduck e9b3eb9e84 feat: Add HTTP methods and URL whitelist system for plugins
Backend Changes:
- Extended SafeScriptAPI to support all HTTP methods (GET, POST, PUT, PATCH, DELETE)
- Created WhitelistedURL model for per-plugin and platform-wide URL whitelisting
- Added _validate_url() method with SSRF protection and private IP blocking
- Updated SafeScriptAPI to accept scheduled_task parameter for whitelist checking
- All HTTP methods now validate against whitelist before making requests

WhitelistedURL Model:
- Supports two scopes: PLATFORM (all plugins) and PLUGIN (specific plugin)
- Stores URL patterns with wildcard support (e.g., https://api.example.com/*)
- Tracks allowed HTTP methods per URL
- Includes approval workflow (approved_by, approved_at)
- Stores original plugin code for verification
- Domain-based indexing for fast lookup
- Database constraint ensures platform-wide entries have no plugin assigned

Security Features:
- SSRF prevention: blocks localhost, loopback, and private IP ranges
- Per-plugin whitelist: each ScheduledTask can only access its whitelisted URLs
- Platform-wide whitelist: approved URLs accessible by all plugins
- HTTP method validation: URLs must explicitly allow each method
- URL pattern matching with wildcard support

Related Models:
- WhitelistedURL.scheduled_task -> ScheduledTask (plugin that owns the whitelist)
- WhitelistedURL.approved_by -> User (platform user who approved the URL)

Migration: schedule/migrations/0014_whitelistedurl.py

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-28 21:00:39 -05:00

807 lines
23 KiB
Python

"""
Safe Scripting Engine for Customer Automations
Allows customers to write simple logic (if/else, loops, variables) while preventing:
- Infinite loops
- Excessive memory usage
- File system access
- Network access (except approved APIs)
- Code injection
- Resource exhaustion
Uses RestrictedPython for safe code execution with additional safety layers.
"""
import ast
import time
import sys
from typing import Any, Dict, List, Optional
from io import StringIO
from contextlib import redirect_stdout, redirect_stderr
import logging
logger = logging.getLogger(__name__)
class ResourceLimitExceeded(Exception):
"""Raised when script exceeds resource limits"""
pass
class ScriptExecutionError(Exception):
"""Raised when script execution fails"""
pass
class SafeScriptAPI:
"""
Safe API that customer scripts can access.
Only exposes whitelisted operations that interact with their own data.
"""
def __init__(self, business, user, execution_context, scheduled_task=None):
self.business = business
self.user = user
self.context = execution_context
self.scheduled_task = scheduled_task # ScheduledTask instance for whitelist checking
self._api_call_count = 0
self._max_api_calls = 50 # Prevent API spam
def _check_api_limit(self):
"""Enforce API call limits"""
self._api_call_count += 1
if self._api_call_count > self._max_api_calls:
raise ResourceLimitExceeded(f"API call limit exceeded ({self._max_api_calls} calls)")
def get_appointments(self, **filters):
"""
Get appointments for this business.
Args:
status: Filter by status (SCHEDULED, COMPLETED, CANCELED)
start_date: Filter by start date (YYYY-MM-DD)
end_date: Filter by end date (YYYY-MM-DD)
limit: Maximum results (default: 100, max: 1000)
Returns:
List of appointment dictionaries
"""
self._check_api_limit()
from .models import Event
from django.utils import timezone
from datetime import datetime
queryset = Event.objects.all()
# Apply filters
if 'status' in filters:
queryset = queryset.filter(status=filters['status'])
if 'start_date' in filters:
start = datetime.strptime(filters['start_date'], '%Y-%m-%d')
queryset = queryset.filter(start_time__gte=timezone.make_aware(start))
if 'end_date' in filters:
end = datetime.strptime(filters['end_date'], '%Y-%m-%d')
queryset = queryset.filter(start_time__lte=timezone.make_aware(end))
# Enforce limits
limit = min(filters.get('limit', 100), 1000)
queryset = queryset[:limit]
# Serialize to safe dictionaries
return [
{
'id': event.id,
'title': event.title,
'start_time': event.start_time.isoformat(),
'end_time': event.end_time.isoformat(),
'status': event.status,
'notes': event.notes,
}
for event in queryset
]
def get_customers(self, **filters):
"""
Get customers for this business.
Args:
limit: Maximum results (default: 100, max: 1000)
has_email: Filter to customers with email addresses
Returns:
List of customer dictionaries
"""
self._check_api_limit()
from smoothschedule.users.models import User
queryset = User.objects.filter(role='customer')
if filters.get('has_email'):
queryset = queryset.exclude(email='')
limit = min(filters.get('limit', 100), 1000)
queryset = queryset[:limit]
return [
{
'id': user.id,
'email': user.email,
'name': user.get_full_name() or user.username,
'phone': getattr(user, 'phone', ''),
}
for user in queryset
]
def send_email(self, to, subject, body):
"""
Send an email to a customer.
Args:
to: Email address or customer ID
subject: Email subject
body: Email body (plain text)
Returns:
True if sent successfully
"""
self._check_api_limit()
from django.core.mail import send_mail
from django.conf import settings
# Resolve customer ID to email if needed
if isinstance(to, int):
from smoothschedule.users.models import User
try:
user = User.objects.get(id=to)
to = user.email
except User.DoesNotExist:
raise ScriptExecutionError(f"Customer {to} not found")
# Validate email
if not to or '@' not in to:
raise ScriptExecutionError(f"Invalid email address: {to}")
# Length limits
if len(subject) > 200:
raise ScriptExecutionError("Subject too long (max 200 characters)")
if len(body) > 10000:
raise ScriptExecutionError("Body too long (max 10,000 characters)")
try:
send_mail(
subject=subject,
message=body,
from_email=settings.DEFAULT_FROM_EMAIL,
recipient_list=[to],
fail_silently=False,
)
return True
except Exception as e:
logger.error(f"Failed to send email: {e}")
return False
def log(self, message):
"""Log a message (for debugging)"""
logger.info(f"[Customer Script] {message}")
return message
def _validate_url(self, url, method='GET'):
"""
Validate URL against whitelist and check for SSRF attacks.
Args:
url: URL to validate
method: HTTP method (GET, POST, PUT, PATCH, DELETE)
Raises:
ScriptExecutionError: If URL is not whitelisted or is unsafe
"""
from urllib.parse import urlparse
from .models import WhitelistedURL
parsed = urlparse(url)
# Prevent SSRF attacks - check localhost
if parsed.hostname in ['localhost', '127.0.0.1', '0.0.0.0', '::1']:
raise ScriptExecutionError("Cannot access localhost")
# Prevent access to private IP ranges
import ipaddress
try:
ip = ipaddress.ip_address(parsed.hostname)
if ip.is_private or ip.is_loopback or ip.is_link_local:
raise ScriptExecutionError("Cannot access private IP addresses")
except ValueError:
# Not an IP address, continue with domain validation
pass
# Check whitelist using database model
if not WhitelistedURL.is_url_whitelisted(url, method, self.scheduled_task):
raise ScriptExecutionError(
f"URL '{url}' with method '{method}' is not whitelisted for this plugin. "
f"Contact support at pluginaccess@smoothschedule.com to request whitelisting."
)
def http_get(self, url, headers=None):
"""
Make an HTTP GET request to approved domains.
Args:
url: URL to fetch (must be whitelisted)
headers: Optional headers dictionary
Returns:
Response text
Raises:
ScriptExecutionError: If URL not whitelisted or request fails
"""
self._check_api_limit()
self._validate_url(url, 'GET')
import requests
try:
response = requests.get(
url,
headers=headers or {},
timeout=10, # 10 second timeout
)
response.raise_for_status()
return response.text
except requests.RequestException as e:
raise ScriptExecutionError(f"HTTP GET request failed: {e}")
def http_post(self, url, data=None, headers=None):
"""
Make an HTTP POST request to approved domains.
Args:
url: URL to post to (must be whitelisted)
data: Data to send (dict or string)
headers: Optional headers dictionary
Returns:
Response text
Raises:
ScriptExecutionError: If URL not whitelisted or request fails
"""
self._check_api_limit()
self._validate_url(url, 'POST')
import requests
try:
# Auto-detect JSON content
if isinstance(data, dict):
response = requests.post(
url,
json=data,
headers=headers or {},
timeout=10,
)
else:
response = requests.post(
url,
data=data,
headers=headers or {},
timeout=10,
)
response.raise_for_status()
return response.text
except requests.RequestException as e:
raise ScriptExecutionError(f"HTTP POST request failed: {e}")
def http_put(self, url, data=None, headers=None):
"""
Make an HTTP PUT request to approved domains.
Args:
url: URL to put to (must be whitelisted)
data: Data to send (dict or string)
headers: Optional headers dictionary
Returns:
Response text
Raises:
ScriptExecutionError: If URL not whitelisted or request fails
"""
self._check_api_limit()
self._validate_url(url, 'PUT')
import requests
try:
# Auto-detect JSON content
if isinstance(data, dict):
response = requests.put(
url,
json=data,
headers=headers or {},
timeout=10,
)
else:
response = requests.put(
url,
data=data,
headers=headers or {},
timeout=10,
)
response.raise_for_status()
return response.text
except requests.RequestException as e:
raise ScriptExecutionError(f"HTTP PUT request failed: {e}")
def http_patch(self, url, data=None, headers=None):
"""
Make an HTTP PATCH request to approved domains.
Args:
url: URL to patch (must be whitelisted)
data: Data to send (dict or string)
headers: Optional headers dictionary
Returns:
Response text
Raises:
ScriptExecutionError: If URL not whitelisted or request fails
"""
self._check_api_limit()
self._validate_url(url, 'PATCH')
import requests
try:
# Auto-detect JSON content
if isinstance(data, dict):
response = requests.patch(
url,
json=data,
headers=headers or {},
timeout=10,
)
else:
response = requests.patch(
url,
data=data,
headers=headers or {},
timeout=10,
)
response.raise_for_status()
return response.text
except requests.RequestException as e:
raise ScriptExecutionError(f"HTTP PATCH request failed: {e}")
def http_delete(self, url, headers=None):
"""
Make an HTTP DELETE request to approved domains.
Args:
url: URL to delete (must be whitelisted)
headers: Optional headers dictionary
Returns:
Response text
Raises:
ScriptExecutionError: If URL not whitelisted or request fails
"""
self._check_api_limit()
self._validate_url(url, 'DELETE')
import requests
try:
response = requests.delete(
url,
headers=headers or {},
timeout=10,
)
response.raise_for_status()
return response.text
except requests.RequestException as e:
raise ScriptExecutionError(f"HTTP DELETE request failed: {e}")
def create_appointment(self, title, start_time, end_time, **kwargs):
"""
Create a new appointment.
Args:
title: Appointment title
start_time: Start datetime (ISO format)
end_time: End datetime (ISO format)
notes: Optional notes
Returns:
Created appointment dictionary
"""
self._check_api_limit()
from .models import Event
from django.utils import timezone
from datetime import datetime
# Parse datetimes
try:
start = timezone.make_aware(datetime.fromisoformat(start_time.replace('Z', '+00:00')))
end = timezone.make_aware(datetime.fromisoformat(end_time.replace('Z', '+00:00')))
except ValueError as e:
raise ScriptExecutionError(f"Invalid datetime format: {e}")
# Create event
event = Event.objects.create(
title=title,
start_time=start,
end_time=end,
notes=kwargs.get('notes', ''),
status='SCHEDULED',
created_by=self.user,
)
return {
'id': event.id,
'title': event.title,
'start_time': event.start_time.isoformat(),
'end_time': event.end_time.isoformat(),
}
def count(self, items):
"""Count items in a list"""
return len(items)
def sum(self, items):
"""Sum numeric items"""
return sum(items)
def filter(self, items, condition):
"""
Filter items by condition.
Example:
customers = api.get_customers()
active = api.filter(customers, lambda c: c['email'] != '')
"""
return [item for item in items if condition(item)]
class SafeScriptEngine:
"""
Execute customer scripts safely with resource limits.
"""
# Resource limits
MAX_EXECUTION_TIME = 30 # seconds
MAX_OUTPUT_SIZE = 10000 # characters
MAX_ITERATIONS = 10000 # loop iterations
MAX_MEMORY_MB = 50 # megabytes
# Allowed built-in functions (whitelist)
SAFE_BUILTINS = {
'len': len,
'range': range,
'min': min,
'max': max,
'sum': sum,
'abs': abs,
'round': round,
'int': int,
'float': float,
'str': str,
'bool': bool,
'list': list,
'dict': dict,
'enumerate': enumerate,
'zip': zip,
'sorted': sorted,
'reversed': reversed,
'any': any,
'all': all,
'True': True,
'False': False,
'None': None,
}
def __init__(self):
self._iteration_count = 0
def _check_iterations(self):
"""Track loop iterations to prevent infinite loops"""
self._iteration_count += 1
if self._iteration_count > self.MAX_ITERATIONS:
raise ResourceLimitExceeded(
f"Loop iteration limit exceeded ({self.MAX_ITERATIONS} iterations)"
)
def _validate_script(self, script: str) -> None:
"""
Validate script before execution.
Checks for:
- Forbidden operations (import, exec, eval, etc.)
- Syntax errors
- Excessive complexity
"""
try:
tree = ast.parse(script)
except SyntaxError as e:
raise ScriptExecutionError(f"Syntax error: {e}")
# Check for forbidden operations
for node in ast.walk(tree):
# No imports
if isinstance(node, (ast.Import, ast.ImportFrom)):
raise ScriptExecutionError(
"Import statements not allowed. Use provided 'api' object instead."
)
# No exec/eval/compile
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ['exec', 'eval', 'compile', '__import__']:
raise ScriptExecutionError(
f"Function '{node.func.id}' not allowed"
)
# No class definitions (for now)
if isinstance(node, ast.ClassDef):
raise ScriptExecutionError("Class definitions not allowed")
# No function definitions (for now - could allow later)
if isinstance(node, ast.FunctionDef):
raise ScriptExecutionError(
"Function definitions not allowed. Use inline logic instead."
)
# Check script size
if len(script) > 50000: # 50KB limit
raise ScriptExecutionError("Script too large (max 50KB)")
def execute(
self,
script: str,
api: SafeScriptAPI,
initial_vars: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Execute a customer script safely.
Args:
script: Python code to execute
api: SafeScriptAPI instance
initial_vars: Optional initial variables
Returns:
Dictionary with execution results:
- success: bool
- output: str (captured print statements)
- result: Any (value of 'result' variable if set)
- error: str (if failed)
"""
# Validate script
self._validate_script(script)
# Reset iteration counter
self._iteration_count = 0
# Prepare safe globals
safe_globals = {
'__builtins__': self.SAFE_BUILTINS,
'api': api,
'_iteration_check': self._check_iterations,
}
# Add initial variables
if initial_vars:
safe_globals.update(initial_vars)
# Inject iteration checks into loops
script = self._inject_loop_guards(script)
# Capture stdout/stderr
stdout_capture = StringIO()
stderr_capture = StringIO()
# Execute with timeout
start_time = time.time()
try:
with redirect_stdout(stdout_capture), redirect_stderr(stderr_capture):
# Compile and execute
compiled = compile(script, '<customer_script>', 'exec')
exec(compiled, safe_globals)
# Check execution time
if time.time() - start_time > self.MAX_EXECUTION_TIME:
raise ResourceLimitExceeded(
f"Execution time exceeded ({self.MAX_EXECUTION_TIME}s)"
)
# Get output
output = stdout_capture.getvalue()
if len(output) > self.MAX_OUTPUT_SIZE:
output = output[:self.MAX_OUTPUT_SIZE] + "\n... (output truncated)"
# Get result variable if set
result = safe_globals.get('result', None)
return {
'success': True,
'output': output,
'result': result,
'error': None,
'iterations': self._iteration_count,
'execution_time': time.time() - start_time,
}
except ResourceLimitExceeded as e:
return {
'success': False,
'output': stdout_capture.getvalue(),
'result': None,
'error': str(e),
}
except Exception as e:
error_msg = f"{type(e).__name__}: {str(e)}"
stderr_output = stderr_capture.getvalue()
if stderr_output:
error_msg += f"\n{stderr_output}"
return {
'success': False,
'output': stdout_capture.getvalue(),
'result': None,
'error': error_msg,
}
def _inject_loop_guards(self, script: str) -> str:
"""
Inject iteration checks into loops to prevent infinite loops.
Transforms:
for i in range(10):
print(i)
Into:
for i in range(10):
_iteration_check()
print(i)
"""
try:
tree = ast.parse(script)
except SyntaxError:
# If it doesn't parse, validation will catch it
return script
class LoopGuardInjector(ast.NodeTransformer):
def visit_For(self, node):
# Add iteration check at start of loop body
check_call = ast.Expr(
value=ast.Call(
func=ast.Name(id='_iteration_check', ctx=ast.Load()),
args=[],
keywords=[]
)
)
node.body.insert(0, check_call)
return self.generic_visit(node)
def visit_While(self, node):
# Add iteration check at start of loop body
check_call = ast.Expr(
value=ast.Call(
func=ast.Name(id='_iteration_check', ctx=ast.Load()),
args=[],
keywords=[]
)
)
node.body.insert(0, check_call)
return self.generic_visit(node)
transformed = LoopGuardInjector().visit(tree)
ast.fix_missing_locations(transformed)
return ast.unparse(transformed)
def test_script_execution():
"""Test the safe script engine"""
engine = SafeScriptEngine()
# Create mock API
class MockBusiness:
name = "Test Business"
api = SafeScriptAPI(
business=MockBusiness(),
user=None,
execution_context={}
)
# Test 1: Simple script
script1 = """
# Get appointments
appointments = api.get_appointments(status='SCHEDULED', limit=10)
# Count them
count = len(appointments)
# Log result
api.log(f"Found {count} appointments")
result = count
"""
print("Test 1: Simple script")
result1 = engine.execute(script1, api)
print(f"Success: {result1['success']}")
print(f"Result: {result1['result']}")
print(f"Output: {result1['output']}")
print()
# Test 2: Conditional logic
script2 = """
appointments = api.get_appointments(limit=100)
# Count by status
scheduled = 0
completed = 0
for apt in appointments:
if apt['status'] == 'SCHEDULED':
scheduled += 1
elif apt['status'] == 'COMPLETED':
completed += 1
result = {
'scheduled': scheduled,
'completed': completed,
'total': len(appointments)
}
"""
print("Test 2: Conditional logic")
result2 = engine.execute(script2, api)
print(f"Success: {result2['success']}")
print(f"Result: {result2['result']}")
print()
# Test 3: Forbidden operation (should fail)
script3 = """
import os
os.system('echo hello')
"""
print("Test 3: Forbidden operation")
result3 = engine.execute(script3, api)
print(f"Success: {result3['success']}")
print(f"Error: {result3['error']}")
print()
# Test 4: Infinite loop protection
script4 = """
count = 0
while True:
count += 1
"""
print("Test 4: Infinite loop protection")
result4 = engine.execute(script4, api)
print(f"Success: {result4['success']}")
print(f"Error: {result4['error']}")
print()
if __name__ == '__main__':
test_script_execution()