import os
import re
import base64
import io
import threading
import shutil
import time
import sys
from typing import Optional, List, Dict, Set, Any
from contextlib import asynccontextmanager

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image
import numpy as np
import onnxruntime as ort

# ======================================================
# RESOURCE PATH
# ======================================================
def resource_path(relative_path: str) -> str:
    """Get absolute path to resource, works for dev and PyInstaller."""
    if hasattr(sys, "_MEIPASS"):
        return os.path.join(sys._MEIPASS, relative_path)
    return os.path.join(os.path.dirname(os.path.abspath(__file__)), relative_path)

# ======================================================
# CONFIG
# ======================================================
# External model folders.
# Keep these beside the EXE / app.py:
#   ./model_hot
#   ./model_cold
HOT_DIR = "./model_hot"
COLD_DIR = "./model_cold"

EVICTION_INTERVAL_SEC = 300
MIN_HOT_MODELS = 0

TARGET_CLASS_INDEX = 0

# ======================================================
# NORMALIZATION
# ======================================================
def normalize_label(s: str) -> str:
    if not s:
        return ""
    s = s.replace("_", " ").replace("-", " ").lower()
    s = re.sub(r"\s+", " ", s).strip()
    return s

# ======================================================
# HOMOGLYPH STRIPPING (Cyrillic/Greek lookalikes -> Latin)
# ======================================================
_HOMOGLYPH_MAP = str.maketrans({
    '\u0410': 'A', '\u0430': 'a', '\u0415': 'E', '\u0435': 'e',
    '\u041E': 'O', '\u043E': 'o', '\u0421': 'C', '\u0441': 'c',
    '\u0420': 'P', '\u0440': 'p', '\u0425': 'X', '\u0445': 'x',
    '\u0422': 'T', '\u0412': 'B', '\u041A': 'K', '\u043A': 'k',
    '\u041C': 'M', '\u043C': 'm', '\u0423': 'Y', '\u0443': 'y',
    '\u0391': 'A', '\u03B1': 'a', '\u0395': 'E', '\u03B5': 'e',
    '\u039F': 'O', '\u03BF': 'o', '\u0399': 'I', '\u03B9': 'i',
    '\u039A': 'K', '\u03BA': 'k', '\u039C': 'M', '\u03BC': 'm',
    '\u039D': 'N', '\u03BD': 'n', '\u03A1': 'P', '\u03C1': 'p',
    '\u03A4': 'T', '\u03C4': 't', '\u03A5': 'Y', '\u03C5': 'y',
})


def decode_b64_to_pil(image_b64: str) -> Image.Image:
    if image_b64.startswith("data:image"):
        image_b64 = image_b64.split(",", 1)[-1]
    raw = base64.b64decode(image_b64)
    return Image.open(io.BytesIO(raw)).convert("RGB")

def preprocess(pil_img: Image.Image, img_size: int = 256) -> np.ndarray:
    pil_img = pil_img.resize((img_size, img_size))
    arr = np.array(pil_img, dtype=np.float32) / 255.0
    arr = arr.transpose(2, 0, 1)
    return np.expand_dims(arr, axis=0)

# ======================================================
# ONNX MODEL REGISTRY
# ======================================================
_model_cache: Dict[str, ort.InferenceSession] = {}
_registry: Dict[str, str] = {}
_cold_index: Dict[str, str] = {}
_model_last_used: Dict[str, float] = {}
_hot_model_files: Set[str] = set()

_lock = threading.Lock()

# ======================================================
# ONNX SESSION LOADER
# ======================================================
def load_session(model_path: str) -> ort.InferenceSession:
    with _lock:
        if model_path in _model_cache:
            return _model_cache[model_path]

    sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])

    with _lock:
        _model_cache[model_path] = sess

    return sess

def detect(
    sess: ort.InferenceSession,
    tensor: np.ndarray,
    conf_threshold: float,
) -> bool:
    input_name = sess.get_inputs()[0].name
    raw = sess.run(None, {input_name: tensor})[0]

    preds = raw[0]
    num_classes = preds.shape[0] - 4

    if TARGET_CLASS_INDEX >= num_classes:
        return False

    scores = preds[4 + TARGET_CLASS_INDEX, :]
    return bool(np.any(scores >= conf_threshold))

# ======================================================
# MODEL SCANNING
# ======================================================
def _find_onnx_in_folder(folder: str) -> Optional[str]:
    if not os.path.isdir(folder):
        return None

    for f in os.listdir(folder):
        if f.lower().endswith(".onnx"):
            return os.path.join(folder, f)

    return None

def scan_hot_models():
    os.makedirs(HOT_DIR, exist_ok=True)

    entries = sorted(os.listdir(HOT_DIR))  # deterministic order

    for entry in entries:
        folder_path = os.path.join(HOT_DIR, entry)

        if not os.path.isdir(folder_path):
            continue

        onnx_path = _find_onnx_in_folder(folder_path)

        if not onnx_path:
            continue

        try:
            with _lock:
                current_hot = len(_hot_model_files)

            if current_hot >= MIN_HOT_MODELS:
                # Already at capacity -- move this folder straight to cold
                norm = normalize_label(entry)
                cold_folder = os.path.join(COLD_DIR, entry)
                try:
                    shutil.move(folder_path, cold_folder)
                    cold_onnx = os.path.join(cold_folder, os.path.basename(onnx_path))
                    with _lock:
                        _cold_index[norm] = cold_onnx
                    print(f"[hot->cold] '{entry}' moved to cold at startup (hot full)")
                except Exception as e:
                    print(f"[hot->cold] Failed to move '{entry}': {e}")
                continue

            load_session(onnx_path)
            norm = normalize_label(entry)

            with _lock:
                _registry[norm] = onnx_path
                _hot_model_files.add(onnx_path)
                _model_last_used[onnx_path] = time.time()

            print(f"[hot] Loaded: {entry}")

        except Exception as e:
            print(f"[hot] Error loading {folder_path}: {e}")

def scan_cold_models():
    os.makedirs(COLD_DIR, exist_ok=True)

    count = 0

    for entry in os.listdir(COLD_DIR):
        folder_path = os.path.join(COLD_DIR, entry)

        if not os.path.isdir(folder_path):
            continue

        onnx_path = _find_onnx_in_folder(folder_path)

        if not onnx_path:
            continue

        norm = normalize_label(entry)

        with _lock:
            if norm not in _cold_index:
                _cold_index[norm] = onnx_path
                count += 1

    print(f"[cold] Indexed {count} model folder(s) in {COLD_DIR}")

# ======================================================
# PROMOTE MODEL: cold -> hot
# ======================================================
def promote_from_cold(norm_question: str) -> Optional[str]:
    with _lock:
        cold_onnx = _cold_index.get(norm_question)

    if cold_onnx is None:
        return None

    cold_folder = os.path.dirname(cold_onnx)
    folder_name = os.path.basename(cold_folder)
    hot_folder = os.path.join(HOT_DIR, folder_name)

    if not os.path.isdir(cold_folder):
        with _lock:
            _cold_index.pop(norm_question, None)
        return None

    try:
        shutil.move(cold_folder, hot_folder)
    except Exception as e:
        print(f"[promote] Failed to move '{folder_name}' cold -> hot: {e}")
        return None

    print(f"[promote] '{folder_name}' cold -> hot")

    hot_onnx = _find_onnx_in_folder(hot_folder)

    if not hot_onnx:
        return None

    try:
        load_session(hot_onnx)
    except Exception as e:
        print(f"[promote] Failed to load '{hot_onnx}': {e}")
        return None

    with _lock:
        _registry[norm_question] = hot_onnx
        _hot_model_files.add(hot_onnx)
        _model_last_used[hot_onnx] = time.time()
        _cold_index.pop(norm_question, None)

    return hot_onnx

# ======================================================
# EVICT MODEL: hot -> cold
# ======================================================
def evict_unused_models():
    now = time.time()

    with _lock:
        hot_count = len(_hot_model_files)

        if hot_count <= MIN_HOT_MODELS:
            return

        candidates = sorted(
            list(_hot_model_files),
            key=lambda p: _model_last_used.get(p, 0),
        )

    evictable = hot_count - MIN_HOT_MODELS
    evicted = 0

    for hot_onnx in candidates:
        if evicted >= evictable:
            break

        with _lock:
            last = _model_last_used.get(hot_onnx, 0)

        if now - last < EVICTION_INTERVAL_SEC:
            continue

        hot_folder = os.path.dirname(hot_onnx)
        folder_name = os.path.basename(hot_folder)
        cold_folder = os.path.join(COLD_DIR, folder_name)

        with _lock:
            keys = [q for q, p in _registry.items() if p == hot_onnx]

            for q in keys:
                del _registry[q]
                cold_onnx = os.path.join(cold_folder, os.path.basename(hot_onnx))
                _cold_index[q] = cold_onnx

            _hot_model_files.discard(hot_onnx)
            _model_last_used.pop(hot_onnx, None)
            _model_cache.pop(hot_onnx, None)

        try:
            shutil.move(hot_folder, cold_folder)
            print(f"[evict] '{folder_name}' hot -> cold")
            evicted += 1
        except Exception as e:
            print(f"[evict] Failed to move '{folder_name}': {e}")

# ======================================================
# BACKGROUND EVICTOR THREAD
# ======================================================
_evictor_stop = threading.Event()

def _evictor_loop():
    while not _evictor_stop.is_set():
        _evictor_stop.wait(EVICTION_INTERVAL_SEC)

        if _evictor_stop.is_set():
            break

        try:
            evict_unused_models()
        except Exception as e:
            print(f"[evictor] error: {e}")

# ======================================================
# LIFESPAN
# ======================================================
@asynccontextmanager
async def lifespan(app: FastAPI):
    scan_hot_models()
    scan_cold_models()

    evictor = threading.Thread(target=_evictor_loop, daemon=True)
    evictor.start()

    print(
        f"[startup] hot={len(_hot_model_files)} models, "
        f"cold={len(_cold_index)} indexed questions"
    )

    yield

    _evictor_stop.set()

# ======================================================
# FASTAPI
# ======================================================
app = FastAPI(title="YOLO ONNX Solver", version="12.0.0", lifespan=lifespan)

# ======================================================
# SCHEMAS
# ======================================================
class PredictRequest(BaseModel):
    class_name: str
    image_base64: str
    conf_threshold: Optional[float] = 0.20
    iou_threshold: Optional[float] = 0.70
    img_size: Optional[int] = 256

class ExtensionCheckRequest(BaseModel):
    model_config = {"extra": "allow"}   # absorb any unknown fields -> no 422

    question: Optional[str] = ""
    questionType: Optional[str] = "objectClassify"
    # queries can be omitted OR sent as JSON null -> both normalised to []
    queries: Optional[List[str]] = None

    @property
    def queries_list(self) -> List[str]:
        """Always returns a list, never None - use this instead of .queries directly."""
        return self.queries if self.queries is not None else []

class OuterCheckRequest(BaseModel):
    """Outer envelope sent by the captcha client: { apiKey, source, task: {...} }"""
    model_config = {"extra": "allow"}
    apiKey: Optional[str] = None
    source: Optional[str] = None
    version: Optional[str] = None
    appID: Optional[Any] = None
    task: Optional[Dict[str, Any]] = None

    def to_inner(self) -> "ExtensionCheckRequest":
        if self.task:
            return ExtensionCheckRequest(**self.task)
        data = {k: v for k, v in (self.model_extra or {}).items()}
        return ExtensionCheckRequest(**data)

# ======================================================
# RESPONSE BUILDER
# ======================================================
def build_extension_response(answers, qtype):
    return {
        "code": 200,
        "msg": "",
        "answers": answers,
        "meta": {
            "pass_report": True,
            "fail_report": True,
            "data": "",
        },
        "questionType": qtype or "objectClassify",
    }

# ======================================================
# /predict
# ======================================================
@app.post("/predict")
async def predict(req: PredictRequest):
    norm = normalize_label(req.class_name)

    onnx_path = _registry.get(norm)

    if not onnx_path:
        onnx_path = promote_from_cold(norm)

    if not onnx_path:
        return {"found": False}

    sess = load_session(onnx_path)

    with _lock:
        _model_last_used[onnx_path] = time.time()

    pil = decode_b64_to_pil(req.image_base64)
    tensor = preprocess(pil, req.img_size)

    found = detect(sess, tensor, req.conf_threshold)
    return {"found": found}

# ======================================================
# /check
# ======================================================
@app.post("/check")
async def check(outer: OuterCheckRequest):
    req = outer.to_inner()
    if not req.queries_list:
        raise HTTPException(400, "queries empty")

    norm_q = normalize_label(req.question)

    print(
        f"[check] question={req.question!r} "
        f"norm={norm_q!r} "
        f"qtype={req.questionType!r} "
        f"queries={len(req.queries_list)}"
    )

    answers: List[bool] = []

    for tile_b64 in req.queries_list:
        try:
            result = await predict(
                PredictRequest(
                    class_name=req.question,
                    image_base64=tile_b64,
                )
            )
            answers.append(bool(result.get("found", False)))

        except Exception as e:
            print(f"[check] prediction error: {e}")
            answers.append(False)

    if answers and not any(answers):
        return []

    return build_extension_response(answers, req.questionType)

# ======================================================
# DEBUG ENDPOINTS
# ======================================================

@app.get("/health")
async def health():
    return {
        "ok": True,
        "hot_models": len(_hot_model_files),
        "cold_models": len(_cold_index),
    }

@app.get("/models")
async def list_models():
    with _lock:
        hot = {q: os.path.basename(os.path.dirname(p)) for q, p in _registry.items()}
        cold = {q: os.path.basename(os.path.dirname(p)) for q, p in _cold_index.items()}

        return {
            "hot_count": len(_hot_model_files),
            "hot": hot,
            "cold_count": len(_cold_index),
            "cold": cold,
        }

# ======================================================
# RUN:
# uvicorn app:app --host 0.0.0.0 --port 8000
# ======================================================
