#!/usr/bin/env python3
"""
Grid Trading Backtest
Strategy: Buy/Sell at fixed grid levels within ±20% range
Uses daily candles (granularity=86400) for speed
"""

import requests
import pandas as pd
import time
import json
import os
from datetime import datetime
from typing import Dict, Any, List, Optional

# ── Configuration ─────────────────────────────────────────────────────────────

GRID_PAIRS = ["BTC-USD", "ETH-USD"]
STARTING_BALANCE = 1000.0
GRID_LEVELS = 10
GRID_ORDER_PCT = 0.08       # 8% of starting balance per order
GRID_RANGE_PCT = 0.20       # ±20% range
MONTHS_BACK = 6
GRANULARITY = 86400         # Daily candles

BASE_URL = "https://api.exchange.coinbase.com"
OUTPUT_DIR = "/Users/chip/.openclaw/workspace-cody/backtest"
OUTPUT_JSON = os.path.join(OUTPUT_DIR, "grid_results.json")

SLEEP_BETWEEN = 0.5
RETRY_ATTEMPTS = 3
RETRY_DELAY = 15


# ── Data Fetching ─────────────────────────────────────────────────────────────

def fetch_candles(pair: str) -> pd.DataFrame:
    """Fetch 6 months of daily candles."""
    now_ts = int(time.time())
    start_ts = now_ts - (MONTHS_BACK * 30 * 24 * 3600)
    # Daily candles: 300 candles max per request = 300 days, we need ~180
    chunk_size = 300 * GRANULARITY
    all_candles = []
    chunk_end = now_ts
    chunk_num = 0

    while chunk_end > start_ts:
        chunk_start = max(chunk_end - chunk_size, start_ts)
        chunk_num += 1
        url = f"{BASE_URL}/products/{pair}/candles"
        params = {"granularity": GRANULARITY, "start": chunk_start, "end": chunk_end}
        for attempt in range(RETRY_ATTEMPTS):
            try:
                resp = requests.get(url, params=params, timeout=30)
                if resp.status_code == 200:
                    data = resp.json()
                    all_candles.extend(data)
                    print(f"  {pair}: chunk {chunk_num}, got {len(data)} candles")
                    break
                elif resp.status_code == 429:
                    print(f"  Rate limited, waiting {RETRY_DELAY}s...")
                    time.sleep(RETRY_DELAY)
                else:
                    print(f"  HTTP {resp.status_code}, retrying...")
                    time.sleep(5)
            except Exception as e:
                print(f"  Error: {e}, retrying...")
                time.sleep(5)
        chunk_end = chunk_start
        time.sleep(SLEEP_BETWEEN)

    if not all_candles:
        return pd.DataFrame()

    df = pd.DataFrame(all_candles, columns=["timestamp", "low", "high", "open", "close", "volume"])
    df["timestamp"] = pd.to_datetime(df["timestamp"], unit="s", utc=True)
    df = df.sort_values("timestamp").drop_duplicates("timestamp").reset_index(drop=True)
    # Keep only last 6 months
    cutoff = pd.Timestamp.utcnow() - pd.Timedelta(days=MONTHS_BACK * 30)
    df = df[df["timestamp"] >= cutoff].reset_index(drop=True)
    print(f"  {pair}: total {len(df)} daily candles")
    return df


# ── Grid Strategy ─────────────────────────────────────────────────────────────

def run_grid_backtest(df: pd.DataFrame, label: str) -> Dict[str, Any]:
    """Run grid trading simulation on daily candle data."""

    if df.empty:
        return {}

    starting_price = float(df.iloc[0]["close"])
    lower_bound = starting_price * (1 - GRID_RANGE_PCT)
    upper_bound = starting_price * (1 + GRID_RANGE_PCT)
    grid_spacing = (upper_bound - lower_bound) / (GRID_LEVELS - 1)

    # Grid level prices
    grid_prices = [lower_bound + i * grid_spacing for i in range(GRID_LEVELS)]

    order_size_usd = STARTING_BALANCE * GRID_ORDER_PCT  # $80 per order

    print(f"\n  {label} Grid Setup:")
    print(f"    Starting price: ${starting_price:,.2f}")
    print(f"    Range: ${lower_bound:,.2f} – ${upper_bound:,.2f}")
    print(f"    Grid spacing: ${grid_spacing:,.2f}")
    print(f"    Order size: ${order_size_usd:.2f}")

    # State
    cash = STARTING_BALANCE
    # active_buys: set of level indices with pending buy orders
    # Initially all levels have buy orders (price will drop to them)
    active_buys = set(range(GRID_LEVELS))
    # active_sells: dict of level_idx -> (qty, cost_basis) - pending sell orders
    # sell at level L means we'll sell the coin bought at level L-1
    active_sells: Dict[int, tuple] = {}

    equity_curve: List[Dict] = []
    total_trades = 0
    grid_profit = 0.0
    out_of_range_days = 0
    paused = False

    for i, row in df.iterrows():
        day_low = float(row["low"])
        day_high = float(row["high"])
        day_close = float(row["close"])
        day_time = row["timestamp"].isoformat()

        # Check if price is entirely out of range
        entirely_out = (day_high < lower_bound) or (day_low > upper_bound)

        if entirely_out:
            out_of_range_days += 1
            paused = True
        else:
            paused = False

        if not paused:
            # ── Process BUY orders ──────────────────────────────────────────
            # Price touches a level if low <= grid_price <= high
            # Process sells first (ascending), then buys (descending)
            # This simulates a day where price may go up first then down

            # Execute SELL orders (price rose to these levels)
            levels_to_remove_sell = []
            new_buys_from_sells = []
            for level_idx in sorted(active_sells.keys()):
                grid_price = grid_prices[level_idx]
                if day_low <= grid_price <= day_high:
                    qty, cost_basis = active_sells[level_idx]
                    revenue = qty * grid_price
                    profit = revenue - cost_basis
                    cash += revenue
                    grid_profit += profit
                    total_trades += 1
                    levels_to_remove_sell.append(level_idx)
                    # Place buy at level below
                    if level_idx - 1 >= 0:
                        new_buys_from_sells.append(level_idx - 1)

            for l in levels_to_remove_sell:
                del active_sells[l]
            for l in new_buys_from_sells:
                active_buys.add(l)

            # Execute BUY orders (price dropped to these levels)
            levels_to_remove_buy = []
            new_sells_from_buys = []
            for level_idx in sorted(active_buys, reverse=True):
                grid_price = grid_prices[level_idx]
                if day_low <= grid_price <= day_high:
                    if cash >= order_size_usd:
                        qty = order_size_usd / grid_price
                        cash -= order_size_usd
                        total_trades += 1
                        levels_to_remove_buy.append(level_idx)
                        # Place sell at level above
                        if level_idx + 1 < GRID_LEVELS:
                            new_sells_from_buys.append((level_idx + 1, qty, order_size_usd))

            for l in levels_to_remove_buy:
                active_buys.discard(l)
            for l, qty, cost in new_sells_from_buys:
                # If there's already a sell at this level, add to it
                if l in active_sells:
                    existing_qty, existing_cost = active_sells[l]
                    active_sells[l] = (existing_qty + qty, existing_cost + cost)
                else:
                    active_sells[l] = (qty, cost)

        # ── Track daily equity ──────────────────────────────────────────────
        # Value of all open positions (coin held awaiting sell) at closing price
        open_position_value = sum(
            qty * day_close for (qty, cost) in active_sells.values()
        )
        total_equity = cash + open_position_value
        equity_curve.append({
            "time": day_time,
            "equity": round(total_equity, 2),
        })

    # ── Final stats ───────────────────────────────────────────────────────────
    last_close = float(df.iloc[-1]["close"])
    open_position_value = sum(qty * last_close for (qty, _) in active_sells.values())
    open_cost_basis = sum(cost for (_, cost) in active_sells.values())
    unrealized_pnl = open_position_value - open_cost_basis
    final_balance = cash + open_position_value
    total_return = (final_balance - STARTING_BALANCE) / STARTING_BALANCE * 100

    # Max drawdown
    max_drawdown = 0.0
    if equity_curve:
        equities = [e["equity"] for e in equity_curve]
        peak = equities[0]
        for eq in equities:
            if eq > peak:
                peak = eq
            dd = (peak - eq) / peak * 100 if peak > 0 else 0
            if dd > max_drawdown:
                max_drawdown = dd

    result = {
        "label": label,
        "final_balance": round(final_balance, 2),
        "total_return_pct": round(total_return, 2),
        "total_trades": total_trades,
        "grid_profit": round(grid_profit, 2),
        "unrealized_pnl": round(unrealized_pnl, 2),
        "max_drawdown": round(max_drawdown, 2),
        "out_of_range_days": out_of_range_days,
        "starting_price": round(starting_price, 2),
        "lower_bound": round(lower_bound, 2),
        "upper_bound": round(upper_bound, 2),
        "grid_spacing": round(grid_spacing, 2),
        "equity_curve": equity_curve,
    }

    print(f"  ✅ {label} done!")
    print(f"     Final: ${final_balance:,.2f} | Return: {total_return:+.2f}%")
    print(f"     Trades: {total_trades} | Grid P&L: ${grid_profit:+.2f} | Unrealized: ${unrealized_pnl:+.2f}")
    print(f"     Max DD: {max_drawdown:.2f}% | Out-of-range days: {out_of_range_days}")

    return result


# ── Main ──────────────────────────────────────────────────────────────────────

def main():
    print("=" * 60)
    print("  Grid Trading Backtest — ±20% Range, 10 Levels")
    print(f"  {len(GRID_PAIRS)} pairs × 6 months daily candles")
    print("=" * 60)

    results = []

    for pair in GRID_PAIRS:
        label = pair.replace("-USD", "") + " Grid"
        print(f"\n{'─'*52}")
        print(f"  {label} ({pair})")
        print(f"{'─'*52}")

        df = fetch_candles(pair)
        if df.empty:
            print(f"  ⚠️  No data for {pair}, skipping.")
            continue

        result = run_grid_backtest(df, label)
        if result:
            results.append(result)

    if not results:
        print("\n❌ No results.")
        return

    # Save JSON
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    with open(OUTPUT_JSON, "w") as f:
        json.dump(results, f, indent=2, default=str)

    print(f"\n{'='*60}")
    print(f"  ✅ Results saved to: {OUTPUT_JSON}")
    print(f"{'='*60}\n")

    return results


if __name__ == "__main__":
    main()
