Files
Django-Enhanced-API-Docs/schema.py

428 lines
16 KiB
Python
Raw Normal View History

2026-01-13 11:00:34 +01:00
from drf_spectacular.openapi import AutoSchema
import logging
import json
from drf_spectacular.utils import OpenApiParameter
from drf_spectacular.types import OpenApiTypes
from django.conf import settings
class CustomAutoSchema(AutoSchema):
"""
Custom schema generator for drf-spectacular.
Provides detailed parameter documentation and custom tags.
"""
def get_tags(self):
"""Map viewset class names to custom tags for better organization"""
viewset_class_name = self.view.__class__.__name__
spectacular_settings = getattr(settings, 'SPECTACULAR_SETTINGS', {})
for tag in spectacular_settings.get('TAGS', []):
if viewset_class_name in tag.get('viewsets', []):
return [tag['name']]
return super().get_tags()
def postprocess_schema_enhancements(result, generator, request, public):
"""
Postprocessing hook to enhance the OpenAPI schema.
Adds detailed parameter descriptions, response codes, and pagination examples.
"""
# Build a mapping of paths to search_fields from the generator
path_search_fields = {}
for endpoint_path, path_regex, method, callback in generator.endpoints:
if hasattr(callback, 'cls'):
view_class = callback.cls
if hasattr(view_class, 'search_fields') and view_class.search_fields:
path_search_fields[endpoint_path] = view_class.search_fields
# Get schemas for reference resolution
schemas_components = result.get('components', {}).get('schemas', {})
for path, path_item in result.get('paths', {}).items():
for method, operation in path_item.items():
if method.lower() not in ['get', 'post', 'put', 'patch', 'delete']:
continue
# Enhance GET parameters
if method.lower() == 'get':
if '{' not in path:
# List endpoints
search_fields = path_search_fields.get(path, [])
_enhance_list_parameters(operation, path, search_fields)
else:
# Detail endpoints - add lang parameter
_add_lang_parameter(operation)
# Add standard error responses
_add_standard_responses(operation, method, path)
# Add code samples for ReDoc
_add_code_samples(operation, path, method, schemas_components)
# Add pagination examples to components
_add_pagination_examples(result)
return result
def _enhance_list_parameters(operation, path, search_fields):
"""Enhance filter parameters for list endpoints"""
parameters = operation.get('parameters', [])
# Add lang parameter for translatable fields
lang_param = {
'name': 'lang',
'in': 'query',
'required': False,
'schema': {
'type': 'string',
'enum': ['en', 'de', 'fr'],
},
'description': 'Language code for translatable fields. If specified, translatable fields return as strings instead of dicts. Omit to get all translations.'
}
parameters.insert(0, lang_param)
# Build list of parameters to keep
filtered_params = []
for param in parameters:
param_name = param.get('name', '')
schema = param.get('schema', {})
param_type = schema.get('type', 'string')
# Enhance search parameter
if param_name == 'search':
if search_fields:
fields_str = ', '.join(search_fields)
param['description'] = f'Search (case-insensitive) across fields: {fields_str}'
filtered_params.append(param)
# Skip search parameter if no search_fields are defined
continue
# Enhance ordering parameter
elif param_name == 'ordering':
param['description'] = "Order results by field. Prefix with '-' for descending (e.g., '-created_at')"
filtered_params.append(param)
# Enhance page parameter
elif param_name == 'page':
param['description'] = "Page number (1-indexed)"
filtered_params.append(param)
# Enhance other parameters based on type
elif not param.get('description') or param['description'] in ['', 'A page number within the paginated result set.', 'Which field to use when ordering the results.']:
if param_type == 'boolean':
param['description'] = f"Filter by {param_name} (use true/false)"
elif param_type == 'integer':
# Check if it's likely an ID field
if param_name in ['category', 'product', 'file', 'product_version', 'user', 'parent', 'group_type']:
param['description'] = f"Filter by {param_name.replace('_', ' ').title()} ID"
elif '__' in param_name:
param['description'] = f"Filter by related ID (lookup via {param_name})"
else:
param['description'] = f"Filter by {param_name} (integer value)"
elif param_type == 'array':
# M2M fields
if param_name in ['tags', 'groups', 'users']:
param['description'] = f"Filter by {param_name.replace('_', ' ').title()} ID (repeat parameter for multiple values, e.g. ?{param_name}=1&{param_name}=2)"
else:
param['description'] = f"Filter by {param_name} (repeat parameter for multiple values)"
else:
param['description'] = f"Filter by exact match on {param_name}"
filtered_params.append(param)
else:
# Keep other parameters as-is
filtered_params.append(param)
# Replace parameters list with filtered version
operation['parameters'] = filtered_params
def _add_lang_parameter(operation):
"""Add lang parameter to GET detail endpoints for translatable fields"""
if 'parameters' not in operation:
operation['parameters'] = []
# Check if lang parameter already exists
has_lang = any(p.get('name') == 'lang' for p in operation['parameters'])
if not has_lang:
lang_param = {
'name': 'lang',
'in': 'query',
'required': False,
'schema': {
'type': 'string',
'enum': ['en', 'de', 'fr'],
},
'description': 'Language code for translatable fields. If specified, translatable fields return as strings instead of dicts. Omit to get all translations.'
}
operation['parameters'].insert(0, lang_param)
def _add_standard_responses(operation, method, path):
"""Add standard error responses to all operations"""
if 'responses' not in operation:
operation['responses'] = {}
method_lower = method.lower()
# All methods can have 401, 403, 500
operation['responses'].setdefault('401', {
'description': 'Unauthorized - Authentication credentials were not provided or are invalid'
})
operation['responses'].setdefault('403', {
'description': 'Forbidden - You do not have permission to perform this action'
})
operation['responses'].setdefault('500', {
'description': 'Internal Server Error - An unexpected error occurred on the server'
})
# Method-specific responses
if method_lower == 'get':
if '{' in path: # Detail endpoint
operation['responses'].setdefault('404', {
'description': 'Not Found - The requested resource does not exist'
})
elif method_lower == 'post':
operation['responses'].setdefault('400', {
'description': 'Bad Request - Invalid input data or validation error'
})
operation['responses'].setdefault('409', {
'description': 'Conflict - The resource already exists or there is a conflict with the current state'
})
elif method_lower in ['put', 'patch']:
operation['responses'].setdefault('400', {
'description': 'Bad Request - Invalid input data or validation error'
})
operation['responses'].setdefault('404', {
'description': 'Not Found - The requested resource does not exist'
})
operation['responses'].setdefault('409', {
'description': 'Conflict - There is a conflict with the current state of the resource'
})
elif method_lower == 'delete':
operation['responses'].setdefault('404', {
'description': 'Not Found - The requested resource does not exist'
})
def _add_code_samples(operation, path, method, schemas_components=None):
"""Add cURL code samples to operation description"""
method_upper = method.upper()
live_url = getattr(settings, 'LIVE_URL', 'http://localhost:8000')
full_url = f"{live_url}{path}"
# Build cURL command
curl_parts = [f"curl -X {method_upper}"]
# Add URL
curl_parts.append(f'"{full_url}"')
# Add authentication
curl_parts.append('-u "username:password"')
# Add headers
curl_parts.append('-H "Accept: application/json"')
# Add body for POST/PUT/PATCH
if method_upper in ['POST', 'PUT', 'PATCH']:
curl_parts.append('-H "Content-Type: application/json"')
# Try to get request body example
request_body = operation.get('requestBody', {})
content = request_body.get('content', {})
json_content = content.get('application/json', {})
body_example = None
# Try to get example from schema
if 'example' in json_content:
body_example = json_content['example']
elif 'examples' in json_content:
# Get first example
examples = json_content['examples']
if examples:
first_example_key = list(examples.keys())[0]
body_example = examples[first_example_key].get('value')
# If no example found, try schema or generate from schema
if not body_example:
schema = json_content.get('schema', {})
if 'example' in schema:
body_example = schema['example']
else:
# Generate example from schema structure
body_example = _generate_example_from_schema(schema, schemas_components)
if body_example:
body_json = json.dumps(body_example, indent=2)
# Escape single quotes in JSON for shell
body_json_escaped = body_json.replace("'", "'\"'\"'")
curl_parts.append(f"-d '{body_json_escaped}'")
curl_command = ' \\\n '.join(curl_parts)
# Log for debugging
if method_upper in ['POST', 'PUT', 'PATCH']:
logging.debug(f"Generated cURL for {path} ({method_upper}): has body data: {'-d' in curl_command}")
# Add as formatted text in a custom vendor extension that won't be displayed
# But add it to description with proper formatting
current_description = operation.get('description', '')
# Use Markdown code fence with explicit language
curl_markdown = f"""
## Example Request
```bash
{curl_command}
```
"""
operation['description'] = current_description + curl_markdown
def _generate_example_from_schema(schema, schemas_components=None, depth=0):
"""Generate example JSON from schema"""
# Prevent infinite recursion
if depth > 5:
return None
if '$ref' in schema:
# Resolve reference
ref_path = schema['$ref']
if ref_path.startswith('#/components/schemas/'):
schema_name = ref_path.replace('#/components/schemas/', '')
if schemas_components and schema_name in schemas_components:
resolved_schema = schemas_components[schema_name]
return _generate_example_from_schema(resolved_schema, schemas_components, depth + 1)
return None
# Check if schema has an example
if 'example' in schema:
return schema['example']
schema_type = schema.get('type')
if schema_type == 'object':
properties = schema.get('properties', {})
example = {}
for prop_name, prop_schema in properties.items():
if prop_schema.get('readOnly'):
continue
value = _generate_example_value(prop_schema, schemas_components, depth + 1)
if value is not None:
example[prop_name] = value
return example
return _generate_example_value(schema, schemas_components, depth)
def _generate_example_value(schema, schemas_components=None, depth=0):
"""Generate example value for a schema field"""
# Prevent infinite recursion
if depth > 5:
return None
# Handle $ref
if '$ref' in schema:
ref_path = schema['$ref']
if ref_path.startswith('#/components/schemas/'):
schema_name = ref_path.replace('#/components/schemas/', '')
if schemas_components and schema_name in schemas_components:
resolved_schema = schemas_components[schema_name]
return _generate_example_from_schema(resolved_schema, schemas_components, depth + 1)
return None
if 'example' in schema:
return schema['example']
schema_type = schema.get('type')
schema_format = schema.get('format')
if schema_type == 'string':
if schema_format == 'email':
return 'user@example.com'
elif schema_format == 'date':
return '2024-01-01'
elif schema_format == 'date-time':
return '2024-01-01T12:00:00Z'
elif schema_format == 'uuid':
return '123e4567-e89b-12d3-a456-426614174000'
else:
return 'string'
elif schema_type == 'integer':
return 1
elif schema_type == 'number':
return 1.0
elif schema_type == 'boolean':
return True
elif schema_type == 'array':
items_schema = schema.get('items', {})
item_example = _generate_example_value(items_schema, schemas_components, depth + 1)
return [item_example] if item_example is not None else []
elif schema_type == 'object':
properties = schema.get('properties', {})
example = {}
for prop_name, prop_schema in properties.items():
if prop_schema.get('readOnly'):
continue
value = _generate_example_value(prop_schema, schemas_components, depth + 1)
if value is not None:
example[prop_name] = value
return example
return None
def _add_pagination_examples(result):
"""Add pagination URL examples to all paginated response schemas"""
if 'components' not in result or 'schemas' not in result['components']:
return
# Build mapping from schema name to endpoint path
schema_to_path = {}
for path, path_item in result.get('paths', {}).items():
# Only process list endpoints (no {id} in path)
if '{' in path:
continue
for method, operation in path_item.items():
if method.lower() != 'get':
continue
# Get the response schema reference
responses = operation.get('responses', {})
success_response = responses.get('200', {})
content = success_response.get('content', {})
json_content = content.get('application/json', {})
schema_ref = json_content.get('schema', {}).get('$ref', '')
# Extract schema name from reference like "#/components/schemas/PaginatedProductList"
if schema_ref.startswith('#/components/schemas/'):
schema_name = schema_ref.replace('#/components/schemas/', '')
schema_to_path[schema_name] = path
# Add examples to paginated schemas
live_url = getattr(settings, 'LIVE_URL', 'http://localhost:8000')
for schema_name, schema in result['components']['schemas'].items():
if not schema_name.startswith('Paginated') or 'properties' not in schema:
continue
# Get the actual endpoint path for this schema
endpoint_path = schema_to_path.get(schema_name, '/api/v1/.../')
if 'next' in schema['properties']:
schema['properties']['next']['example'] = f"{live_url}{endpoint_path}?page=<next-page:int>"
schema['properties']['next']['nullable'] = True
if 'previous' in schema['properties']:
schema['properties']['previous']['example'] = f"{live_url}{endpoint_path}?page=<prev-page:int>"
schema['properties']['previous']['nullable'] = True