import fitz  # PyMuPDF 用于 PDF 文本和图像提取
import pdfplumber  # PDF 表格提取
import easyocr  # OCR 识别
from pathlib import Path
from PIL import Image
import argparse
import os
import logging

# 配置日志
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')


class Config:
    """配置类，统一管理命令行参数和环境变量"""

    def __init__(self, input_pdf, output_dir, output_name=None, img_dir='images', ocr_gpu=True, min_text_len=20):
        self.input_pdf = Path(input_pdf)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.output_name = output_name or (self.input_pdf.stem + ".md")
        self.output_md = self.output_dir / self.output_name
        self.img_dir = Path(img_dir)
        self.img_dir.mkdir(exist_ok=True)
        self.ocr_gpu = ocr_gpu
        self.min_text_len = min_text_len


class PDFToMarkdownConverter:
    """
    PDF 转 Markdown 的核心类
    支持文本提取、表格提取以及对文字较少页面的图像OCR分析
    """

    def __init__(self, config: Config = None):
        self.reader_cache = {}
        self.config = config

    def get_ocr_reader(self, use_gpu=True):
        key = 'gpu' if use_gpu else 'cpu'
        if key not in self.reader_cache:
            logging.info(f"创建 OCR Reader，GPU={use_gpu}")
            self.reader_cache[key] = easyocr.Reader(
                ['en', 'ch_sim'], gpu=use_gpu)
        return self.reader_cache[key]

    def extract_text_blocks(self, pdf_path):
        """
        从PDF文件中提取文本块

        Args:
            pdf_path (str): PDF文件的路径

        Returns:
            list: 包含每页文本内容的列表，每个元素代表一页的文本内容
        """
        doc = fitz.open(pdf_path)
        pages_text = []
        for idx, page in enumerate(doc):
            # 获取页面的文本块并按坐标排序
            blocks = page.get_text("blocks")
            blocks.sort(key=lambda b: (b[1], b[0]))
            # 提取非空文本块并组合成页面文本
            page_text = [b[4].strip() for b in blocks if b[4].strip()]
            pages_text.append("\n".join(page_text))
            logging.info(f"提取页面 {idx+1} 文本，长度 {len(page_text)}")
        return pages_text

    def extract_tables(self, pdf_path):
        tables_md = []
        with pdfplumber.open(pdf_path) as pdf:
            for page_idx, page in enumerate(pdf.pages):
                for table in page.extract_tables():
                    if not table:
                        continue
                    md_table = "| " + " | ".join(table[0]) + " |\n"
                    md_table += "| " + \
                        " | ".join(["---"]*len(table[0])) + " |\n"
                    for row in table[1:]:
                        md_table += "| " + \
                            " | ".join(
                                [cell if cell else "" for cell in row]) + " |\n"
                    tables_md.append(md_table)
                    logging.info(f"提取页面 {page_idx+1} 表格，行数 {len(table)}")
        return tables_md

    def extract_images_ocr(self, pdf_path, img_dir, use_gpu=True, page_idx=None):
        reader = self.get_ocr_reader(use_gpu)
        img_dir = Path(img_dir)
        img_dir.mkdir(exist_ok=True)
        doc = fitz.open(pdf_path)
        images_text = []
        pages_to_process = [
            page_idx] if page_idx is not None else range(len(doc))
        for i in pages_to_process:
            page = doc[i]
            try:
                image_list = page.get_images(full=True)
            except Exception as e:
                logging.warning(f"页面 {i+1} 获取图片异常，跳过: {e}")
                continue

            for j, img in enumerate(image_list):
                xref = img[0]
                try:
                    pix = fitz.Pixmap(doc, xref)
                except Exception as e:
                    logging.warning(f"页面 {i+1} 图像 {j+1} 渲染异常，跳过: {e}")
                    continue

                img_pil = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) if pix.n < 5 else Image.frombytes(
                    "RGB", [pix.width, pix.height], fitz.Pixmap(fitz.csRGB, pix).samples)
                img_path = img_dir / f"page{i+1}_img{j+1}.png"
                img_pil.save(img_path)
                logging.info(f"保存页面 {i+1} 图像 {j+1} 到 {img_path}")

                result = reader.readtext(str(img_path))
                ocr_text = "\n".join([t[1] for t in result])
                if ocr_text.strip():
                    images_text.append(
                        f"### Image OCR page {i+1}\n{ocr_text}\n")
        return images_text

    def generate_markdown(self, config: Config = None):
        cfg = config or self.config
        if cfg is None:
            raise ValueError("需要提供 Config 对象或在构造时传入 config")

        logging.info(f"开始提取文本 {cfg.input_pdf}")
        texts = self.extract_text_blocks(cfg.input_pdf)
        tables = self.extract_tables(cfg.input_pdf)
        images_ocr = []

        for idx, page_text in enumerate(texts):
            if len(page_text.strip()) < cfg.min_text_len:
                logging.info(f"页面 {idx+1} 文本较少，启用图像OCR")
                images_ocr.extend(self.extract_images_ocr(
                    cfg.input_pdf, cfg.img_dir, cfg.ocr_gpu, page_idx=idx))

        with open(cfg.output_md, "w", encoding="utf-8") as f:
            for page_text in texts:
                f.write(page_text + "\n\n")
            for table_md in tables:
                f.write(table_md + "\n")
            for img_text in images_ocr:
                f.write(img_text + "\n")
        logging.info(f"Markdown 已生成：{cfg.output_md}")

    def get_config(self):
        return self.config


class CLIParser:
    def __init__(self):
        self.parser = argparse.ArgumentParser(
            description="PDF → Markdown GPU OCR 脚本")
        self.parser.add_argument("input_pdf", help="输入 PDF 文件路径")
        self.parser.add_argument("output_dir", help="输出 Markdown 文件夹")
        self.parser.add_argument(
            "--output_name", help="可选：自定义输出 Markdown 文件名（不带路径）", default=None)
        self.parser.add_argument("--img-dir", help="图片保存目录（默认 images）",
                                 default=os.environ.get("PDF2MD_IMG_DIR", "images"))
        self.parser.add_argument("--ocr-gpu", help="是否使用 GPU OCR (True/False，默认 True)",
                                 default=os.environ.get("PDF2MD_OCR_GPU", "True"))
        self.parser.add_argument("--min-text-len", type=int, help="页面文字少于此长度时启用图像分析 (默认 20)",
                                 default=int(os.environ.get("PDF2MD_MIN_TEXT_LEN", 20)))

    def parse_args(self):
        args = self.parser.parse_args()
        args.ocr_gpu = args.ocr_gpu.lower() == "true"
        return args


if __name__ == "__main__":
    cli = CLIParser()
    args = cli.parse_args()

    config = Config(
        input_pdf=args.input_pdf,
        output_dir=args.output_dir,
        output_name=args.output_name,
        img_dir=args.img_dir,
        ocr_gpu=args.ocr_gpu,
        min_text_len=args.min_text_len
    )

    converter = PDFToMarkdownConverter(config)
    converter.generate_markdown()
