428 lines
16 KiB
Python
428 lines
16 KiB
Python
from drf_spectacular.openapi import AutoSchema
|
|
import logging
|
|
import json
|
|
from drf_spectacular.utils import OpenApiParameter
|
|
from drf_spectacular.types import OpenApiTypes
|
|
from django.conf import settings
|
|
|
|
|
|
class CustomAutoSchema(AutoSchema):
|
|
"""
|
|
Custom schema generator for drf-spectacular.
|
|
Provides detailed parameter documentation and custom tags.
|
|
"""
|
|
|
|
def get_tags(self):
|
|
"""Map viewset class names to custom tags for better organization"""
|
|
viewset_class_name = self.view.__class__.__name__
|
|
spectacular_settings = getattr(settings, 'SPECTACULAR_SETTINGS', {})
|
|
for tag in spectacular_settings.get('TAGS', []):
|
|
if viewset_class_name in tag.get('viewsets', []):
|
|
return [tag['name']]
|
|
|
|
return super().get_tags()
|
|
|
|
|
|
def postprocess_schema_enhancements(result, generator, request, public):
|
|
"""
|
|
Postprocessing hook to enhance the OpenAPI schema.
|
|
Adds detailed parameter descriptions, response codes, and pagination examples.
|
|
"""
|
|
|
|
# Build a mapping of paths to search_fields from the generator
|
|
path_search_fields = {}
|
|
for endpoint_path, path_regex, method, callback in generator.endpoints:
|
|
if hasattr(callback, 'cls'):
|
|
view_class = callback.cls
|
|
if hasattr(view_class, 'search_fields') and view_class.search_fields:
|
|
path_search_fields[endpoint_path] = view_class.search_fields
|
|
|
|
# Get schemas for reference resolution
|
|
schemas_components = result.get('components', {}).get('schemas', {})
|
|
|
|
for path, path_item in result.get('paths', {}).items():
|
|
for method, operation in path_item.items():
|
|
if method.lower() not in ['get', 'post', 'put', 'patch', 'delete']:
|
|
continue
|
|
|
|
# Enhance GET parameters
|
|
if method.lower() == 'get':
|
|
if '{' not in path:
|
|
# List endpoints
|
|
search_fields = path_search_fields.get(path, [])
|
|
_enhance_list_parameters(operation, path, search_fields)
|
|
else:
|
|
# Detail endpoints - add lang parameter
|
|
_add_lang_parameter(operation)
|
|
|
|
# Add standard error responses
|
|
_add_standard_responses(operation, method, path)
|
|
|
|
# Add code samples for ReDoc
|
|
_add_code_samples(operation, path, method, schemas_components)
|
|
|
|
# Add pagination examples to components
|
|
_add_pagination_examples(result)
|
|
|
|
return result
|
|
|
|
|
|
def _enhance_list_parameters(operation, path, search_fields):
|
|
"""Enhance filter parameters for list endpoints"""
|
|
parameters = operation.get('parameters', [])
|
|
|
|
# Add lang parameter for translatable fields
|
|
lang_param = {
|
|
'name': 'lang',
|
|
'in': 'query',
|
|
'required': False,
|
|
'schema': {
|
|
'type': 'string',
|
|
'enum': ['en', 'de', 'fr'],
|
|
},
|
|
'description': 'Language code for translatable fields. If specified, translatable fields return as strings instead of dicts. Omit to get all translations.'
|
|
}
|
|
parameters.insert(0, lang_param)
|
|
|
|
# Build list of parameters to keep
|
|
filtered_params = []
|
|
|
|
for param in parameters:
|
|
param_name = param.get('name', '')
|
|
schema = param.get('schema', {})
|
|
param_type = schema.get('type', 'string')
|
|
|
|
# Enhance search parameter
|
|
if param_name == 'search':
|
|
if search_fields:
|
|
fields_str = ', '.join(search_fields)
|
|
param['description'] = f'Search (case-insensitive) across fields: {fields_str}'
|
|
filtered_params.append(param)
|
|
# Skip search parameter if no search_fields are defined
|
|
continue
|
|
|
|
# Enhance ordering parameter
|
|
elif param_name == 'ordering':
|
|
param['description'] = "Order results by field. Prefix with '-' for descending (e.g., '-created_at')"
|
|
filtered_params.append(param)
|
|
|
|
# Enhance page parameter
|
|
elif param_name == 'page':
|
|
param['description'] = "Page number (1-indexed)"
|
|
filtered_params.append(param)
|
|
|
|
# Enhance other parameters based on type
|
|
elif not param.get('description') or param['description'] in ['', 'A page number within the paginated result set.', 'Which field to use when ordering the results.']:
|
|
if param_type == 'boolean':
|
|
param['description'] = f"Filter by {param_name} (use true/false)"
|
|
elif param_type == 'integer':
|
|
# Check if it's likely an ID field
|
|
if param_name in ['category', 'product', 'file', 'product_version', 'user', 'parent', 'group_type']:
|
|
param['description'] = f"Filter by {param_name.replace('_', ' ').title()} ID"
|
|
elif '__' in param_name:
|
|
param['description'] = f"Filter by related ID (lookup via {param_name})"
|
|
else:
|
|
param['description'] = f"Filter by {param_name} (integer value)"
|
|
elif param_type == 'array':
|
|
# M2M fields
|
|
if param_name in ['tags', 'groups', 'users']:
|
|
param['description'] = f"Filter by {param_name.replace('_', ' ').title()} ID (repeat parameter for multiple values, e.g. ?{param_name}=1&{param_name}=2)"
|
|
else:
|
|
param['description'] = f"Filter by {param_name} (repeat parameter for multiple values)"
|
|
else:
|
|
param['description'] = f"Filter by exact match on {param_name}"
|
|
filtered_params.append(param)
|
|
else:
|
|
# Keep other parameters as-is
|
|
filtered_params.append(param)
|
|
|
|
# Replace parameters list with filtered version
|
|
operation['parameters'] = filtered_params
|
|
|
|
|
|
def _add_lang_parameter(operation):
|
|
"""Add lang parameter to GET detail endpoints for translatable fields"""
|
|
if 'parameters' not in operation:
|
|
operation['parameters'] = []
|
|
|
|
# Check if lang parameter already exists
|
|
has_lang = any(p.get('name') == 'lang' for p in operation['parameters'])
|
|
if not has_lang:
|
|
lang_param = {
|
|
'name': 'lang',
|
|
'in': 'query',
|
|
'required': False,
|
|
'schema': {
|
|
'type': 'string',
|
|
'enum': ['en', 'de', 'fr'],
|
|
},
|
|
'description': 'Language code for translatable fields. If specified, translatable fields return as strings instead of dicts. Omit to get all translations.'
|
|
}
|
|
operation['parameters'].insert(0, lang_param)
|
|
|
|
|
|
def _add_standard_responses(operation, method, path):
|
|
"""Add standard error responses to all operations"""
|
|
if 'responses' not in operation:
|
|
operation['responses'] = {}
|
|
|
|
method_lower = method.lower()
|
|
|
|
# All methods can have 401, 403, 500
|
|
operation['responses'].setdefault('401', {
|
|
'description': 'Unauthorized - Authentication credentials were not provided or are invalid'
|
|
})
|
|
operation['responses'].setdefault('403', {
|
|
'description': 'Forbidden - You do not have permission to perform this action'
|
|
})
|
|
operation['responses'].setdefault('500', {
|
|
'description': 'Internal Server Error - An unexpected error occurred on the server'
|
|
})
|
|
|
|
# Method-specific responses
|
|
if method_lower == 'get':
|
|
if '{' in path: # Detail endpoint
|
|
operation['responses'].setdefault('404', {
|
|
'description': 'Not Found - The requested resource does not exist'
|
|
})
|
|
|
|
elif method_lower == 'post':
|
|
operation['responses'].setdefault('400', {
|
|
'description': 'Bad Request - Invalid input data or validation error'
|
|
})
|
|
operation['responses'].setdefault('409', {
|
|
'description': 'Conflict - The resource already exists or there is a conflict with the current state'
|
|
})
|
|
|
|
elif method_lower in ['put', 'patch']:
|
|
operation['responses'].setdefault('400', {
|
|
'description': 'Bad Request - Invalid input data or validation error'
|
|
})
|
|
operation['responses'].setdefault('404', {
|
|
'description': 'Not Found - The requested resource does not exist'
|
|
})
|
|
operation['responses'].setdefault('409', {
|
|
'description': 'Conflict - There is a conflict with the current state of the resource'
|
|
})
|
|
|
|
elif method_lower == 'delete':
|
|
operation['responses'].setdefault('404', {
|
|
'description': 'Not Found - The requested resource does not exist'
|
|
})
|
|
|
|
|
|
def _add_code_samples(operation, path, method, schemas_components=None):
|
|
"""Add cURL code samples to operation description"""
|
|
method_upper = method.upper()
|
|
live_url = getattr(settings, 'LIVE_URL', 'http://localhost:8000')
|
|
full_url = f"{live_url}{path}"
|
|
|
|
# Build cURL command
|
|
curl_parts = [f"curl -X {method_upper}"]
|
|
|
|
# Add URL
|
|
curl_parts.append(f'"{full_url}"')
|
|
|
|
# Add authentication
|
|
curl_parts.append('-u "username:password"')
|
|
|
|
# Add headers
|
|
curl_parts.append('-H "Accept: application/json"')
|
|
|
|
# Add body for POST/PUT/PATCH
|
|
if method_upper in ['POST', 'PUT', 'PATCH']:
|
|
curl_parts.append('-H "Content-Type: application/json"')
|
|
|
|
# Try to get request body example
|
|
request_body = operation.get('requestBody', {})
|
|
content = request_body.get('content', {})
|
|
json_content = content.get('application/json', {})
|
|
|
|
body_example = None
|
|
|
|
# Try to get example from schema
|
|
if 'example' in json_content:
|
|
body_example = json_content['example']
|
|
elif 'examples' in json_content:
|
|
# Get first example
|
|
examples = json_content['examples']
|
|
if examples:
|
|
first_example_key = list(examples.keys())[0]
|
|
body_example = examples[first_example_key].get('value')
|
|
|
|
# If no example found, try schema or generate from schema
|
|
if not body_example:
|
|
schema = json_content.get('schema', {})
|
|
if 'example' in schema:
|
|
body_example = schema['example']
|
|
else:
|
|
# Generate example from schema structure
|
|
body_example = _generate_example_from_schema(schema, schemas_components)
|
|
|
|
if body_example:
|
|
body_json = json.dumps(body_example, indent=2)
|
|
# Escape single quotes in JSON for shell
|
|
body_json_escaped = body_json.replace("'", "'\"'\"'")
|
|
curl_parts.append(f"-d '{body_json_escaped}'")
|
|
|
|
curl_command = ' \\\n '.join(curl_parts)
|
|
|
|
# Log for debugging
|
|
if method_upper in ['POST', 'PUT', 'PATCH']:
|
|
logging.debug(f"Generated cURL for {path} ({method_upper}): has body data: {'-d' in curl_command}")
|
|
|
|
# Add as formatted text in a custom vendor extension that won't be displayed
|
|
# But add it to description with proper formatting
|
|
current_description = operation.get('description', '')
|
|
|
|
# Use Markdown code fence with explicit language
|
|
curl_markdown = f"""
|
|
|
|
## Example Request
|
|
|
|
```bash
|
|
{curl_command}
|
|
```
|
|
"""
|
|
|
|
operation['description'] = current_description + curl_markdown
|
|
|
|
|
|
def _generate_example_from_schema(schema, schemas_components=None, depth=0):
|
|
"""Generate example JSON from schema"""
|
|
# Prevent infinite recursion
|
|
if depth > 5:
|
|
return None
|
|
|
|
if '$ref' in schema:
|
|
# Resolve reference
|
|
ref_path = schema['$ref']
|
|
if ref_path.startswith('#/components/schemas/'):
|
|
schema_name = ref_path.replace('#/components/schemas/', '')
|
|
if schemas_components and schema_name in schemas_components:
|
|
resolved_schema = schemas_components[schema_name]
|
|
return _generate_example_from_schema(resolved_schema, schemas_components, depth + 1)
|
|
return None
|
|
|
|
# Check if schema has an example
|
|
if 'example' in schema:
|
|
return schema['example']
|
|
|
|
schema_type = schema.get('type')
|
|
|
|
if schema_type == 'object':
|
|
properties = schema.get('properties', {})
|
|
example = {}
|
|
for prop_name, prop_schema in properties.items():
|
|
if prop_schema.get('readOnly'):
|
|
continue
|
|
value = _generate_example_value(prop_schema, schemas_components, depth + 1)
|
|
if value is not None:
|
|
example[prop_name] = value
|
|
return example
|
|
|
|
return _generate_example_value(schema, schemas_components, depth)
|
|
|
|
|
|
def _generate_example_value(schema, schemas_components=None, depth=0):
|
|
"""Generate example value for a schema field"""
|
|
# Prevent infinite recursion
|
|
if depth > 5:
|
|
return None
|
|
|
|
# Handle $ref
|
|
if '$ref' in schema:
|
|
ref_path = schema['$ref']
|
|
if ref_path.startswith('#/components/schemas/'):
|
|
schema_name = ref_path.replace('#/components/schemas/', '')
|
|
if schemas_components and schema_name in schemas_components:
|
|
resolved_schema = schemas_components[schema_name]
|
|
return _generate_example_from_schema(resolved_schema, schemas_components, depth + 1)
|
|
return None
|
|
|
|
if 'example' in schema:
|
|
return schema['example']
|
|
|
|
schema_type = schema.get('type')
|
|
schema_format = schema.get('format')
|
|
|
|
if schema_type == 'string':
|
|
if schema_format == 'email':
|
|
return 'user@example.com'
|
|
elif schema_format == 'date':
|
|
return '2024-01-01'
|
|
elif schema_format == 'date-time':
|
|
return '2024-01-01T12:00:00Z'
|
|
elif schema_format == 'uuid':
|
|
return '123e4567-e89b-12d3-a456-426614174000'
|
|
else:
|
|
return 'string'
|
|
elif schema_type == 'integer':
|
|
return 1
|
|
elif schema_type == 'number':
|
|
return 1.0
|
|
elif schema_type == 'boolean':
|
|
return True
|
|
elif schema_type == 'array':
|
|
items_schema = schema.get('items', {})
|
|
item_example = _generate_example_value(items_schema, schemas_components, depth + 1)
|
|
return [item_example] if item_example is not None else []
|
|
elif schema_type == 'object':
|
|
properties = schema.get('properties', {})
|
|
example = {}
|
|
for prop_name, prop_schema in properties.items():
|
|
if prop_schema.get('readOnly'):
|
|
continue
|
|
value = _generate_example_value(prop_schema, schemas_components, depth + 1)
|
|
if value is not None:
|
|
example[prop_name] = value
|
|
return example
|
|
|
|
return None
|
|
|
|
|
|
def _add_pagination_examples(result):
|
|
"""Add pagination URL examples to all paginated response schemas"""
|
|
if 'components' not in result or 'schemas' not in result['components']:
|
|
return
|
|
|
|
# Build mapping from schema name to endpoint path
|
|
schema_to_path = {}
|
|
|
|
for path, path_item in result.get('paths', {}).items():
|
|
# Only process list endpoints (no {id} in path)
|
|
if '{' in path:
|
|
continue
|
|
|
|
for method, operation in path_item.items():
|
|
if method.lower() != 'get':
|
|
continue
|
|
|
|
# Get the response schema reference
|
|
responses = operation.get('responses', {})
|
|
success_response = responses.get('200', {})
|
|
content = success_response.get('content', {})
|
|
json_content = content.get('application/json', {})
|
|
schema_ref = json_content.get('schema', {}).get('$ref', '')
|
|
|
|
# Extract schema name from reference like "#/components/schemas/PaginatedProductList"
|
|
if schema_ref.startswith('#/components/schemas/'):
|
|
schema_name = schema_ref.replace('#/components/schemas/', '')
|
|
schema_to_path[schema_name] = path
|
|
|
|
# Add examples to paginated schemas
|
|
live_url = getattr(settings, 'LIVE_URL', 'http://localhost:8000')
|
|
for schema_name, schema in result['components']['schemas'].items():
|
|
if not schema_name.startswith('Paginated') or 'properties' not in schema:
|
|
continue
|
|
|
|
# Get the actual endpoint path for this schema
|
|
endpoint_path = schema_to_path.get(schema_name, '/api/v1/.../')
|
|
|
|
if 'next' in schema['properties']:
|
|
schema['properties']['next']['example'] = f"{live_url}{endpoint_path}?page=<next-page:int>"
|
|
schema['properties']['next']['nullable'] = True
|
|
if 'previous' in schema['properties']:
|
|
schema['properties']['previous']['example'] = f"{live_url}{endpoint_path}?page=<prev-page:int>"
|
|
schema['properties']['previous']['nullable'] = True
|