import os
import argparse
import fitz  # PyMuPDF
import pdfplumber
import easyocr
from pathlib import Path
from PIL import Image
import base64
import re
import cv2
from markdownify import markdownify as md

# # ----------------------------
# # 配置
# # ----------------------------
# INPUT_PDF = "产品简介.pdf"
# OUTPUT_MD = "产品简介.md"
# IMG_DIR = Path("images")
# IMG_DIR.mkdir(exist_ok=True)

# 初始化 GPU OCR
reader = easyocr.Reader(['en', 'ch_sim'], gpu=True)

# ----------------------------
# 1. 提取 PDF 文本块
# ----------------------------


def extract_text_blocks(pdf_path: Path | str):
    doc = fitz.open(pdf_path)
    pages_text = []
    for page in doc:
        # [(x0, y0, x1, y1, text, block_no, block_type), ...]
        blocks = page.get_text("blocks")
        blocks.sort(key=lambda b: (b[1], b[0]))  # 按纵向、横向排序
        page_text = []
        for b in blocks:
            text = b[4].strip()
            if text:
                page_text.append(text)
        pages_text.append("\n".join(page_text))
    return pages_text

# ----------------------------
# 2. 提取 PDF 表格
# ----------------------------


def extract_tables(pdf_path):
    tables_md = []
    with pdfplumber.open(pdf_path) as pdf:
        for page in pdf.pages:
            for table in page.extract_tables():
                # 转 Markdown 表格
                md_table = "| " + " | ".join(table[0]) + " |\n"
                md_table += "| " + " | ".join(["---"]*len(table[0])) + " |\n"
                for row in table[1:]:
                    md_table += "| " + \
                        " | ".join(
                            [cell if cell else "" for cell in row]) + " |\n"
                tables_md.append(md_table)
    return tables_md

# ----------------------------
# 3. 提取图片 + OCR
# ----------------------------


def extract_images_ocr(pdf_path: Path, img_dir: Path):
    doc = fitz.open(pdf_path)
    images_text = []
    for i, page in enumerate(doc):
        image_list = page.get_images(full=True)
        for j, img in enumerate(image_list):
            xref = img[0]
            pix = fitz.Pixmap(doc, xref)
            if pix.n < 5:  # 不是 CMYK
                img_pil = Image.frombytes(
                    "RGB", [pix.width, pix.height], pix.samples)
            else:  # CMYK 转 RGB
                pix = fitz.Pixmap(fitz.csRGB, pix)
                img_pil = Image.frombytes(
                    "RGB", [pix.width, pix.height], pix.samples)

            img_path = img_dir / f"page{i+1}_img{j+1}.png"
            img_dir.mkdir(parents=True, exist_ok=True)
            img_pil.save(img_path)

            # OCR 识别
            result = reader.readtext(str(img_path))
            ocr_text = "\n".join([t[1] for t in result])
            if ocr_text.strip():
                images_text.append(f"### Image OCR page {i+1}\n{ocr_text}\n")
    return images_text

# ----------------------------
# 4. 合并生成 Markdown
# ----------------------------


def generate_markdown(pdf_path: Path, output_md: Path, img_dir: Path):
    texts = extract_text_blocks(pdf_path)
    tables = extract_tables(pdf_path)
    images_ocr = extract_images_ocr(pdf_path, img_dir=img_dir)

    with open(output_md, "w", encoding="utf-8") as f:
        for page_text in texts:
            f.write(page_text + "\n\n")
        for table_md in tables:
            f.write(table_md + "\n")
        for img_text in images_ocr:
            f.write(img_text + "\n")
    print(f"Markdown 已生成：{output_md}")


# ----------------------------
# 主函数
# ----------------------------

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PDF → Markdown GPU 加速脚本")
    parser.add_argument("input_pdf", help="输入 PDF 文件路径")
    parser.add_argument("output_dir", help="输出 Markdown 文件夹")
    parser.add_argument(
        "--output_name",
        help="可选：自定义输出 Markdown 文件名（不带路径），默认使用 PDF 文件名",
        default=None
    )
    parser.add_argument(
        "--img-dir",
        help="图片保存目录（默认 images）",
        default=os.environ.get("PDF2MD_IMG_DIR", "images")
    )
    parser.add_argument(
        "--ocr-gpu",
        help="是否使用 GPU OCR (True/False，默认 True)",
        default=os.environ.get("PDF2MD_OCR_GPU", "True")
    )
    args = parser.parse_args()

    input_pdf = Path(args.input_pdf)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # 根据 PDF 文件名生成默认输出文件名
    if args.output_name:
        output_md = output_dir / args.output_name
    else:
        output_md = output_dir / (input_pdf.stem + ".md")

    img_dir = Path(args.img_dir)
    use_gpu = args.ocr_gpu.lower() == "true"

    print(f"输入 PDF: {input_pdf}")
    print(f"输出 Markdown: {output_md}")
    print(f"图片目录: {img_dir}")
    print(f"OCR GPU: {use_gpu}")

    # 这里调用你的生成函数
    generate_markdown(pdf_path=input_pdf, output_md=output_md, img_dir=img_dir)
