from drf_spectacular.openapi import AutoSchema import logging import json from drf_spectacular.utils import OpenApiParameter from drf_spectacular.types import OpenApiTypes from django.conf import settings class CustomAutoSchema(AutoSchema): """ Custom schema generator for drf-spectacular. Provides detailed parameter documentation and custom tags. """ def get_tags(self): """Map viewset class names to custom tags for better organization""" viewset_class_name = self.view.__class__.__name__ spectacular_settings = getattr(settings, 'SPECTACULAR_SETTINGS', {}) for tag in spectacular_settings.get('TAGS', []): if viewset_class_name in tag.get('viewsets', []): return [tag['name']] return super().get_tags() def postprocess_schema_enhancements(result, generator, request, public): """ Postprocessing hook to enhance the OpenAPI schema. Adds detailed parameter descriptions, response codes, and pagination examples. """ # Build a mapping of paths to search_fields from the generator path_search_fields = {} for endpoint_path, path_regex, method, callback in generator.endpoints: if hasattr(callback, 'cls'): view_class = callback.cls if hasattr(view_class, 'search_fields') and view_class.search_fields: path_search_fields[endpoint_path] = view_class.search_fields # Get schemas for reference resolution schemas_components = result.get('components', {}).get('schemas', {}) for path, path_item in result.get('paths', {}).items(): for method, operation in path_item.items(): if method.lower() not in ['get', 'post', 'put', 'patch', 'delete']: continue # Enhance GET parameters if method.lower() == 'get': if '{' not in path: # List endpoints search_fields = path_search_fields.get(path, []) _enhance_list_parameters(operation, path, search_fields) else: # Detail endpoints - add lang parameter _add_lang_parameter(operation) # Add standard error responses _add_standard_responses(operation, method, path) # Add code samples for ReDoc _add_code_samples(operation, path, method, schemas_components) # Add pagination examples to components _add_pagination_examples(result) return result def _enhance_list_parameters(operation, path, search_fields): """Enhance filter parameters for list endpoints""" parameters = operation.get('parameters', []) # Add lang parameter for translatable fields lang_param = { 'name': 'lang', 'in': 'query', 'required': False, 'schema': { 'type': 'string', 'enum': ['en', 'de', 'fr'], }, 'description': 'Language code for translatable fields. If specified, translatable fields return as strings instead of dicts. Omit to get all translations.' } parameters.insert(0, lang_param) # Build list of parameters to keep filtered_params = [] for param in parameters: param_name = param.get('name', '') schema = param.get('schema', {}) param_type = schema.get('type', 'string') # Enhance search parameter if param_name == 'search': if search_fields: fields_str = ', '.join(search_fields) param['description'] = f'Search (case-insensitive) across fields: {fields_str}' filtered_params.append(param) # Skip search parameter if no search_fields are defined continue # Enhance ordering parameter elif param_name == 'ordering': param['description'] = "Order results by field. Prefix with '-' for descending (e.g., '-created_at')" filtered_params.append(param) # Enhance page parameter elif param_name == 'page': param['description'] = "Page number (1-indexed)" filtered_params.append(param) # Enhance other parameters based on type elif not param.get('description') or param['description'] in ['', 'A page number within the paginated result set.', 'Which field to use when ordering the results.']: if param_type == 'boolean': param['description'] = f"Filter by {param_name} (use true/false)" elif param_type == 'integer': # Check if it's likely an ID field if param_name in ['category', 'product', 'file', 'product_version', 'user', 'parent', 'group_type']: param['description'] = f"Filter by {param_name.replace('_', ' ').title()} ID" elif '__' in param_name: param['description'] = f"Filter by related ID (lookup via {param_name})" else: param['description'] = f"Filter by {param_name} (integer value)" elif param_type == 'array': # M2M fields if param_name in ['tags', 'groups', 'users']: param['description'] = f"Filter by {param_name.replace('_', ' ').title()} ID (repeat parameter for multiple values, e.g. ?{param_name}=1&{param_name}=2)" else: param['description'] = f"Filter by {param_name} (repeat parameter for multiple values)" else: param['description'] = f"Filter by exact match on {param_name}" filtered_params.append(param) else: # Keep other parameters as-is filtered_params.append(param) # Replace parameters list with filtered version operation['parameters'] = filtered_params def _add_lang_parameter(operation): """Add lang parameter to GET detail endpoints for translatable fields""" if 'parameters' not in operation: operation['parameters'] = [] # Check if lang parameter already exists has_lang = any(p.get('name') == 'lang' for p in operation['parameters']) if not has_lang: lang_param = { 'name': 'lang', 'in': 'query', 'required': False, 'schema': { 'type': 'string', 'enum': ['en', 'de', 'fr'], }, 'description': 'Language code for translatable fields. If specified, translatable fields return as strings instead of dicts. Omit to get all translations.' } operation['parameters'].insert(0, lang_param) def _add_standard_responses(operation, method, path): """Add standard error responses to all operations""" if 'responses' not in operation: operation['responses'] = {} method_lower = method.lower() # All methods can have 401, 403, 500 operation['responses'].setdefault('401', { 'description': 'Unauthorized - Authentication credentials were not provided or are invalid' }) operation['responses'].setdefault('403', { 'description': 'Forbidden - You do not have permission to perform this action' }) operation['responses'].setdefault('500', { 'description': 'Internal Server Error - An unexpected error occurred on the server' }) # Method-specific responses if method_lower == 'get': if '{' in path: # Detail endpoint operation['responses'].setdefault('404', { 'description': 'Not Found - The requested resource does not exist' }) elif method_lower == 'post': operation['responses'].setdefault('400', { 'description': 'Bad Request - Invalid input data or validation error' }) operation['responses'].setdefault('409', { 'description': 'Conflict - The resource already exists or there is a conflict with the current state' }) elif method_lower in ['put', 'patch']: operation['responses'].setdefault('400', { 'description': 'Bad Request - Invalid input data or validation error' }) operation['responses'].setdefault('404', { 'description': 'Not Found - The requested resource does not exist' }) operation['responses'].setdefault('409', { 'description': 'Conflict - There is a conflict with the current state of the resource' }) elif method_lower == 'delete': operation['responses'].setdefault('404', { 'description': 'Not Found - The requested resource does not exist' }) def _add_code_samples(operation, path, method, schemas_components=None): """Add cURL code samples to operation description""" method_upper = method.upper() live_url = getattr(settings, 'LIVE_URL', 'http://localhost:8000') full_url = f"{live_url}{path}" # Build cURL command curl_parts = [f"curl -X {method_upper}"] # Add URL curl_parts.append(f'"{full_url}"') # Add authentication curl_parts.append('-u "username:password"') # Add headers curl_parts.append('-H "Accept: application/json"') # Add body for POST/PUT/PATCH if method_upper in ['POST', 'PUT', 'PATCH']: curl_parts.append('-H "Content-Type: application/json"') # Try to get request body example request_body = operation.get('requestBody', {}) content = request_body.get('content', {}) json_content = content.get('application/json', {}) body_example = None # Try to get example from schema if 'example' in json_content: body_example = json_content['example'] elif 'examples' in json_content: # Get first example examples = json_content['examples'] if examples: first_example_key = list(examples.keys())[0] body_example = examples[first_example_key].get('value') # If no example found, try schema or generate from schema if not body_example: schema = json_content.get('schema', {}) if 'example' in schema: body_example = schema['example'] else: # Generate example from schema structure body_example = _generate_example_from_schema(schema, schemas_components) if body_example: body_json = json.dumps(body_example, indent=2) # Escape single quotes in JSON for shell body_json_escaped = body_json.replace("'", "'\"'\"'") curl_parts.append(f"-d '{body_json_escaped}'") curl_command = ' \\\n '.join(curl_parts) # Log for debugging if method_upper in ['POST', 'PUT', 'PATCH']: logging.debug(f"Generated cURL for {path} ({method_upper}): has body data: {'-d' in curl_command}") # Add as formatted text in a custom vendor extension that won't be displayed # But add it to description with proper formatting current_description = operation.get('description', '') # Use Markdown code fence with explicit language curl_markdown = f""" ## Example Request ```bash {curl_command} ``` """ operation['description'] = current_description + curl_markdown def _generate_example_from_schema(schema, schemas_components=None, depth=0): """Generate example JSON from schema""" # Prevent infinite recursion if depth > 5: return None if '$ref' in schema: # Resolve reference ref_path = schema['$ref'] if ref_path.startswith('#/components/schemas/'): schema_name = ref_path.replace('#/components/schemas/', '') if schemas_components and schema_name in schemas_components: resolved_schema = schemas_components[schema_name] return _generate_example_from_schema(resolved_schema, schemas_components, depth + 1) return None # Check if schema has an example if 'example' in schema: return schema['example'] schema_type = schema.get('type') if schema_type == 'object': properties = schema.get('properties', {}) example = {} for prop_name, prop_schema in properties.items(): if prop_schema.get('readOnly'): continue value = _generate_example_value(prop_schema, schemas_components, depth + 1) if value is not None: example[prop_name] = value return example return _generate_example_value(schema, schemas_components, depth) def _generate_example_value(schema, schemas_components=None, depth=0): """Generate example value for a schema field""" # Prevent infinite recursion if depth > 5: return None # Handle $ref if '$ref' in schema: ref_path = schema['$ref'] if ref_path.startswith('#/components/schemas/'): schema_name = ref_path.replace('#/components/schemas/', '') if schemas_components and schema_name in schemas_components: resolved_schema = schemas_components[schema_name] return _generate_example_from_schema(resolved_schema, schemas_components, depth + 1) return None if 'example' in schema: return schema['example'] schema_type = schema.get('type') schema_format = schema.get('format') if schema_type == 'string': if schema_format == 'email': return 'user@example.com' elif schema_format == 'date': return '2024-01-01' elif schema_format == 'date-time': return '2024-01-01T12:00:00Z' elif schema_format == 'uuid': return '123e4567-e89b-12d3-a456-426614174000' else: return 'string' elif schema_type == 'integer': return 1 elif schema_type == 'number': return 1.0 elif schema_type == 'boolean': return True elif schema_type == 'array': items_schema = schema.get('items', {}) item_example = _generate_example_value(items_schema, schemas_components, depth + 1) return [item_example] if item_example is not None else [] elif schema_type == 'object': properties = schema.get('properties', {}) example = {} for prop_name, prop_schema in properties.items(): if prop_schema.get('readOnly'): continue value = _generate_example_value(prop_schema, schemas_components, depth + 1) if value is not None: example[prop_name] = value return example return None def _add_pagination_examples(result): """Add pagination URL examples to all paginated response schemas""" if 'components' not in result or 'schemas' not in result['components']: return # Build mapping from schema name to endpoint path schema_to_path = {} for path, path_item in result.get('paths', {}).items(): # Only process list endpoints (no {id} in path) if '{' in path: continue for method, operation in path_item.items(): if method.lower() != 'get': continue # Get the response schema reference responses = operation.get('responses', {}) success_response = responses.get('200', {}) content = success_response.get('content', {}) json_content = content.get('application/json', {}) schema_ref = json_content.get('schema', {}).get('$ref', '') # Extract schema name from reference like "#/components/schemas/PaginatedProductList" if schema_ref.startswith('#/components/schemas/'): schema_name = schema_ref.replace('#/components/schemas/', '') schema_to_path[schema_name] = path # Add examples to paginated schemas live_url = getattr(settings, 'LIVE_URL', 'http://localhost:8000') for schema_name, schema in result['components']['schemas'].items(): if not schema_name.startswith('Paginated') or 'properties' not in schema: continue # Get the actual endpoint path for this schema endpoint_path = schema_to_path.get(schema_name, '/api/v1/.../') if 'next' in schema['properties']: schema['properties']['next']['example'] = f"{live_url}{endpoint_path}?page=" schema['properties']['next']['nullable'] = True if 'previous' in schema['properties']: schema['properties']['previous']['example'] = f"{live_url}{endpoint_path}?page=" schema['properties']['previous']['nullable'] = True