"""
Risk Gate Checks

Pre-trade risk filters that must pass before any position is taken.
Checks spread, liquidity, chop, and data quality.
Per-strategy thresholds when a strategy is specified.
"""

import numpy as np
from typing import Dict, Any, Optional


def compute(
    bid: Optional[float],
    ask: Optional[float],
    orderbook_depth_usd: Optional[float],
    highs: Optional[np.ndarray] = None,
    lows: Optional[np.ndarray] = None,
    closes: Optional[np.ndarray] = None,
    is_btc: bool = False,
    required_fields: Optional[Dict[str, Any]] = None,
    strategy: Optional[str] = None,
) -> Dict[str, Any]:
    """
    Evaluate pre-trade risk gates.

    Args:
        bid: Current best bid price.
        ask: Current best ask price.
        orderbook_depth_usd: Total USD depth within 0.5% of mid price.
        highs: 1-D array of high prices (for chop index, needs >= 14 bars).
        lows: 1-D array of low prices (for chop index, needs >= 14 bars).
        closes: 1-D array of close prices (for chop index, needs >= 14 bars).
        is_btc: True if the symbol is BTC (tighter spread threshold for legacy mode).
        required_fields: Dict of field_name -> value; all must be non-None
            to pass the data quality gate.
        strategy: Strategy name ("fast", "swing", "medium", "position").
            When set, uses per-strategy thresholds. When None, uses defaults.

    Returns:
        Dict with pass/fail for each gate and an overall ``tradeable`` boolean.
    """
    gates: Dict[str, Any] = {}

    # Load per-strategy thresholds or fall back to defaults
    if strategy:
        from .strategies import get_strategy
        strat = get_strategy(strategy)
        spread_threshold_bps = float(strat["spread_bps"])
        depth_threshold = float(strat["depth_usd"])
        chop_threshold = float(strat["chop_max"])
    else:
        # Legacy defaults (10 bps BTC, 20 bps alts, $50K depth, chop <= 60)
        spread_threshold_bps = 10.0 if is_btc else 20.0
        depth_threshold = 50_000.0
        chop_threshold = 60.0

    # --- Spread gate (bps) ---
    if bid is not None and ask is not None and bid > 0:
        mid = (bid + ask) / 2
        spread_bps = (ask - bid) / mid * 10000
        gates["gate.spread_bps"] = spread_bps
        gates["gate.spread_threshold_bps"] = spread_threshold_bps
        gates["gate.spread_ok"] = spread_bps < spread_threshold_bps
    else:
        gates["gate.spread_bps"] = None
        gates["gate.spread_threshold_bps"] = spread_threshold_bps
        gates["gate.spread_ok"] = False

    # --- Liquidity gate ---
    gates["gate.depth_threshold"] = depth_threshold
    if orderbook_depth_usd is not None:
        gates["gate.depth_usd"] = orderbook_depth_usd
        gates["gate.liquidity_ok"] = orderbook_depth_usd > depth_threshold
    else:
        gates["gate.depth_usd"] = None
        gates["gate.liquidity_ok"] = False

    # --- Chop gate ---
    # Chop Index = 100 * LOG10(SUM(ATR,14) / (highest_high_14 - lowest_low_14)) / LOG10(14)
    chop_index = _chop_index(highs, lows, closes, period=14)
    gates["gate.chop_index"] = chop_index
    gates["gate.chop_threshold"] = chop_threshold
    if chop_index is not None:
        gates["gate.chop_ok"] = chop_index <= chop_threshold
    else:
        gates["gate.chop_ok"] = False

    # --- Data quality gate ---
    if required_fields is not None and len(required_fields) > 0:
        all_present = all(v is not None for v in required_fields.values())
        gates["gate.data_quality_ok"] = all_present
    else:
        # No fields to check -> pass by default
        gates["gate.data_quality_ok"] = True

    # --- Overall tradeable ---
    gates["tradeable"] = (
        gates["gate.spread_ok"]
        and gates["gate.liquidity_ok"]
        and gates["gate.chop_ok"]
        and gates["gate.data_quality_ok"]
    )

    return gates


def _chop_index(
    highs: Optional[np.ndarray],
    lows: Optional[np.ndarray],
    closes: Optional[np.ndarray],
    period: int = 14,
) -> Optional[float]:
    """
    Calculate the Chop Index over *period* bars.

    Formula:
        CI = 100 * log10(sum(ATR_1, period) / (highest_high - lowest_low)) / log10(period)

    Returns None if insufficient data.
    """
    if (
        highs is None
        or lows is None
        or closes is None
        or len(highs) < period + 1
        or len(lows) < period + 1
        or len(closes) < period + 1
    ):
        return None

    # Use last period+1 bars so we can compute period ATR values
    h = highs[-(period + 1) :]
    l = lows[-(period + 1) :]
    c = closes[-(period + 1) :]

    # True Range for each of the last *period* bars
    tr = np.maximum(
        h[1:] - l[1:],
        np.maximum(np.abs(h[1:] - c[:-1]), np.abs(l[1:] - c[:-1])),
    )

    atr_sum = np.sum(tr)
    highest_high = np.max(h[1:])
    lowest_low = np.min(l[1:])
    hl_range = highest_high - lowest_low

    if hl_range <= 0:
        return None

    chop = 100.0 * np.log10(atr_sum / hl_range) / np.log10(period)
    return float(chop)
