"""
Synchronous Bybit V5 REST client.

Reads credentials from environment variables:
    BYBIT_API_KEY
    BYBIT_API_SECRET
    BYBIT_MODE        ("demo" | "live", default "demo")
"""

import logging
import os
import time
from decimal import Decimal
from typing import Any, Dict, List, Optional
from urllib.parse import urlencode

import httpx
import pandas as pd

from .auth import build_auth_headers
from .endpoints import BybitEndpoints, get_bybit_interval
from .models import InstrumentInfo

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Exceptions
# ---------------------------------------------------------------------------

class BybitAPIError(Exception):
    """Raised when Bybit returns a non-zero retCode."""

    def __init__(self, ret_code: int, ret_msg: str, full_response: Dict[str, Any]):
        self.ret_code = ret_code
        self.ret_msg = ret_msg
        self.full_response = full_response
        super().__init__(f"Bybit API error {ret_code}: {ret_msg}")


# ---------------------------------------------------------------------------
# Rate limiter
# ---------------------------------------------------------------------------

class _TokenBucket:
    """Simple token-bucket rate limiter (sync)."""

    def __init__(self, capacity: int, refill_rate: float):
        self.capacity = capacity
        self.refill_rate = refill_rate
        self.tokens = float(capacity)
        self.last_refill = time.time()

    def _refill(self):
        now = time.time()
        self.tokens = min(self.capacity, self.tokens + (now - self.last_refill) * self.refill_rate)
        self.last_refill = now

    def wait(self, n: int = 1):
        while True:
            self._refill()
            if self.tokens >= n:
                self.tokens -= n
                return
            needed = n - self.tokens
            time.sleep(max(0.01, needed / self.refill_rate))


# ---------------------------------------------------------------------------
# Client
# ---------------------------------------------------------------------------

class BybitClient:
    """
    Synchronous Bybit V5 REST client.

    All methods block until the HTTP round-trip completes.
    """

    def __init__(
        self,
        api_key: Optional[str] = None,
        api_secret: Optional[str] = None,
        mode: Optional[str] = None,
        timeout: int = 30,
    ):
        self.api_key = api_key or os.environ["BYBIT_API_KEY"]
        self.api_secret = api_secret or os.environ["BYBIT_API_SECRET"]
        self.mode = (mode or os.environ.get("BYBIT_MODE", "demo")).lower()
        self.base_url = BybitEndpoints.get_rest_base_url(self.mode)
        self._http = httpx.Client(timeout=timeout)
        self._limiter = _TokenBucket(capacity=10, refill_rate=10.0)
        logger.debug("BybitClient initialised (%s): %s", self.mode, self.base_url)

    # ------------------------------------------------------------------
    # Low-level request
    # ------------------------------------------------------------------

    def _request(
        self,
        method: str,
        endpoint: str,
        params: Optional[Dict[str, Any]] = None,
        auth: bool = False,
        max_retries: int = 3,
    ) -> Dict[str, Any]:
        url = f"{self.base_url}{endpoint}"
        params = params or {}

        for attempt in range(max_retries):
            try:
                self._limiter.wait()

                if auth:
                    headers = build_auth_headers(
                        self.api_key, self.api_secret, params, method
                    )
                else:
                    headers = {"Content-Type": "application/json"}

                if method == "GET":
                    if auth:
                        qs = urlencode(sorted(params.items()))
                        response = self._http.get(f"{url}?{qs}", headers=headers)
                    else:
                        response = self._http.get(url, params=params, headers=headers)
                elif method == "POST":
                    response = self._http.post(url, json=params, headers=headers)
                else:
                    raise ValueError(f"Unsupported HTTP method: {method}")

                response.raise_for_status()
                data = response.json()

                if not isinstance(data, dict):
                    if attempt < max_retries - 1:
                        logger.warning("Non-dict response (%s), retrying...", type(data).__name__)
                        time.sleep(1)
                        continue
                    raise BybitAPIError(-1, f"Unexpected response type: {type(data).__name__}", {})

                ret_code = data.get("retCode", 0)

                # Rate limit -- back off and retry
                if ret_code == 10006:
                    if attempt < max_retries - 1:
                        logger.warning("Rate limit (attempt %d/%d), waiting 5 s", attempt + 1, max_retries)
                        time.sleep(5)
                        continue
                    raise BybitAPIError(ret_code, "Rate limit exceeded after retries", data)

                # Timestamp drift -- retry with fresh timestamp
                if ret_code == 10002:
                    if attempt < max_retries - 1:
                        logger.warning("Timestamp error (attempt %d/%d), retrying", attempt + 1, max_retries)
                        time.sleep(0.5)
                        continue
                    raise BybitAPIError(ret_code, data.get("retMsg", "timestamp error"), data)

                # 110043 = leverage not modified (already correct)
                # 34040  = trading stop not modified
                if ret_code not in (0, 110043, 34040):
                    raise BybitAPIError(ret_code, data.get("retMsg", "unknown"), data)

                return data

            except (httpx.ConnectTimeout, httpx.ReadTimeout):
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                    continue
                raise

            except httpx.HTTPStatusError as exc:
                if exc.response.status_code == 429 and attempt < max_retries - 1:
                    time.sleep(5)
                    continue
                raise

            except httpx.NetworkError:
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                    continue
                raise

            except BybitAPIError:
                raise

            except Exception:
                raise

        raise BybitAPIError(-1, "Max retries reached", {})

    # ------------------------------------------------------------------
    # Public market data
    # ------------------------------------------------------------------

    def get_klines(
        self,
        symbol: str,
        interval: str,
        limit: int = 200,
    ) -> pd.DataFrame:
        """Return OHLCV DataFrame (chronological order)."""
        params = {
            "category": "linear",
            "symbol": symbol,
            "interval": get_bybit_interval(interval),
            "limit": min(limit, 1000),
        }
        result = self._request("GET", BybitEndpoints.KLINES, params)
        klines = result.get("result", {}).get("list", [])
        if not klines:
            return pd.DataFrame()

        klines = list(reversed(klines))
        df = pd.DataFrame(klines, columns=["time", "open", "high", "low", "close", "volume", "turnover"])
        df["time"] = df["time"].astype(int)
        for col in ("open", "high", "low", "close", "volume", "turnover"):
            df[col] = df[col].astype(float)
        return df[["time", "open", "high", "low", "close", "volume"]]

    def get_recent_trades(self, symbol: str, limit: int = 1000) -> List[Dict[str, Any]]:
        """Return recent public trades (chronological order)."""
        params = {"category": "linear", "symbol": symbol, "limit": min(limit, 1000)}
        result = self._request("GET", BybitEndpoints.RECENT_TRADES, params)
        trades = list(reversed(result.get("result", {}).get("list", [])))
        return [
            {
                "time": int(t["time"]),
                "price": float(t["price"]),
                "qty": float(t["size"]),
                "side": t["side"],
            }
            for t in trades
        ]

    def get_orderbook(self, symbol: str, limit: int = 50) -> Dict[str, Any]:
        """Return orderbook snapshot with bids, asks, ts."""
        params = {"category": "linear", "symbol": symbol, "limit": limit}
        result = self._request("GET", BybitEndpoints.ORDERBOOK, params)
        rd = result.get("result", {})
        return {
            "bids": [[float(p), float(q)] for p, q in rd.get("b", [])],
            "asks": [[float(p), float(q)] for p, q in rd.get("a", [])],
            "ts": int(rd.get("ts", int(time.time() * 1000))),
        }

    def get_ticker(self, symbol: str) -> Dict[str, Any]:
        """Return 24 h ticker data."""
        params = {"category": "linear", "symbol": symbol}
        result = self._request("GET", BybitEndpoints.TICKERS, params)
        tickers = result.get("result", {}).get("list", [])
        if not tickers:
            raise BybitAPIError(-1, f"No ticker data for {symbol}", result)
        t = tickers[0]
        return {
            "symbol": t["symbol"],
            "last_price": float(t.get("lastPrice", 0)),
            "bid": float(t.get("bid1Price", 0)),
            "ask": float(t.get("ask1Price", 0)),
            "volume_24h": float(t.get("volume24h", 0)),
            "turnover_24h": float(t.get("turnover24h", 0)),
            "funding_rate": float(t.get("fundingRate", 0)),
            "next_funding_time": int(t.get("nextFundingTime", 0)),
            "open_interest": float(t.get("openInterest", 0)),
            "mark_price": float(t.get("markPrice", 0)),
            "index_price": float(t.get("indexPrice", 0)),
            "ts": int(time.time() * 1000),
        }

    def get_server_time(self) -> int:
        """Return Bybit server time in milliseconds."""
        result = self._request("GET", BybitEndpoints.SERVER_TIME, auth=False)
        return int(result.get("result", {}).get("timeSecond", "0")) * 1000

    def get_open_interest(
        self, symbol: str, interval: str, limit: int = 200
    ) -> List[Dict[str, Any]]:
        """Return historical open interest (chronological)."""
        params = {
            "category": "linear",
            "symbol": symbol,
            "intervalTime": interval,
            "limit": min(limit, 200),
        }
        result = self._request("GET", BybitEndpoints.OPEN_INTEREST, params)
        entries = list(reversed(result.get("result", {}).get("list", [])))
        return [
            {"time": int(e["timestamp"]), "open_interest": float(e["openInterest"])}
            for e in entries
        ]

    def get_funding_history(
        self, symbol: str, limit: int = 50
    ) -> List[Dict[str, Any]]:
        """Return funding rate history (chronological)."""
        params = {
            "category": "linear",
            "symbol": symbol,
            "limit": min(limit, 200),
        }
        result = self._request("GET", BybitEndpoints.FUNDING_HISTORY, params)
        entries = list(reversed(result.get("result", {}).get("list", [])))
        return [
            {"time": int(e["fundingRateTimestamp"]), "funding_rate": float(e["fundingRate"])}
            for e in entries
        ]

    def get_instrument_info(self, symbol: str) -> InstrumentInfo:
        """Return instrument trading rules."""
        params = {"category": "linear", "symbol": symbol}
        result = self._request("GET", BybitEndpoints.INSTRUMENTS_INFO, params)
        instruments = result.get("result", {}).get("list", [])
        if not instruments:
            raise BybitAPIError(-1, f"No instrument info for {symbol}", result)
        inst = instruments[0]
        pf = inst.get("priceFilter", {})
        lf = inst.get("lotSizeFilter", {})
        lvf = inst.get("leverageFilter", {})
        return InstrumentInfo(
            symbol=inst["symbol"],
            status=inst.get("status", "Trading"),
            tick_size=Decimal(pf.get("tickSize", "0.01")),
            lot_size=Decimal(lf.get("qtyStep", "0.001")),
            min_notional=Decimal(lf.get("minOrderQty", "0")),
            max_order_qty=Decimal(lf.get("maxOrderQty", "0")),
            max_leverage=int(float(lvf.get("maxLeverage", "50"))),
        )

    # ------------------------------------------------------------------
    # Authenticated: trading
    # ------------------------------------------------------------------

    def place_order(
        self,
        symbol: str,
        side: str,
        order_type: str,
        qty: Decimal,
        price: Optional[Decimal] = None,
        time_in_force: str = "GTC",
        order_link_id: Optional[str] = None,
        stop_loss: Optional[Decimal] = None,
        take_profit: Optional[Decimal] = None,
        reduce_only: bool = False,
        trigger_price: Optional[Decimal] = None,
        trigger_by: Optional[str] = None,
        position_idx: int = 0,
    ) -> Dict[str, Any]:
        """Place an order with optional atomic SL/TP."""
        params: Dict[str, Any] = {
            "category": "linear",
            "symbol": symbol,
            "side": side,
            "orderType": order_type,
            "qty": str(qty),
            "timeInForce": time_in_force,
            "positionIdx": position_idx,
        }

        if price is not None:
            params["price"] = str(price)
        if order_link_id:
            params["orderLinkId"] = order_link_id

        # Atomic SL / TP
        if stop_loss or take_profit:
            params["tpslMode"] = "Full"
        if stop_loss:
            params["stopLoss"] = str(stop_loss)
            params["slOrderType"] = "Market"
            params["slTriggerBy"] = "MarkPrice"
        if take_profit:
            params["takeProfit"] = str(take_profit)
            params["tpOrderType"] = "Market"
            params["tpTriggerBy"] = "MarkPrice"

        # Conditional trigger
        if trigger_price:
            params["triggerPrice"] = str(trigger_price)
            params["triggerBy"] = trigger_by or "MarkPrice"
            params["triggerDirection"] = 1 if side == "Sell" else 2

        if reduce_only:
            params["reduceOnly"] = True

        return self._request("POST", BybitEndpoints.PLACE_ORDER, params, auth=True)

    def cancel_order(
        self,
        symbol: str,
        order_id: Optional[str] = None,
        order_link_id: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Cancel an open order by orderId or orderLinkId."""
        params: Dict[str, Any] = {"category": "linear", "symbol": symbol}
        if order_link_id:
            params["orderLinkId"] = order_link_id
        elif order_id:
            params["orderId"] = order_id
        else:
            raise ValueError("Either order_id or order_link_id is required")
        return self._request("POST", BybitEndpoints.CANCEL_ORDER, params, auth=True)

    def set_leverage(self, symbol: str, leverage: int) -> Dict[str, Any]:
        """Set leverage for a symbol (buy and sell sides)."""
        params = {
            "category": "linear",
            "symbol": symbol,
            "buyLeverage": str(leverage),
            "sellLeverage": str(leverage),
        }
        return self._request("POST", BybitEndpoints.SET_LEVERAGE, params, auth=True)

    def set_trading_stop(
        self,
        symbol: str,
        tp_sl_mode: str = "Full",
        take_profit: Optional[str] = None,
        tp_order_type: Optional[str] = None,
        tp_size: Optional[str] = None,
        tp_limit_price: Optional[str] = None,
        stop_loss: Optional[str] = None,
        sl_order_type: Optional[str] = None,
        sl_size: Optional[str] = None,
        sl_limit_price: Optional[str] = None,
        position_idx: int = 0,
    ) -> Dict[str, Any]:
        """Set or modify trading stop (TP/SL) for a position."""
        params: Dict[str, Any] = {
            "category": "linear",
            "symbol": symbol,
            "tpslMode": tp_sl_mode,
            "positionIdx": position_idx,
        }
        if take_profit:
            params["takeProfit"] = take_profit
        if tp_order_type:
            params["tpOrderType"] = tp_order_type
        if tp_size:
            params["tpSize"] = tp_size
        if tp_limit_price:
            params["tpLimitPrice"] = tp_limit_price
        if stop_loss:
            params["stopLoss"] = stop_loss
        if sl_order_type:
            params["slOrderType"] = sl_order_type
        if sl_size:
            params["slSize"] = sl_size
        if sl_limit_price:
            params["slLimitPrice"] = sl_limit_price
        return self._request("POST", BybitEndpoints.SET_TRADING_STOP, params, auth=True)

    # ------------------------------------------------------------------
    # Authenticated: account & positions
    # ------------------------------------------------------------------

    def get_wallet_balance(self) -> Dict[str, Any]:
        """Return wallet balance summary."""
        params = {"accountType": "UNIFIED"}
        result = self._request("GET", BybitEndpoints.GET_WALLET_BALANCE, params, auth=True)
        accounts = result.get("result", {}).get("list", [])
        if not accounts:
            raise BybitAPIError(-1, "No wallet balance data", result)
        a = accounts[0]
        return {
            "total_equity": Decimal(a.get("totalEquity", "0")),
            "wallet_balance": Decimal(a.get("totalWalletBalance", "0")),
            "available_balance": Decimal(a.get("totalAvailableBalance", "0")),
            "used_margin": Decimal(a.get("totalInitialMargin", "0")),
            "unrealized_pnl": Decimal(a.get("totalPerpUPL", "0")),
        }

    def get_positions(
        self,
        symbol: Optional[str] = None,
        settle_coin: Optional[str] = None,
        limit: Optional[int] = None,
    ) -> List[Dict[str, Any]]:
        """Return open positions."""
        params: Dict[str, Any] = {"category": "linear"}
        if symbol:
            params["symbol"] = symbol
        elif settle_coin:
            params["settleCoin"] = settle_coin
        else:
            params["settleCoin"] = "USDT"
        if limit:
            params["limit"] = limit
        result = self._request("GET", BybitEndpoints.GET_POSITION_INFO, params, auth=True)
        return result.get("result", {}).get("list", [])

    def get_open_orders(
        self,
        symbol: Optional[str] = None,
        order_id: Optional[str] = None,
        order_link_id: Optional[str] = None,
        limit: int = 20,
    ) -> List[Dict[str, Any]]:
        """Return open (unfilled + partially filled) orders."""
        params: Dict[str, Any] = {
            "category": "linear",
            "settleCoin": "USDT",
            "limit": limit,
        }
        if order_id:
            params["orderId"] = order_id
        if order_link_id:
            params["orderLinkId"] = order_link_id
        if symbol:
            params["symbol"] = symbol
        result = self._request("GET", BybitEndpoints.GET_OPEN_ORDERS, params, auth=True)
        return result.get("result", {}).get("list", [])

    def get_order_history(
        self,
        symbol: Optional[str] = None,
        start_time: Optional[int] = None,
        end_time: Optional[int] = None,
        limit: int = 50,
        cursor: Optional[str] = None,
        order_id: Optional[str] = None,
        order_link_id: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Return order history (filled / cancelled / rejected)."""
        params: Dict[str, Any] = {"category": "linear", "limit": limit}
        if order_id:
            params["orderId"] = order_id
        if order_link_id:
            params["orderLinkId"] = order_link_id
        if symbol:
            params["symbol"] = symbol
        if start_time:
            params["startTime"] = start_time
        if end_time:
            params["endTime"] = end_time
        if cursor:
            params["cursor"] = cursor
        result = self._request("GET", BybitEndpoints.GET_ORDER_HISTORY, params, auth=True)
        return result.get("result", {})

    def get_closed_pnl(
        self,
        symbol: Optional[str] = None,
        limit: int = 50,
        start_time: Optional[int] = None,
        end_time: Optional[int] = None,
        cursor: Optional[str] = None,
    ) -> Dict[str, Any]:
        """Return closed P&L history."""
        params: Dict[str, Any] = {"category": "linear", "limit": min(limit, 100)}
        if symbol:
            params["symbol"] = symbol
        if start_time:
            params["startTime"] = start_time
        if end_time:
            params["endTime"] = end_time
        if cursor:
            params["cursor"] = cursor
        result = self._request("GET", BybitEndpoints.GET_CLOSED_PNL, params, auth=True)
        api_result = result.get("result", {})
        if isinstance(api_result, str):
            return {"list": []}
        return api_result

    def get_executions(
        self,
        symbol: str,
        order_link_id: Optional[str] = None,
        start_time: Optional[int] = None,
        end_time: Optional[int] = None,
        limit: int = 50,
    ) -> List[Dict[str, Any]]:
        """Return execution (fill) history for a symbol."""
        params: Dict[str, Any] = {
            "category": "linear",
            "symbol": symbol,
            "limit": min(limit, 100),
        }
        if order_link_id:
            params["orderLinkId"] = order_link_id
        if start_time:
            params["startTime"] = start_time
        if end_time:
            params["endTime"] = end_time
        result = self._request("GET", BybitEndpoints.EXECUTION_LIST, params, auth=True)
        return result.get("result", {}).get("list", [])

    # ------------------------------------------------------------------
    # Lifecycle
    # ------------------------------------------------------------------

    def close(self):
        """Close the underlying HTTP client."""
        self._http.close()
        logger.info("BybitClient closed")
