Source code for py_alpaca_api.http.feed_manager

from __future__ import annotations

import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, ClassVar

from py_alpaca_api.exceptions import APIRequestError, ValidationError

[docs] logger = logging.getLogger(__name__)
[docs] class FeedType(Enum): """Available data feed types."""
[docs] SIP = "sip"
[docs] IEX = "iex"
[docs] OTC = "otc"
@classmethod
[docs] def from_string(cls, value: str) -> FeedType: """Create FeedType from string value.""" try: return cls(value.lower()) except ValueError as e: raise ValidationError( f"Invalid feed type: {value}. Must be one of {[f.value for f in cls]}" ) from e
[docs] class SubscriptionLevel(Enum): """User subscription levels."""
[docs] BASIC = "basic"
[docs] UNLIMITED = "unlimited"
[docs] BUSINESS = "business"
@classmethod
[docs] def from_error(cls, error_message: str) -> SubscriptionLevel | None: """Detect subscription level from error message.""" error_lower = error_message.lower() if "subscription" in error_lower: if "unlimited" in error_lower or "business" in error_lower: return cls.UNLIMITED return cls.BASIC return None
@dataclass
[docs] class FeedConfig: """Configuration for feed management."""
[docs] preferred_feed: FeedType = FeedType.SIP
[docs] fallback_feeds: list[FeedType] = field(default_factory=lambda: [FeedType.IEX])
[docs] auto_fallback: bool = True
[docs] subscription_level: SubscriptionLevel | None = None
[docs] endpoint_feeds: dict[str, FeedType] = field(default_factory=dict)
[docs] def get_feed_for_endpoint(self, endpoint: str) -> FeedType: """Get the configured feed for a specific endpoint.""" return self.endpoint_feeds.get(endpoint, self.preferred_feed)
[docs] class FeedManager: """Manages data feed selection and fallback logic.""" # Endpoints that support feed parameter
[docs] FEED_SUPPORTED_ENDPOINTS: ClassVar[set[str]] = { "bars", "quotes", "trades", "snapshots", "latest/bars", "latest/quotes", "latest/trades", }
# Feed availability by subscription level
[docs] SUBSCRIPTION_FEEDS: ClassVar[dict[SubscriptionLevel, list[FeedType]]] = { SubscriptionLevel.BASIC: [FeedType.IEX], SubscriptionLevel.UNLIMITED: [FeedType.SIP, FeedType.IEX, FeedType.OTC], SubscriptionLevel.BUSINESS: [FeedType.SIP, FeedType.IEX, FeedType.OTC], }
def __init__(self, config: FeedConfig | None = None): """Initialize the feed manager. Args: config: Feed configuration. If None, uses defaults. """ self.config = config or FeedConfig() self._failed_feeds: dict[str, set[FeedType]] = {} self._detected_subscription_level: SubscriptionLevel | None = None
[docs] def get_feed(self, endpoint: str, symbol: str | None = None) -> str | None: """Get the appropriate feed for an endpoint. Args: endpoint: The API endpoint being called symbol: Optional symbol for endpoint-specific logic Returns: Feed parameter value or None if endpoint doesn't support feeds """ if not self._supports_feed(endpoint): return None feed = self.config.get_feed_for_endpoint(endpoint) # Check if this feed has previously failed endpoint_key = f"{endpoint}:{symbol}" if symbol else endpoint if ( endpoint_key in self._failed_feeds and feed in self._failed_feeds[endpoint_key] ): # Try to use fallback for fallback in self.config.fallback_feeds: if fallback not in self._failed_feeds.get(endpoint_key, set()): logger.info(f"Using fallback feed {fallback.value} for {endpoint}") return fallback.value return feed.value
[docs] def handle_feed_error( self, endpoint: str, feed: str, error: APIRequestError, symbol: str | None = None, ) -> str | None: """Handle feed-related errors and return alternative feed if available. Args: endpoint: The API endpoint that failed feed: The feed that caused the error error: The API error symbol: Optional symbol for endpoint-specific tracking Returns: Alternative feed to try, or None if no alternatives available """ if not self.config.auto_fallback: return None # Try to detect subscription level from error error_msg = str(error) detected_level = SubscriptionLevel.from_error(error_msg) if detected_level and not self._detected_subscription_level: self._detected_subscription_level = detected_level logger.info(f"Detected subscription level: {detected_level.value}") # Track failed feed endpoint_key = f"{endpoint}:{symbol}" if symbol else endpoint if endpoint_key not in self._failed_feeds: self._failed_feeds[endpoint_key] = set() try: feed_type = FeedType.from_string(feed) self._failed_feeds[endpoint_key].add(feed_type) logger.warning(f"Feed {feed} failed for {endpoint_key}: {error_msg}") except ValidationError: logger.exception(f"Invalid feed type in error handling: {feed}") return None # Find alternative feed for fallback in self.config.fallback_feeds: if fallback not in self._failed_feeds[ endpoint_key ] and self._is_feed_available(fallback): logger.info(f"Falling back to {fallback.value} feed for {endpoint_key}") return fallback.value logger.error(f"No alternative feeds available for {endpoint_key}") return None
[docs] def detect_subscription_level(self, api_client: Any) -> SubscriptionLevel: """Detect user's subscription level by testing API access. Args: api_client: API client instance to test with Returns: Detected subscription level """ # Try SIP feed first (requires Unlimited/Business) try: # Make a test request with SIP feed test_endpoint = "latest/quotes" test_params = {"symbols": "AAPL", "feed": FeedType.SIP.value} api_client._make_request( "GET", f"/stocks/{test_endpoint}", params=test_params ) # If successful, user has at least Unlimited self._detected_subscription_level = SubscriptionLevel.UNLIMITED logger.info("Detected Unlimited/Business subscription level") except APIRequestError as e: # SIP failed, user likely has Basic subscription if "subscription" in str(e).lower() or "unauthorized" in str(e).lower(): self._detected_subscription_level = SubscriptionLevel.BASIC logger.info("Detected Basic subscription level") else: # Unexpected error, default to Basic for safety self._detected_subscription_level = SubscriptionLevel.BASIC logger.warning( f"Could not detect subscription level: {e}. Defaulting to Basic." ) self.config.subscription_level = self._detected_subscription_level return self._detected_subscription_level
[docs] def validate_feed(self, endpoint: str, feed: str) -> bool: """Validate if a feed is appropriate for an endpoint. Args: endpoint: The API endpoint feed: The feed to validate Returns: True if feed is valid for endpoint """ if not self._supports_feed(endpoint): return False try: feed_type = FeedType.from_string(feed) except ValidationError: return False return self._is_feed_available(feed_type)
[docs] def reset_failures(self, endpoint: str | None = None) -> None: """Reset tracked feed failures. Args: endpoint: Optional endpoint to reset. If None, resets all. """ if endpoint: keys_to_remove = [ k for k in self._failed_feeds if k.startswith(f"{endpoint}:") ] for key in keys_to_remove: del self._failed_feeds[key] if endpoint in self._failed_feeds: del self._failed_feeds[endpoint] else: self._failed_feeds.clear() logger.info(f"Reset feed failures for {endpoint or 'all endpoints'}")
def _supports_feed(self, endpoint: str) -> bool: """Check if an endpoint supports feed parameter. Args: endpoint: The API endpoint Returns: True if endpoint supports feed parameter """ # Check if any supported endpoint pattern matches return any(supported in endpoint for supported in self.FEED_SUPPORTED_ENDPOINTS) def _is_feed_available(self, feed: FeedType) -> bool: """Check if a feed is available based on subscription level. Args: feed: The feed to check Returns: True if feed is available """ if not self._detected_subscription_level and not self.config.subscription_level: # If we don't know subscription level, assume all feeds available return True level = self._detected_subscription_level or self.config.subscription_level if level is None: return True available_feeds = self.SUBSCRIPTION_FEEDS.get(level, []) return feed in available_feeds
[docs] def get_available_feeds(self) -> list[FeedType]: """Get list of available feeds based on subscription level. Returns: List of available feed types """ if not self._detected_subscription_level and not self.config.subscription_level: # If unknown, return all feeds return list(FeedType) level = self._detected_subscription_level or self.config.subscription_level if level is None: return list(FeedType) return self.SUBSCRIPTION_FEEDS.get(level, [FeedType.IEX])