Build Scalable APIs with AWS API Gateway

David Childs

Master AWS API Gateway with advanced patterns for authentication, rate limiting, caching, and serverless API development at scale.

API Gateway promises serverless APIs that scale infinitely, but the reality is more complex. After building APIs handling millions of requests daily, I've learned the patterns that separate hobby projects from production-grade serverless APIs. Here's your complete guide to API Gateway mastery.

API Gateway Architecture Patterns

REST API vs HTTP API Decision Tree

# api_selector.py
def choose_api_type(requirements):
    """Help choose between REST API and HTTP API"""
    
    # HTTP API is better for:
    http_api_score = 0
    # Lower cost (70% cheaper)
    if requirements.get('cost_sensitive'):
        http_api_score += 3
    # Simple JWT authorization
    if requirements.get('auth_type') == 'jwt':
        http_api_score += 2
    # Lower latency requirements
    if requirements.get('latency_critical'):
        http_api_score += 2
    
    # REST API is better for:
    rest_api_score = 0
    # API Keys required
    if requirements.get('api_keys'):
        rest_api_score += 3
    # Request/response transformation
    if requirements.get('transformation'):
        rest_api_score += 2
    # AWS WAF integration
    if requirements.get('waf_required'):
        rest_api_score += 3
    # Caching needed
    if requirements.get('caching'):
        rest_api_score += 2
    # Complex authorization (Cognito, IAM)
    if requirements.get('auth_type') in ['cognito', 'iam']:
        rest_api_score += 2
    
    return {
        'recommendation': 'HTTP API' if http_api_score > rest_api_score else 'REST API',
        'http_api_score': http_api_score,
        'rest_api_score': rest_api_score
    }

Terraform Configuration for Production API

# api_gateway.tf
resource "aws_api_gateway_rest_api" "main" {
  name        = "production-api"
  description = "Production REST API"
  
  endpoint_configuration {
    types = ["EDGE"]  # or REGIONAL, PRIVATE
  }
  
  # Binary media types for file uploads
  binary_media_types = [
    "image/*",
    "application/pdf",
    "application/octet-stream"
  ]
}

# Request Validator
resource "aws_api_gateway_request_validator" "main" {
  name                        = "request-validator"
  rest_api_id                = aws_api_gateway_rest_api.main.id
  validate_request_body       = true
  validate_request_parameters = true
}

# API Model for validation
resource "aws_api_gateway_model" "user" {
  rest_api_id  = aws_api_gateway_rest_api.main.id
  name         = "User"
  content_type = "application/json"
  
  schema = jsonencode({
    "$schema" = "http://json-schema.org/draft-04/schema#"
    title     = "User"
    type      = "object"
    required  = ["email", "name"]
    properties = {
      email = {
        type    = "string"
        format  = "email"
      }
      name = {
        type      = "string"
        minLength = 1
        maxLength = 100
      }
      age = {
        type    = "integer"
        minimum = 0
        maximum = 150
      }
    }
  })
}

# Method with validation
resource "aws_api_gateway_method" "create_user" {
  rest_api_id   = aws_api_gateway_rest_api.main.id
  resource_id   = aws_api_gateway_resource.users.id
  http_method   = "POST"
  authorization = "COGNITO_USER_POOLS"
  authorizer_id = aws_api_gateway_authorizer.cognito.id
  
  request_validator_id = aws_api_gateway_request_validator.main.id
  
  request_models = {
    "application/json" = aws_api_gateway_model.user.name
  }
}

Authentication and Authorization

Cognito Authorizer

# cognito_auth.py
import boto3
import jwt
from jwt.algorithms import RSAAlgorithm
import requests
from functools import wraps

class CognitoAuthorizer:
    def __init__(self, user_pool_id, region):
        self.user_pool_id = user_pool_id
        self.region = region
        self.jwks_url = f'https://cognito-idp.{region}.amazonaws.com/{user_pool_id}/.well-known/jwks.json'
        self.jwks = requests.get(self.jwks_url).json()
        
    def verify_token(self, token):
        """Verify and decode Cognito JWT token"""
        try:
            # Get the key id from the token header
            unverified_header = jwt.get_unverified_header(token)
            kid = unverified_header['kid']
            
            # Find the key in JWKS
            key = None
            for k in self.jwks['keys']:
                if k['kid'] == kid:
                    key = RSAAlgorithm.from_jwk(k)
                    break
            
            if not key:
                raise ValueError('Public key not found')
            
            # Verify the token
            decoded = jwt.decode(
                token,
                key,
                algorithms=['RS256'],
                audience=self.user_pool_id,
                options={"verify_exp": True}
            )
            
            return decoded
            
        except Exception as e:
            raise ValueError(f'Token verification failed: {str(e)}')
    
    def authorize(self, required_groups=None):
        """Decorator for Lambda function authorization"""
        def decorator(func):
            @wraps(func)
            def wrapper(event, context):
                # Extract token from event
                token = event['headers'].get('Authorization', '').replace('Bearer ', '')
                
                if not token:
                    return {
                        'statusCode': 401,
                        'body': json.dumps({'error': 'No authorization token'})
                    }
                
                try:
                    # Verify token
                    claims = self.verify_token(token)
                    
                    # Check groups if required
                    if required_groups:
                        user_groups = claims.get('cognito:groups', [])
                        if not any(group in user_groups for group in required_groups):
                            return {
                                'statusCode': 403,
                                'body': json.dumps({'error': 'Insufficient permissions'})
                            }
                    
                    # Add user context to event
                    event['requestContext']['authorizer'] = {
                        'claims': claims,
                        'userId': claims['sub'],
                        'email': claims.get('email')
                    }
                    
                    return func(event, context)
                    
                except ValueError as e:
                    return {
                        'statusCode': 401,
                        'body': json.dumps({'error': str(e)})
                    }
            
            return wrapper
        return decorator

Lambda Authorizer with Caching

# lambda_authorizer.py
import json
import time
import hmac
import hashlib
import base64

def lambda_handler(event, context):
    """Custom Lambda authorizer with policy caching"""
    
    token = event['authorizationToken']
    method_arn = event['methodArn']
    
    try:
        # Validate token (implement your logic)
        principal_id, policies = validate_token(token)
        
        # Generate policy with caching
        policy = generate_policy(
            principal_id,
            'Allow',
            method_arn,
            context={
                'userId': principal_id,
                'expiresAt': str(int(time.time()) + 3600)
            }
        )
        
        return policy
        
    except Exception as e:
        # Deny access
        raise Exception('Unauthorized')

def validate_token(token):
    """Validate API token and return principal ID and policies"""
    
    # Example: HMAC-based token validation
    parts = token.split('.')
    if len(parts) != 3:
        raise ValueError('Invalid token format')
    
    client_id = parts[0]
    timestamp = parts[1]
    signature = parts[2]
    
    # Check timestamp (prevent replay attacks)
    if abs(int(timestamp) - int(time.time())) > 300:  # 5 minutes
        raise ValueError('Token expired')
    
    # Verify signature
    secret = get_client_secret(client_id)  # Fetch from Secrets Manager
    expected_signature = hmac.new(
        secret.encode(),
        f"{client_id}.{timestamp}".encode(),
        hashlib.sha256
    ).hexdigest()
    
    if not hmac.compare_digest(signature, expected_signature):
        raise ValueError('Invalid signature')
    
    return client_id, get_client_policies(client_id)

def generate_policy(principal_id, effect, resource, context=None):
    """Generate IAM policy for API Gateway"""
    
    policy = {
        'principalId': principal_id,
        'policyDocument': {
            'Version': '2012-10-17',
            'Statement': [
                {
                    'Action': 'execute-api:Invoke',
                    'Effect': effect,
                    'Resource': resource
                }
            ]
        }
    }
    
    if context:
        policy['context'] = context
    
    # Enable policy caching (5 minutes)
    policy['usageIdentifierKey'] = principal_id
    
    return policy

Rate Limiting and Throttling

Usage Plans and API Keys

# usage_plan_manager.py
import boto3

class UsagePlanManager:
    def __init__(self):
        self.apigateway = boto3.client('apigateway')
        
    def create_tiered_usage_plans(self, api_id, stage_name):
        """Create tiered usage plans for different customer segments"""
        
        plans = [
            {
                'name': 'basic',
                'description': 'Basic tier - 1000 requests/day',
                'throttle': {
                    'rateLimit': 10,
                    'burstLimit': 20
                },
                'quota': {
                    'limit': 1000,
                    'period': 'DAY'
                }
            },
            {
                'name': 'premium',
                'description': 'Premium tier - 10000 requests/day',
                'throttle': {
                    'rateLimit': 100,
                    'burstLimit': 200
                },
                'quota': {
                    'limit': 10000,
                    'period': 'DAY'
                }
            },
            {
                'name': 'enterprise',
                'description': 'Enterprise tier - Unlimited',
                'throttle': {
                    'rateLimit': 1000,
                    'burstLimit': 2000
                }
                # No quota for enterprise
            }
        ]
        
        created_plans = []
        
        for plan in plans:
            # Create usage plan
            response = self.apigateway.create_usage_plan(
                name=plan['name'],
                description=plan['description'],
                apiStages=[{
                    'apiId': api_id,
                    'stage': stage_name,
                    'throttle': {
                        '*/*': plan['throttle']  # Default throttle for all methods
                    }
                }],
                throttle=plan['throttle'],
                quota=plan.get('quota')
            )
            
            created_plans.append(response)
        
        return created_plans
    
    def create_api_key_for_customer(self, customer_id, usage_plan_id):
        """Create API key for a customer and associate with usage plan"""
        
        # Create API key
        api_key = self.apigateway.create_api_key(
            name=f'customer-{customer_id}',
            description=f'API key for customer {customer_id}',
            enabled=True,
            tags={
                'CustomerId': customer_id
            }
        )
        
        # Associate with usage plan
        self.apigateway.create_usage_plan_key(
            usagePlanId=usage_plan_id,
            keyId=api_key['id'],
            keyType='API_KEY'
        )
        
        return api_key

Per-Method Throttling

# method_throttling.tf
resource "aws_api_gateway_method_settings" "settings" {
  rest_api_id = aws_api_gateway_rest_api.main.id
  stage_name  = aws_api_gateway_deployment.main.stage_name
  method_path = "*/*"  # Apply to all methods
  
  settings {
    metrics_enabled = true
    logging_level   = "INFO"
    data_trace_enabled = true
    
    # Default throttling
    throttling_rate_limit  = 1000
    throttling_burst_limit = 2000
    
    # Caching
    caching_enabled      = true
    cache_ttl_in_seconds = 300
    cache_data_encrypted = true
    require_authorization_for_cache_control = true
  }
}

# Per-method override
resource "aws_api_gateway_method_settings" "heavy_endpoint" {
  rest_api_id = aws_api_gateway_rest_api.main.id
  stage_name  = aws_api_gateway_deployment.main.stage_name
  method_path = "data/GET"  # Specific method
  
  settings {
    throttling_rate_limit  = 100   # Lower limit for heavy endpoint
    throttling_burst_limit = 200
    caching_enabled        = true
    cache_ttl_in_seconds  = 3600  # Longer cache for expensive operation
  }
}

Response Caching

Cache Key Configuration

# cache_manager.py
def configure_cache_key(method_arn, parameters):
    """Configure cache key parameters for API Gateway method"""
    
    apigateway = boto3.client('apigateway')
    
    # Parse method ARN
    parts = method_arn.split(':')
    region = parts[3]
    api_id = parts[5].split('/')[0]
    
    # Configure cache key parameters
    apigateway.update_method(
        restApiId=api_id,
        resourceId=resource_id,
        httpMethod='GET',
        patchOperations=[
            {
                'op': 'add',
                'path': '/requestParameters/method.request.querystring.category',
                'value': 'true'
            },
            {
                'op': 'add',
                'path': '/requestParameters/method.request.header.Accept-Language',
                'value': 'true'
            }
        ]
    )
    
    # Configure integration cache key parameters
    apigateway.update_integration(
        restApiId=api_id,
        resourceId=resource_id,
        httpMethod='GET',
        patchOperations=[
            {
                'op': 'add',
                'path': '/cacheKeyParameters',
                'value': 'method.request.querystring.category'
            },
            {
                'op': 'add',
                'path': '/cacheKeyParameters',
                'value': 'method.request.header.Accept-Language'
            }
        ]
    )

Cache Invalidation Strategy

# cache_invalidation.py
import hashlib

def invalidate_cache(api_id, stage_name, path):
    """Invalidate API Gateway cache for specific paths"""
    
    apigateway = boto3.client('apigateway')
    
    # Flush entire stage cache
    apigateway.flush_stage_cache(
        restApiId=api_id,
        stageName=stage_name
    )
    
    # For selective invalidation, use cache-busting headers
    def generate_cache_buster():
        """Generate cache buster value"""
        timestamp = str(int(time.time()))
        return hashlib.md5(timestamp.encode()).hexdigest()[:8]
    
    # Include in request
    headers = {
        'Cache-Control': 'no-cache',
        'X-Cache-Buster': generate_cache_buster()
    }
    
    return headers

Request/Response Transformation

Mapping Templates

# mapping_templates.py
def create_request_mapping_template():
    """Create VTL template for request transformation"""
    
    template = """
#set($inputRoot = $input.path('$'))
{
    "method": "$context.httpMethod",
    "body": $input.json('$'),
    "headers": {
        #foreach($header in $input.params().header.keySet())
        "$header": "$util.escapeJavaScript($input.params().header.get($header))"
        #if($foreach.hasNext),#end
        #end
    },
    "queryParams": {
        #foreach($queryParam in $input.params().querystring.keySet())
        "$queryParam": "$util.escapeJavaScript($input.params().querystring.get($queryParam))"
        #if($foreach.hasNext),#end
        #end
    },
    "pathParams": {
        #foreach($pathParam in $input.params().path.keySet())
        "$pathParam": "$util.escapeJavaScript($input.params().path.get($pathParam))"
        #if($foreach.hasNext),#end
        #end
    },
    "context": {
        "accountId": "$context.accountId",
        "apiId": "$context.apiId",
        "requestId": "$context.requestId",
        "requestTime": "$context.requestTime",
        "sourceIp": "$context.identity.sourceIp",
        "userAgent": "$context.identity.userAgent"
    }
}
"""
    return template

def create_response_mapping_template():
    """Create VTL template for response transformation"""
    
    template = """
#set($inputRoot = $input.path('$'))
#if($inputRoot.statusCode == 200)
{
    "success": true,
    "data": $inputRoot.body,
    "metadata": {
        "timestamp": "$context.requestTime",
        "requestId": "$context.requestId"
    }
}
#else
{
    "success": false,
    "error": {
        "message": "$inputRoot.errorMessage",
        "type": "$inputRoot.errorType",
        "requestId": "$context.requestId"
    }
}
#end
"""
    return template

CORS Configuration

Comprehensive CORS Setup

# cors_configuration.py
def setup_cors(api_id, resource_id):
    """Configure CORS for API Gateway resource"""
    
    apigateway = boto3.client('apigateway')
    
    # Add OPTIONS method
    apigateway.put_method(
        restApiId=api_id,
        resourceId=resource_id,
        httpMethod='OPTIONS',
        authorizationType='NONE'
    )
    
    # Configure mock integration for OPTIONS
    apigateway.put_integration(
        restApiId=api_id,
        resourceId=resource_id,
        httpMethod='OPTIONS',
        type='MOCK',
        requestTemplates={
            'application/json': '{"statusCode": 200}'
        }
    )
    
    # Configure method response
    apigateway.put_method_response(
        restApiId=api_id,
        resourceId=resource_id,
        httpMethod='OPTIONS',
        statusCode='200',
        responseParameters={
            'method.response.header.Access-Control-Allow-Headers': True,
            'method.response.header.Access-Control-Allow-Methods': True,
            'method.response.header.Access-Control-Allow-Origin': True,
            'method.response.header.Access-Control-Max-Age': True
        }
    )
    
    # Configure integration response
    apigateway.put_integration_response(
        restApiId=api_id,
        resourceId=resource_id,
        httpMethod='OPTIONS',
        statusCode='200',
        responseParameters={
            'method.response.header.Access-Control-Allow-Headers': "'Content-Type,X-Amz-Date,Authorization,X-Api-Key,X-Amz-Security-Token'",
            'method.response.header.Access-Control-Allow-Methods': "'GET,POST,PUT,DELETE,OPTIONS'",
            'method.response.header.Access-Control-Allow-Origin': "'*'",
            'method.response.header.Access-Control-Max-Age': "'86400'"
        }
    )
    
    # Add CORS headers to actual methods
    for method in ['GET', 'POST', 'PUT', 'DELETE']:
        try:
            apigateway.put_method_response(
                restApiId=api_id,
                resourceId=resource_id,
                httpMethod=method,
                statusCode='200',
                responseParameters={
                    'method.response.header.Access-Control-Allow-Origin': True
                }
            )
            
            apigateway.put_integration_response(
                restApiId=api_id,
                resourceId=resource_id,
                httpMethod=method,
                statusCode='200',
                responseParameters={
                    'method.response.header.Access-Control-Allow-Origin': "'*'"
                }
            )
        except:
            pass  # Method might not exist

WebSocket APIs

WebSocket Connection Manager

# websocket_manager.py
import boto3
import json

class WebSocketManager:
    def __init__(self, table_name):
        self.dynamodb = boto3.resource('dynamodb')
        self.table = self.dynamodb.Table(table_name)
        self.apigateway = boto3.client('apigatewaymanagementapi')
        
    def on_connect(self, connection_id, request_context):
        """Handle new WebSocket connection"""
        
        # Store connection info
        self.table.put_item(
            Item={
                'connectionId': connection_id,
                'connectedAt': request_context['connectedAt'],
                'sourceIp': request_context['identity']['sourceIp'],
                'userAgent': request_context['identity']['userAgent'],
                'ttl': int(time.time()) + 86400  # 24 hour TTL
            }
        )
        
        return {'statusCode': 200}
    
    def on_disconnect(self, connection_id):
        """Handle WebSocket disconnection"""
        
        # Remove connection
        self.table.delete_item(
            Key={'connectionId': connection_id}
        )
        
        return {'statusCode': 200}
    
    def on_message(self, connection_id, message):
        """Handle WebSocket message"""
        
        try:
            data = json.loads(message)
            action = data.get('action')
            
            if action == 'broadcast':
                # Broadcast to all connections
                self.broadcast_message(data['message'], exclude=connection_id)
            elif action == 'direct':
                # Send to specific connection
                self.send_message(data['targetId'], data['message'])
            
            return {'statusCode': 200}
            
        except Exception as e:
            return {
                'statusCode': 400,
                'body': json.dumps({'error': str(e)})
            }
    
    def broadcast_message(self, message, exclude=None):
        """Broadcast message to all connections"""
        
        # Get all connections
        connections = self.table.scan()['Items']
        
        # Send to each connection
        for connection in connections:
            if connection['connectionId'] != exclude:
                try:
                    self.send_message(connection['connectionId'], message)
                except:
                    # Connection is stale, remove it
                    self.on_disconnect(connection['connectionId'])
    
    def send_message(self, connection_id, message):
        """Send message to specific connection"""
        
        self.apigateway.post_to_connection(
            ConnectionId=connection_id,
            Data=json.dumps(message)
        )

Monitoring and Logging

CloudWatch Integration

# monitoring.py
def setup_api_monitoring(api_name, stage_name):
    """Setup comprehensive monitoring for API Gateway"""
    
    cloudwatch = boto3.client('cloudwatch')
    
    # Create dashboard
    dashboard_body = {
        "widgets": [
            {
                "type": "metric",
                "properties": {
                    "metrics": [
                        ["AWS/ApiGateway", "Count", {"stat": "Sum", "label": "Total Requests"}],
                        [".", "4XXError", {"stat": "Sum", "label": "4XX Errors"}],
                        [".", "5XXError", {"stat": "Sum", "label": "5XX Errors"}]
                    ],
                    "period": 300,
                    "stat": "Average",
                    "region": "us-east-1",
                    "title": "API Requests"
                }
            },
            {
                "type": "metric",
                "properties": {
                    "metrics": [
                        ["AWS/ApiGateway", "Latency", {"stat": "Average"}],
                        [".", ".", {"stat": "p99"}]
                    ],
                    "period": 300,
                    "stat": "Average",
                    "region": "us-east-1",
                    "title": "API Latency"
                }
            }
        ]
    }
    
    cloudwatch.put_dashboard(
        DashboardName=f'{api_name}-dashboard',
        DashboardBody=json.dumps(dashboard_body)
    )
    
    # Create alarms
    alarms = [
        {
            'name': f'{api_name}-high-4xx-errors',
            'metric': '4XXError',
            'threshold': 100,
            'comparison': 'GreaterThanThreshold'
        },
        {
            'name': f'{api_name}-high-5xx-errors',
            'metric': '5XXError',
            'threshold': 10,
            'comparison': 'GreaterThanThreshold'
        },
        {
            'name': f'{api_name}-high-latency',
            'metric': 'Latency',
            'threshold': 1000,
            'comparison': 'GreaterThanThreshold',
            'statistic': 'Average'
        }
    ]
    
    for alarm in alarms:
        cloudwatch.put_metric_alarm(
            AlarmName=alarm['name'],
            ComparisonOperator=alarm['comparison'],
            EvaluationPeriods=2,
            MetricName=alarm['metric'],
            Namespace='AWS/ApiGateway',
            Period=300,
            Statistic=alarm.get('statistic', 'Sum'),
            Threshold=alarm['threshold'],
            Dimensions=[
                {'Name': 'ApiName', 'Value': api_name},
                {'Name': 'Stage', 'Value': stage_name}
            ]
        )

Best Practices Checklist

  • Choose the right API type (REST vs HTTP)
  • Implement proper authentication (Cognito, Lambda, IAM)
  • Configure request validation with models
  • Set up usage plans and API keys
  • Enable caching where appropriate
  • Configure CORS properly
  • Implement request/response transformation
  • Set up CloudWatch logging
  • Configure custom domain names
  • Enable AWS WAF for security
  • Implement proper error handling
  • Use VPC Link for private integrations
  • Monitor with X-Ray tracing
  • Set up deployment stages
  • Document with OpenAPI/Swagger

Conclusion

API Gateway is powerful but requires careful configuration for production use. Focus on security, implement proper throttling, leverage caching, and monitor everything. The key to success is understanding when to use each feature and how they interact. Start simple, validate your patterns, then scale with confidence.

Share this article

DC

David Childs

Consulting Systems Engineer with over 10 years of experience building scalable infrastructure and helping organizations optimize their technology stack.

Related Articles