import fitz  # PyMuPDF
import pdfplumber
import easyocr
from pathlib import Path
from PIL import Image
import argparse
import os


class PDFToMarkdownConverter:
    def __init__(self, pdf_path, output_dir, output_name=None, img_dir='images', use_gpu=True):
        self.pdf_path = Path(pdf_path)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.output_name = output_name or (self.pdf_path.stem + ".md")
        self.output_md = self.output_dir / self.output_name
        self.img_dir = Path(img_dir)
        self.img_dir.mkdir(exist_ok=True)
        self.use_gpu = use_gpu
        self.reader = easyocr.Reader(['en', 'ch_sim'], gpu=self.use_gpu)

    def extract_text_blocks(self):
        doc = fitz.open(self.pdf_path)
        pages_text = []
        for page in doc:
            blocks = page.get_text("blocks")
            blocks.sort(key=lambda b: (b[1], b[0]))
            page_text = []
            for b in blocks:
                text = b[4].strip()
                if text:
                    page_text.append(text)
            pages_text.append("\n".join(page_text))
        return pages_text

    def extract_tables(self):
        tables_md = []
        with pdfplumber.open(self.pdf_path) as pdf:
            for page in pdf.pages:
                for table in page.extract_tables():
                    if not table:
                        continue
                    md_table = "| " + " | ".join(table[0]) + " |\n"
                    md_table += "| " + \
                        " | ".join(["---"]*len(table[0])) + " |\n"
                    for row in table[1:]:
                        md_table += "| " + \
                            " | ".join(
                                [cell if cell else "" for cell in row]) + " |\n"
                    tables_md.append(md_table)
        return tables_md

    def extract_images_ocr(self):
        doc = fitz.open(self.pdf_path)
        images_text = []
        for i, page in enumerate(doc):
            try:
                image_list = page.get_images(full=True)
            except Exception as e:
                print(f"页面 {i+1} 获取图片异常，跳过: {e}")
                continue

            for j, img in enumerate(image_list):
                xref = img[0]
                try:
                    pix = fitz.Pixmap(doc, xref)
                except Exception as e:
                    print(f"页面 {i+1} 图像 {j+1} 渲染异常，跳过: {e}")
                    continue

                if pix.n < 5:
                    img_pil = Image.frombytes(
                        "RGB", [pix.width, pix.height], pix.samples)
                else:
                    pix = fitz.Pixmap(fitz.csRGB, pix)
                    img_pil = Image.frombytes(
                        "RGB", [pix.width, pix.height], pix.samples)

                img_path = self.img_dir / f"page{i+1}_img{j+1}.png"
                img_pil.save(img_path)

                result = self.reader.readtext(str(img_path))
                ocr_text = "\n".join([t[1] for t in result])
                if ocr_text.strip():
                    images_text.append(
                        f"### Image OCR page {i+1}\n{ocr_text}\n")
        return images_text

    def generate_markdown(self):
        texts = self.extract_text_blocks()
        tables = self.extract_tables()
        images_ocr = self.extract_images_ocr()

        with open(self.output_md, "w", encoding="utf-8") as f:
            for page_text in texts:
                f.write(page_text + "\n\n")
            for table_md in tables:
                f.write(table_md + "\n")
            for img_text in images_ocr:
                f.write(img_text + "\n")
        print(f"Markdown 已生成：{self.output_md}")


class CLIParser:
    def __init__(self):
        self.parser = argparse.ArgumentParser(
            description="PDF → Markdown GPU OCR 脚本")
        self.parser.add_argument("input_pdf", help="输入 PDF 文件路径")
        self.parser.add_argument("output_dir", help="输出 Markdown 文件夹")
        self.parser.add_argument(
            "--output_name",
            help="可选：自定义输出 Markdown 文件名（不带路径），默认使用 PDF 文件名",
            default=None
        )
        self.parser.add_argument(
            "--img-dir",
            help="图片保存目录（默认 images）",
            default=os.environ.get("PDF2MD_IMG_DIR", "images")
        )
        self.parser.add_argument(
            "--ocr-gpu",
            help="是否使用 GPU OCR (True/False，默认 True)",
            default=os.environ.get("PDF2MD_OCR_GPU", "True")
        )

    def parse_args(self):
        args = self.parser.parse_args()
        args.ocr_gpu = args.ocr_gpu.lower() == "true"
        return args


if __name__ == "__main__":
    cli = CLIParser()
    args = cli.parse_args()

    converter = PDFToMarkdownConverter(
        pdf_path=args.input_pdf,
        output_dir=args.output_dir,
        output_name=args.output_name,
        img_dir=args.img_dir,
        use_gpu=args.ocr_gpu
    )
    converter.generate_markdown()
