Source code for kanoa.core.token_guard

"""
Token counting and cost guardrails for API calls.

This module provides pre-flight token counting and user-friendly guardrails
to prevent unexpected costs from large API requests.

Features:
- Pre-flight token counting before API calls
- Backend-agnostic design (protocol-based)
- Configurable thresholds for warnings, approval prompts, and hard limits
- Cost estimation based on current pricing
- Jupyter-friendly interactive approval
- Environment variable overrides for automation

Usage:
    from kanoa.backends.gemini import GeminiTokenCounter
    from kanoa.core.token_guard import TokenGuard

    # Create backend-specific counter
    counter = GeminiTokenCounter(client, model="gemini-3-pro-preview")

    # Wrap with guard
    guard = TokenGuard(counter)

    # Check before API call
    result = guard.check(contents, pricing=PRICING)

    if result.requires_approval and not result.approved:
        raise TokenLimitExceeded(result.message)

    # Proceed with API call...
"""

import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Optional, Protocol

from ..config import options

# Default thresholds (tokens)
# Now controlled via kanoa.options
DEFAULT_WARN_THRESHOLD = 2048
DEFAULT_APPROVAL_THRESHOLD = 50_000
DEFAULT_REJECT_THRESHOLD = 200_000


# =============================================================================
# Token Counter Protocol & Base Class
# =============================================================================


[docs] class TokenCounter(Protocol): """Protocol for backend-agnostic token counting.""" @property def backend_name(self) -> str: """Return the backend name (e.g., 'gemini', 'claude').""" ... @property def model(self) -> str: """Return the model name.""" ...
[docs] def count_tokens(self, contents: Any) -> int: """ Count tokens for the given contents. Args: contents: Content to count (format varies by backend) Returns: Token count """ ...
[docs] def estimate_tokens(self, contents: Any) -> int: """ Fallback estimation when API counting fails. Args: contents: Content to estimate Returns: Estimated token count """ ...
[docs] class BaseTokenCounter(ABC): """Base class for token counters with shared functionality.""" @property @abstractmethod def backend_name(self) -> str: """Return the backend name.""" ... @property @abstractmethod def model(self) -> str: """Return the model name.""" ...
[docs] @abstractmethod def count_tokens(self, contents: Any) -> int: """Count tokens using the backend API.""" ...
[docs] def estimate_tokens(self, contents: Any) -> int: """Fallback token estimation based on content size (~4 chars per token).""" if isinstance(contents, str): return len(contents) // 4 if isinstance(contents, list): total = 0 for item in contents: if isinstance(item, str): total += len(item) // 4 elif isinstance(item, dict): # Handle message dicts (Claude format) content = item.get("content", "") if isinstance(content, str): total += len(content) // 4 elif isinstance(content, list): for part in content: if isinstance(part, dict): text = part.get("text", "") total += len(text) // 4 elif hasattr(item, "parts"): # Handle Gemini Content objects for part in item.parts: if hasattr(part, "text"): total += len(part.text) // 4 return total return 0
class FallbackTokenCounter(BaseTokenCounter): """Fallback counter that only uses estimation (no API calls).""" def __init__(self, backend_name: str = "unknown", model: str = "unknown"): self._backend_name = backend_name self._model = model @property def backend_name(self) -> str: return self._backend_name @property def model(self) -> str: return self._model def count_tokens(self, contents: Any) -> int: """Use estimation only (no API available).""" return self.estimate_tokens(contents) # ============================================================================= # Token Guard Result & Exception # =============================================================================
[docs] @dataclass class TokenCheckResult: """Result of a token count check.""" token_count: int estimated_cost: float level: str # "ok", "warn", "approval", "reject" approved: bool message: str requires_approval: bool = False def __str__(self) -> str: return ( f"TokenCheck: {self.token_count:,} tokens, " f"~${self.estimated_cost:.4f}, level={self.level}" )
[docs] class TokenLimitExceeded(Exception): """Raised when token count exceeds the reject threshold."""
[docs] def __init__(self, token_count: int, limit: int, estimated_cost: float): self.token_count = token_count self.limit = limit self.estimated_cost = estimated_cost super().__init__( f"Token count ({token_count:,}) exceeds limit ({limit:,}). " f"Estimated cost: ${estimated_cost:.4f}. " f"Set KANOA_TOKEN_REJECT_THRESHOLD to increase limit." )
# ============================================================================= # Token Guard # =============================================================================
[docs] class TokenGuard: """ Pre-flight token counting and cost guardrails. Provides configurable thresholds for: - Warnings: Log a warning but proceed - Approval: Prompt user for confirmation (Jupyter-friendly) - Rejection: Hard limit that blocks the request All thresholds can be overridden via environment variables. """
[docs] def __init__( self, counter: TokenCounter, warn_threshold: Optional[int] = None, approval_threshold: Optional[int] = None, reject_threshold: Optional[int] = None, auto_approve: bool = False, ): """ Initialize TokenGuard. Args: counter: TokenCounter instance (backend-specific) warn_threshold: Token count to trigger warning (default: 10K) approval_threshold: Token count to require approval (default: 50K) reject_threshold: Token count to reject request (default: 200K) auto_approve: Skip interactive prompts (for automation) """ self.counter = counter # Load thresholds from args, env vars, or kanoa.options self.warn_threshold = ( warn_threshold or int(os.environ.get("KANOA_TOKEN_WARN_THRESHOLD", "0")) or options.token_warn_threshold ) self.approval_threshold = ( approval_threshold or int(os.environ.get("KANOA_TOKEN_APPROVAL_THRESHOLD", "0")) or options.token_approval_threshold ) self.reject_threshold = ( reject_threshold or int(os.environ.get("KANOA_TOKEN_REJECT_THRESHOLD", "0")) or options.token_reject_threshold ) # Auto-approve can be set via env var or options self.auto_approve = ( auto_approve or os.environ.get("KANOA_AUTO_APPROVE") == "1" or options.auto_approve )
# Expose counter properties for convenience @property def backend_name(self) -> str: """Return the backend name from the counter.""" return self.counter.backend_name @property def model(self) -> str: """Return the model name from the counter.""" return self.counter.model
[docs] def count_tokens(self, contents: Any) -> int: """ Count tokens using the configured counter. Args: contents: Content to count (format varies by backend) Returns: Token count """ return self.counter.count_tokens(contents)
[docs] def estimate_cost( self, token_count: int, pricing: Dict[str, float], context_threshold: int = 200_000, ) -> float: """ Estimate cost for input tokens based on pricing. Args: token_count: Number of input tokens pricing: Pricing dict with 'input_short', 'input_long' keys context_threshold: Threshold for short vs long context pricing Returns: Estimated cost in dollars """ if token_count <= context_threshold: price_per_million = pricing.get("input_short", 2.00) else: price_per_million = pricing.get("input_long", 4.00) return token_count / 1_000_000 * price_per_million
[docs] def check( self, contents: Any, pricing: Optional[Dict[str, float]] = None, ) -> TokenCheckResult: """ Check token count and determine if request should proceed. Args: contents: Content to check pricing: Optional pricing dict for cost estimation Returns: TokenCheckResult with approval status and message """ pricing = pricing or {"input_short": 2.00, "input_long": 4.00} # Count tokens token_count = self.count_tokens(contents) estimated_cost = self.estimate_cost(token_count, pricing) # Determine level if token_count >= self.reject_threshold: return TokenCheckResult( token_count=token_count, estimated_cost=estimated_cost, level="reject", approved=False, requires_approval=True, message=( f"Request rejected: {token_count:,} tokens exceeds " f"limit of {self.reject_threshold:,}. " f"Estimated cost: ${estimated_cost:.4f}" ), ) if token_count >= self.approval_threshold: # Need user approval if self.auto_approve: return TokenCheckResult( token_count=token_count, estimated_cost=estimated_cost, level="approval", approved=True, requires_approval=False, message=( f"Large request auto-approved: {token_count:,} tokens, " f"~${estimated_cost:.4f}" ), ) else: # Interactive approval approved = self._request_approval(token_count, estimated_cost) return TokenCheckResult( token_count=token_count, estimated_cost=estimated_cost, level="approval", approved=approved, requires_approval=True, message=( f"Large request {'approved' if approved else 'denied'}: " f"{token_count:,} tokens, ~${estimated_cost:.4f}" ), ) if token_count >= self.warn_threshold: return TokenCheckResult( token_count=token_count, estimated_cost=estimated_cost, level="warn", approved=True, requires_approval=False, message=f"{token_count:,} tokens, ~${estimated_cost:.4f}", ) # OK - under all thresholds return TokenCheckResult( token_count=token_count, estimated_cost=estimated_cost, level="ok", approved=True, requires_approval=False, message=f"{token_count:,} tokens, ~${estimated_cost:.4f}", )
def _request_approval(self, token_count: int, estimated_cost: float) -> bool: """ Request user approval for large requests. Jupyter-friendly: uses input() which works in notebooks. Args: token_count: Number of tokens estimated_cost: Estimated cost in dollars Returns: True if approved, False otherwise """ print("\n" + "=" * 60) print("LARGE TOKEN REQUEST - APPROVAL REQUIRED") print("=" * 60) print(f" Token count: {token_count:,}") print(f" Estimated cost: ${estimated_cost:.4f}") print(f" Approval limit: {self.approval_threshold:,} tokens") print("=" * 60) try: response = input("Proceed with this request? [y/N]: ").strip().lower() approved = response in ("y", "yes") if not approved: print("Request cancelled by user.") return approved except (EOFError, KeyboardInterrupt): print("\nRequest cancelled.") return False
[docs] def guard( self, contents: Any, pricing: Optional[Dict[str, float]] = None, ) -> TokenCheckResult: """ Check tokens and raise exception if rejected. Convenience method that combines check() with automatic rejection. Args: contents: Content to check pricing: Optional pricing dict Returns: TokenCheckResult if approved Raises: TokenLimitExceeded: If request exceeds reject threshold or user denies """ result = self.check(contents, pricing) if result.level == "reject" or ( result.requires_approval and not result.approved ): raise TokenLimitExceeded( result.token_count, ( self.reject_threshold if result.level == "reject" else self.approval_threshold ), result.estimated_cost, ) # Log warning/info if result.level == "warn": print(result.message) return result
# ============================================================================= # Public API # ============================================================================= __all__ = [ # Core classes "TokenGuard", "TokenCheckResult", "TokenLimitExceeded", # Protocol and base class "TokenCounter", "BaseTokenCounter", # Fallback counter (backend-agnostic) "FallbackTokenCounter", # Constants "DEFAULT_WARN_THRESHOLD", "DEFAULT_APPROVAL_THRESHOLD", "DEFAULT_REJECT_THRESHOLD", ]