"""
QuantClaw Technical Indicators

Ported from BebBot V3 — battle-tested computations preserved exactly.
Pure functions: numpy arrays in, dict of features out.
"""

import numpy as np
import pandas as pd
from typing import Dict, Any, Optional


def compute_indicators(df: pd.DataFrame, timeframe: str) -> Dict[str, Any]:
    """
    Calculate all technical indicators for a timeframe.

    Args:
        df: DataFrame with columns: open, high, low, close, volume
        timeframe: Timeframe string (e.g., "1m", "5m", "1h", "4h")

    Returns:
        Dict with indicator values keyed as "indicator.name_timeframe"
    """
    if df is None or len(df) < 50:
        return _empty_indicators(timeframe)

    # Ensure numeric types
    for col in ["open", "high", "low", "close", "volume"]:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors="coerce")

    close = df["close"].values
    high = df["high"].values
    low = df["low"].values
    volume = df["volume"].values

    indicators = {}

    # EMAs (10, 20, 50, 100, 200)
    for period in [10, 20, 50, 100, 200]:
        ema = _ema(close, period)
        indicators[f"ema{period}_{timeframe}"] = ema

    # RSI (14)
    rsi = _rsi(close, 14)
    indicators[f"rsi.level_{timeframe}"] = rsi

    # MACD (12, 26, 9)
    macd, signal, histogram = _macd(close, 12, 26, 9)
    indicators[f"macd.macd_{timeframe}"] = macd
    indicators[f"macd.signal_{timeframe}"] = signal
    indicators[f"macd.histogram_{timeframe}"] = histogram

    # MACD cross strength (normalized by price)
    if macd is not None and signal is not None:
        cross_distance = macd - signal
        indicators[f"macd.cross_strength_{timeframe}"] = (cross_distance / close[-1]) * 10000 if close[-1] > 0 else 0
    else:
        indicators[f"macd.cross_strength_{timeframe}"] = None

    # Bollinger Bands (20, 2)
    upper, middle, lower = _bollinger_bands(close, 20, 2)
    indicators[f"boll.upper_{timeframe}"] = upper
    indicators[f"boll.middle_{timeframe}"] = middle
    indicators[f"boll.lower_{timeframe}"] = lower
    indicators[f"boll.pct_{timeframe}"] = (close[-1] - lower) / (upper - lower) if upper != lower else 0.5

    # ATR (14)
    atr_abs = _atr(high, low, close, 14)
    atr_pct = (atr_abs / close[-1]) * 100 if close[-1] > 0 else 0
    indicators[f"atr.abs_{timeframe}"] = atr_abs
    indicators[f"atr.pct_{timeframe}"] = atr_pct
    indicators[f"inv_atr.pct_{timeframe}"] = (1.0 / atr_pct) if atr_pct > 0 else 0

    # Stochastic (14, 3, 3)
    stoch_k, stoch_d = _stochastic(high, low, close, 14, 3, 3)
    indicators[f"stoch.k_{timeframe}"] = stoch_k
    indicators[f"stoch.d_{timeframe}"] = stoch_d

    # Choppiness Index (14)
    chop = _choppiness(high, low, close, 14)
    indicators[f"chop.index_{timeframe}"] = chop

    # Volume SMA ratio (20)
    vol_sma = _sma(volume, 20)
    vol_ratio = volume[-1] / vol_sma if vol_sma > 0 else 1.0
    indicators[f"vol.ratio_{timeframe}"] = vol_ratio

    # ADX (14) - Average Directional Index
    if len(df) >= 28:  # Need 14 + 14 for smoothing
        adx, plus_di, minus_di = _adx(high, low, close, 14)
        indicators[f"adx.level_{timeframe}"] = adx
        indicators[f"adx.plus_di_{timeframe}"] = plus_di
        indicators[f"adx.minus_di_{timeframe}"] = minus_di
    else:
        indicators[f"adx.level_{timeframe}"] = None
        indicators[f"adx.plus_di_{timeframe}"] = None
        indicators[f"adx.minus_di_{timeframe}"] = None

    # Williams %R (14)
    williams_r = _williams_r(high, low, close, 14)
    indicators[f"willr.level_{timeframe}"] = williams_r

    # CCI (20)
    cci = _cci(high, low, close, 20)
    indicators[f"cci.level_{timeframe}"] = cci

    # OBV (On-Balance Volume)
    obv = _obv(close, volume)
    indicators[f"obv.value_{timeframe}"] = obv

    # OBV slope and divergence - O(n) incremental calculation
    if len(df) >= 28:
        # Build OBV history incrementally in single pass (NOT O(n^2))
        obv_history = [0.0]
        running_obv = 0.0
        for i in range(1, len(close)):
            if close[i] > close[i - 1]:
                running_obv += volume[i]
            elif close[i] < close[i - 1]:
                running_obv -= volume[i]
            obv_history.append(running_obv)

        obv_14_ago = obv_history[-15] if len(obv_history) >= 15 else obv_history[0]
        obv_slope = (obv - obv_14_ago) / 14
        indicators[f"obv.slope_{timeframe}"] = obv_slope

        # Divergence strength (actual magnitude, not binary)
        price_slope = (close[-1] - close[-15]) / close[-15] * 100 if close[-15] > 0 else 0
        obv_slope_pct = (obv - obv_14_ago) / abs(obv_14_ago) * 100 if obv_14_ago != 0 else 0

        # Positive = bullish divergence (price down but OBV up)
        # Negative = bearish divergence (price up but OBV down)
        divergence_strength = obv_slope_pct - price_slope
        indicators[f"obv.divergence_strength_{timeframe}"] = float(divergence_strength)
    else:
        indicators[f"obv.slope_{timeframe}"] = None
        indicators[f"obv.divergence_strength_{timeframe}"] = None

    # Keltner Channels (20, 2)
    kelt_upper, kelt_middle, kelt_lower = _keltner_channels(high, low, close, 20, 2)
    indicators[f"keltner.upper_{timeframe}"] = kelt_upper
    indicators[f"keltner.middle_{timeframe}"] = kelt_middle
    indicators[f"keltner.lower_{timeframe}"] = kelt_lower
    if kelt_upper and kelt_lower and kelt_upper != kelt_lower:
        indicators[f"keltner.pct_{timeframe}"] = (close[-1] - kelt_lower) / (kelt_upper - kelt_lower)
    else:
        indicators[f"keltner.pct_{timeframe}"] = None

    # Donchian Channels (20)
    don_upper, don_lower = _donchian_channels(high, low, 20)
    indicators[f"donchian.upper_{timeframe}"] = don_upper
    indicators[f"donchian.lower_{timeframe}"] = don_lower
    if don_upper and don_lower:
        indicators[f"donchian.midline_{timeframe}"] = (don_upper + don_lower) / 2
        indicators[f"donchian.width_{timeframe}"] = don_upper - don_lower
    else:
        indicators[f"donchian.midline_{timeframe}"] = None
        indicators[f"donchian.width_{timeframe}"] = None

    return indicators


def _empty_indicators(timeframe: str) -> Dict[str, Any]:
    """Return dict with None values for insufficient data."""
    return {
        f"ema10_{timeframe}": None,
        f"ema20_{timeframe}": None,
        f"ema50_{timeframe}": None,
        f"ema100_{timeframe}": None,
        f"ema200_{timeframe}": None,
        f"rsi.level_{timeframe}": None,
        f"macd.macd_{timeframe}": None,
        f"macd.signal_{timeframe}": None,
        f"macd.histogram_{timeframe}": None,
        f"macd.cross_strength_{timeframe}": None,
        f"boll.upper_{timeframe}": None,
        f"boll.middle_{timeframe}": None,
        f"boll.lower_{timeframe}": None,
        f"boll.pct_{timeframe}": None,
        f"atr.abs_{timeframe}": None,
        f"atr.pct_{timeframe}": None,
        f"inv_atr.pct_{timeframe}": None,
        f"stoch.k_{timeframe}": None,
        f"stoch.d_{timeframe}": None,
        f"chop.index_{timeframe}": None,
        f"vol.ratio_{timeframe}": None,
        f"adx.level_{timeframe}": None,
        f"adx.plus_di_{timeframe}": None,
        f"adx.minus_di_{timeframe}": None,
        f"willr.level_{timeframe}": None,
        f"cci.level_{timeframe}": None,
        f"obv.value_{timeframe}": None,
        f"obv.slope_{timeframe}": None,
        f"obv.divergence_strength_{timeframe}": None,
        f"keltner.upper_{timeframe}": None,
        f"keltner.middle_{timeframe}": None,
        f"keltner.lower_{timeframe}": None,
        f"keltner.pct_{timeframe}": None,
        f"donchian.upper_{timeframe}": None,
        f"donchian.lower_{timeframe}": None,
        f"donchian.midline_{timeframe}": None,
        f"donchian.width_{timeframe}": None,
    }


# ---------------------------------------------------------------------------
# Internal computation helpers — math preserved exactly from BebBot V3
# ---------------------------------------------------------------------------


def _ema(data: np.ndarray, period: int) -> Optional[float]:
    """Calculate Exponential Moving Average."""
    if len(data) < period:
        return None

    alpha = 2 / (period + 1)
    ema = data[0]

    for price in data[1:]:
        ema = alpha * price + (1 - alpha) * ema

    return float(ema)


def _sma(data: np.ndarray, period: int) -> Optional[float]:
    """Calculate Simple Moving Average."""
    if len(data) < period:
        return None

    return float(np.mean(data[-period:]))


def _rsi(data: np.ndarray, period: int = 14) -> Optional[float]:
    """Calculate Relative Strength Index."""
    if len(data) < period + 1:
        return None

    deltas = np.diff(data)
    gains = np.where(deltas > 0, deltas, 0)
    losses = np.where(deltas < 0, -deltas, 0)

    avg_gain = np.mean(gains[-period:])
    avg_loss = np.mean(losses[-period:])

    if avg_loss == 0:
        return 100.0

    rs = avg_gain / avg_loss
    rsi = 100 - (100 / (1 + rs))

    return float(rsi)


def _macd(
    data: np.ndarray,
    fast: int = 12,
    slow: int = 26,
    signal: int = 9,
) -> tuple[Optional[float], Optional[float], Optional[float]]:
    """Calculate MACD."""
    if len(data) < slow + signal:
        return None, None, None

    ema_fast = _ema(data, fast)
    ema_slow = _ema(data, slow)

    if ema_fast is None or ema_slow is None:
        return None, None, None

    macd_line = ema_fast - ema_slow

    # Calculate MACD history for signal line — O(n) incremental
    macd_history = []
    k_fast = 2 / (fast + 1)
    k_slow = 2 / (slow + 1)

    # Initialize EMAs with SMA of first 'period' values
    ema_f = np.mean(data[:fast])
    ema_s = np.mean(data[:slow])

    # Build incrementally from slow period onwards
    for i in range(slow, len(data)):
        ema_f = data[i] * k_fast + ema_f * (1 - k_fast)
        ema_s = data[i] * k_slow + ema_s * (1 - k_slow)
        macd_history.append(ema_f - ema_s)

    if len(macd_history) < signal:
        return macd_line, None, None

    signal_line = _ema(np.array(macd_history), signal)

    if signal_line is None:
        return macd_line, None, None

    histogram = macd_line - signal_line

    return macd_line, signal_line, histogram


def _bollinger_bands(
    data: np.ndarray,
    period: int = 20,
    std_dev: float = 2.0,
) -> tuple[Optional[float], Optional[float], Optional[float]]:
    """Calculate Bollinger Bands."""
    if len(data) < period:
        return None, None, None

    middle = _sma(data, period)

    if middle is None:
        return None, None, None

    std = float(np.std(data[-period:]))
    upper = middle + (std_dev * std)
    lower = middle - (std_dev * std)

    return upper, middle, lower


def _atr(
    high: np.ndarray,
    low: np.ndarray,
    close: np.ndarray,
    period: int = 14,
) -> Optional[float]:
    """Calculate Average True Range using Wilder's smoothing."""
    if len(high) < period + 1:
        return None

    # Calculate True Range
    tr = []
    for i in range(1, len(high)):
        h_l = high[i] - low[i]
        h_c = abs(high[i] - close[i - 1])
        l_c = abs(low[i] - close[i - 1])
        tr.append(max(h_l, h_c, l_c))

    if len(tr) < period:
        return None

    # ATR uses Wilder's smoothing (NOT EMA)
    # Wilder's: atr = (prev_atr * (period-1) + current_tr) / period
    atr = sum(tr[:period]) / period  # Initial ATR is SMA
    for i in range(period, len(tr)):
        atr = (atr * (period - 1) + tr[i]) / period

    return atr


def _stochastic(
    high: np.ndarray,
    low: np.ndarray,
    close: np.ndarray,
    k_period: int = 14,
    d_period: int = 3,
    smooth: int = 3,
) -> tuple[Optional[float], Optional[float]]:
    """Calculate Stochastic Oscillator."""
    if len(close) < k_period + d_period:
        return None, None

    # Calculate %K
    highest_high = np.max(high[-k_period:])
    lowest_low = np.min(low[-k_period:])

    if highest_high == lowest_low:
        k_raw = 50.0
    else:
        k_raw = ((close[-1] - lowest_low) / (highest_high - lowest_low)) * 100

    # Smooth %K
    k_values = []
    for i in range(k_period, len(close) + 1):
        hh = np.max(high[i - k_period : i])
        ll = np.min(low[i - k_period : i])
        if hh != ll:
            k_values.append(((close[i - 1] - ll) / (hh - ll)) * 100)
        else:
            k_values.append(50.0)

    if len(k_values) < smooth:
        return k_raw, None

    k = _sma(np.array(k_values), smooth)

    # Calculate %D
    if len(k_values) < d_period + smooth:
        return k, None

    d = _sma(np.array(k_values), d_period)

    return k, d


def _choppiness(
    high: np.ndarray,
    low: np.ndarray,
    close: np.ndarray,
    period: int = 14,
) -> Optional[float]:
    """
    Calculate Choppiness Index.

    Choppiness > 61: Market is choppy (sideways)
    Choppiness < 38: Market is trending
    """
    if len(high) < period + 1:
        return None

    # Calculate True Range sum
    tr_sum = 0
    for i in range(-period, 0):
        h_l = high[i] - low[i]
        h_c = abs(high[i] - close[i - 1]) if i > -period else 0
        l_c = abs(low[i] - close[i - 1]) if i > -period else 0
        tr_sum += max(h_l, h_c, l_c)

    # Calculate high-low range over period
    high_max = np.max(high[-period:])
    low_min = np.min(low[-period:])
    range_hl = high_max - low_min

    if range_hl == 0:
        return 50.0  # Neutral

    chop = 100 * np.log10(tr_sum / range_hl) / np.log10(period)

    return float(chop)


def _adx(
    high: np.ndarray,
    low: np.ndarray,
    close: np.ndarray,
    period: int = 14,
) -> tuple[Optional[float], Optional[float], Optional[float]]:
    """
    Calculate ADX (Average Directional Index) with Wilder's smoothing.

    ADX measures trend strength (not direction):
    - ADX > 25: Strong trend
    - ADX > 40: Very strong trend
    - ADX < 20: Weak trend or ranging

    Returns:
        (adx, plus_di, minus_di) — All values 0-100, or None if insufficient data
    """
    if len(high) < period * 2 + 1:
        return None, None, None

    # True Range
    tr = np.zeros(len(high))
    for i in range(1, len(high)):
        hl = high[i] - low[i]
        hc = abs(high[i] - close[i - 1])
        lc = abs(low[i] - close[i - 1])
        tr[i] = max(hl, hc, lc)

    # Directional Movement
    plus_dm = np.zeros(len(high))
    minus_dm = np.zeros(len(high))
    for i in range(1, len(high)):
        up_move = high[i] - high[i - 1]
        down_move = low[i - 1] - low[i]

        if up_move > down_move and up_move > 0:
            plus_dm[i] = up_move
        if down_move > up_move and down_move > 0:
            minus_dm[i] = down_move

    # Wilder's smoothing (NOT standard EMA)
    alpha = 1.0 / period
    atr = np.zeros(len(tr))
    atr[period] = np.mean(tr[1 : period + 1])
    for i in range(period + 1, len(tr)):
        atr[i] = atr[i - 1] + alpha * (tr[i] - atr[i - 1])

    plus_di_smooth = np.zeros(len(plus_dm))
    plus_di_smooth[period] = np.mean(plus_dm[1 : period + 1])
    for i in range(period + 1, len(plus_dm)):
        plus_di_smooth[i] = plus_di_smooth[i - 1] + alpha * (plus_dm[i] - plus_di_smooth[i - 1])

    minus_di_smooth = np.zeros(len(minus_dm))
    minus_di_smooth[period] = np.mean(minus_dm[1 : period + 1])
    for i in range(period + 1, len(minus_dm)):
        minus_di_smooth[i] = minus_di_smooth[i - 1] + alpha * (minus_dm[i] - minus_di_smooth[i - 1])

    # +DI and -DI (protect against zero ATR during startup)
    plus_di = np.zeros_like(atr)
    minus_di = np.zeros_like(atr)
    valid_atr = atr > 0
    plus_di[valid_atr] = 100 * plus_di_smooth[valid_atr] / atr[valid_atr]
    minus_di[valid_atr] = 100 * minus_di_smooth[valid_atr] / atr[valid_atr]

    # DX (protect against zero sum)
    di_sum = plus_di + minus_di
    dx = np.zeros_like(di_sum)
    valid_sum = di_sum > 0
    dx[valid_sum] = 100 * np.abs(plus_di[valid_sum] - minus_di[valid_sum]) / di_sum[valid_sum]
    dx = np.nan_to_num(dx)

    # ADX (Wilder's smoothing of DX)
    adx = np.zeros(len(dx))
    adx[period * 2] = np.mean(dx[period : period * 2])
    for i in range(period * 2 + 1, len(dx)):
        adx[i] = adx[i - 1] + alpha * (dx[i] - adx[i - 1])

    return float(adx[-1]), float(plus_di[-1]), float(minus_di[-1])


def _williams_r(high: np.ndarray, low: np.ndarray, close: np.ndarray, period: int = 14) -> Optional[float]:
    """Williams %R momentum oscillator."""
    if len(high) < period:
        return None

    highest_high = np.max(high[-period:])
    lowest_low = np.min(low[-period:])

    if highest_high == lowest_low:
        return -50.0

    wr = -100 * (highest_high - close[-1]) / (highest_high - lowest_low)
    return float(wr)


def _cci(high: np.ndarray, low: np.ndarray, close: np.ndarray, period: int = 20) -> Optional[float]:
    """Commodity Channel Index."""
    if len(high) < period:
        return None

    typical_price = (high + low + close) / 3
    sma_tp = np.mean(typical_price[-period:])
    mean_deviation = np.mean(np.abs(typical_price[-period:] - sma_tp))

    if mean_deviation == 0:
        return 0.0

    cci = (typical_price[-1] - sma_tp) / (0.015 * mean_deviation)
    return float(cci)


def _obv(close: np.ndarray, volume: np.ndarray) -> float:
    """On-Balance Volume."""
    if len(close) < 2:
        return 0.0

    obv = 0
    for i in range(1, len(close)):
        if close[i] > close[i - 1]:
            obv += volume[i]
        elif close[i] < close[i - 1]:
            obv -= volume[i]

    return float(obv)


def _keltner_channels(
    high: np.ndarray,
    low: np.ndarray,
    close: np.ndarray,
    period: int = 20,
    multiplier: float = 2,
) -> tuple[Optional[float], Optional[float], Optional[float]]:
    """Keltner Channels — EMA +/- (multiplier x ATR)."""
    if len(close) < period:
        return None, None, None

    middle = _ema(close, period)
    atr = _atr(high, low, close, period)

    if middle is None or atr is None:
        return None, None, None

    upper = middle + (multiplier * atr)
    lower = middle - (multiplier * atr)

    return float(upper), float(middle), float(lower)


def _donchian_channels(
    high: np.ndarray,
    low: np.ndarray,
    period: int = 20,
) -> tuple[Optional[float], Optional[float]]:
    """Donchian Channels — highest high and lowest low over period."""
    if len(high) < period:
        return None, None

    upper = np.max(high[-period:])
    lower = np.min(low[-period:])

    return float(upper), float(lower)
