fix(ml): ocr inputs not resized correctly (#23541)

* fix resizing, use pillow

* unused import

* linting

* lanczos

* optimizations

fused operations

unused import
This commit is contained in:
Mert 2025-11-03 02:21:30 -05:00 committed by GitHub
parent f5ff36a1f8
commit 79d0e3e1ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 120 additions and 51 deletions

View file

@ -1,8 +1,10 @@
from typing import Any from typing import Any
import cv2
import numpy as np import numpy as np
from numpy.typing import NDArray
from PIL import Image from PIL import Image
from rapidocr.ch_ppocr_det import TextDetector as RapidTextDetector from rapidocr.ch_ppocr_det.utils import DBPostProcess
from rapidocr.inference_engine.base import FileInfo, InferSession from rapidocr.inference_engine.base import FileInfo, InferSession
from rapidocr.utils import DownloadFile, DownloadFileInput from rapidocr.utils import DownloadFile, DownloadFileInput
from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType from rapidocr.utils.typings import EngineType, LangDet, OCRVersion, TaskType
@ -10,11 +12,10 @@ from rapidocr.utils.typings import ModelType as RapidModelType
from immich_ml.config import log from immich_ml.config import log
from immich_ml.models.base import InferenceModel from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import decode_cv2
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession from immich_ml.sessions.ort import OrtSession
from .schemas import OcrOptions, TextDetectionOutput from .schemas import TextDetectionOutput
class TextDetector(InferenceModel): class TextDetector(InferenceModel):
@ -24,13 +25,20 @@ class TextDetector(InferenceModel):
def __init__(self, model_name: str, **model_kwargs: Any) -> None: def __init__(self, model_name: str, **model_kwargs: Any) -> None:
super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX) super().__init__(model_name, **model_kwargs, model_format=ModelFormat.ONNX)
self.max_resolution = 736 self.max_resolution = 736
self.min_score = 0.5 self.mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
self.score_mode = "fast" self.std_inv = np.float32(1.0) / (np.array([0.5, 0.5, 0.5], dtype=np.float32) * 255.0)
self._empty: TextDetectionOutput = { self._empty: TextDetectionOutput = {
"image": np.empty(0, dtype=np.float32),
"boxes": np.empty(0, dtype=np.float32), "boxes": np.empty(0, dtype=np.float32),
"scores": np.empty(0, dtype=np.float32), "scores": np.empty(0, dtype=np.float32),
} }
self.postprocess = DBPostProcess(
thresh=0.3,
box_thresh=model_kwargs.get("minScore", 0.5),
max_candidates=1000,
unclip_ratio=1.6,
use_dilation=True,
score_mode="fast",
)
def _download(self) -> None: def _download(self) -> None:
model_info = InferSession.get_model_url( model_info = InferSession.get_model_url(
@ -52,35 +60,65 @@ class TextDetector(InferenceModel):
def _load(self) -> ModelSession: def _load(self) -> ModelSession:
# TODO: support other runtime sessions # TODO: support other runtime sessions
session = OrtSession(self.model_path) return OrtSession(self.model_path)
self.model = RapidTextDetector(
OcrOptions(
session=session.session,
limit_side_len=self.max_resolution,
limit_type="min",
box_thresh=self.min_score,
score_mode=self.score_mode,
)
)
return session
def _predict(self, inputs: bytes | Image.Image) -> TextDetectionOutput: # partly adapted from RapidOCR
results = self.model(decode_cv2(inputs)) def _predict(self, inputs: Image.Image) -> TextDetectionOutput:
if results.boxes is None or results.scores is None or results.img is None: w, h = inputs.size
if w < 32 or h < 32:
return self._empty
out = self.session.run(None, {"x": self._transform(inputs)})[0]
boxes, scores = self.postprocess(out, (h, w))
if len(boxes) == 0:
return self._empty return self._empty
return { return {
"image": results.img, "boxes": self.sorted_boxes(boxes),
"boxes": np.array(results.boxes, dtype=np.float32), "scores": np.array(scores, dtype=np.float32),
"scores": np.array(results.scores, dtype=np.float32),
} }
# adapted from RapidOCR
def _transform(self, img: Image.Image) -> NDArray[np.float32]:
if img.height < img.width:
ratio = float(self.max_resolution) / img.height
else:
ratio = float(self.max_resolution) / img.width
resize_h = int(img.height * ratio)
resize_w = int(img.width * ratio)
resize_h = int(round(resize_h / 32) * 32)
resize_w = int(round(resize_w / 32) * 32)
resized_img = img.resize((int(resize_w), int(resize_h)), resample=Image.Resampling.LANCZOS)
img_np: NDArray[np.float32] = cv2.cvtColor(np.array(resized_img, dtype=np.float32), cv2.COLOR_RGB2BGR) # type: ignore
img_np -= self.mean
img_np *= self.std_inv
img_np = np.transpose(img_np, (2, 0, 1))
return np.expand_dims(img_np, axis=0)
def sorted_boxes(self, dt_boxes: NDArray[np.float32]) -> NDArray[np.float32]:
if len(dt_boxes) == 0:
return dt_boxes
# Sort by y, then identify lines, then sort by (line, x)
y_order = np.argsort(dt_boxes[:, 0, 1], kind="stable")
sorted_y = dt_boxes[y_order, 0, 1]
line_ids = np.empty(len(dt_boxes), dtype=np.int32)
line_ids[0] = 0
np.cumsum(np.abs(np.diff(sorted_y)) >= 10, out=line_ids[1:])
# Create composite sort key for final ordering
# Shift line_ids by large factor, add x for tie-breaking
sort_key = line_ids[y_order] * 1e6 + dt_boxes[y_order, 0, 0]
final_order = np.argsort(sort_key, kind="stable")
sorted_boxes: NDArray[np.float32] = dt_boxes[y_order[final_order]]
return sorted_boxes
def configure(self, **kwargs: Any) -> None: def configure(self, **kwargs: Any) -> None:
if (max_resolution := kwargs.get("maxResolution")) is not None: if (max_resolution := kwargs.get("maxResolution")) is not None:
self.max_resolution = max_resolution self.max_resolution = max_resolution
self.model.limit_side_len = max_resolution
if (min_score := kwargs.get("minScore")) is not None: if (min_score := kwargs.get("minScore")) is not None:
self.min_score = min_score self.postprocess.box_thresh = min_score
self.model.postprocess_op.box_thresh = min_score
if (score_mode := kwargs.get("scoreMode")) is not None: if (score_mode := kwargs.get("scoreMode")) is not None:
self.score_mode = score_mode self.postprocess.score_mode = score_mode
self.model.postprocess_op.score_mode = score_mode

View file

@ -1,9 +1,8 @@
from typing import Any from typing import Any
import cv2
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
from PIL.Image import Image from PIL import Image
from rapidocr.ch_ppocr_rec import TextRecInput from rapidocr.ch_ppocr_rec import TextRecInput
from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer from rapidocr.ch_ppocr_rec import TextRecognizer as RapidTextRecognizer
from rapidocr.inference_engine.base import FileInfo, InferSession from rapidocr.inference_engine.base import FileInfo, InferSession
@ -14,6 +13,7 @@ from rapidocr.utils.vis_res import VisRes
from immich_ml.config import log, settings from immich_ml.config import log, settings
from immich_ml.models.base import InferenceModel from immich_ml.models.base import InferenceModel
from immich_ml.models.transforms import pil_to_cv2
from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType from immich_ml.schemas import ModelFormat, ModelSession, ModelTask, ModelType
from immich_ml.sessions.ort import OrtSession from immich_ml.sessions.ort import OrtSession
@ -65,17 +65,16 @@ class TextRecognizer(InferenceModel):
) )
return session return session
def _predict(self, _: Image, texts: TextDetectionOutput) -> TextRecognitionOutput: def _predict(self, img: Image.Image, texts: TextDetectionOutput) -> TextRecognitionOutput:
boxes, img, box_scores = texts["boxes"], texts["image"], texts["scores"] boxes, box_scores = texts["boxes"], texts["scores"]
if boxes.shape[0] == 0: if boxes.shape[0] == 0:
return self._empty return self._empty
rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes))) rec = self.model(TextRecInput(img=self.get_crop_img_list(img, boxes)))
if rec.txts is None: if rec.txts is None:
return self._empty return self._empty
height, width = img.shape[0:2] boxes[:, :, 0] /= img.width
boxes[:, :, 0] /= width boxes[:, :, 1] /= img.height
boxes[:, :, 1] /= height
text_scores = np.array(rec.scores) text_scores = np.array(rec.scores)
valid_text_score_idx = text_scores > self.min_score valid_text_score_idx = text_scores > self.min_score
@ -87,7 +86,7 @@ class TextRecognizer(InferenceModel):
"textScore": text_scores[valid_text_score_idx], "textScore": text_scores[valid_text_score_idx],
} }
def get_crop_img_list(self, img: NDArray[np.float32], boxes: NDArray[np.float32]) -> list[NDArray[np.float32]]: def get_crop_img_list(self, img: Image.Image, boxes: NDArray[np.float32]) -> list[NDArray[np.uint8]]:
img_crop_width = np.maximum( img_crop_width = np.maximum(
np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1) np.linalg.norm(boxes[:, 1] - boxes[:, 0], axis=1), np.linalg.norm(boxes[:, 2] - boxes[:, 3], axis=1)
).astype(np.int32) ).astype(np.int32)
@ -98,22 +97,55 @@ class TextRecognizer(InferenceModel):
pts_std[:, 1:3, 0] = img_crop_width[:, None] pts_std[:, 1:3, 0] = img_crop_width[:, None]
pts_std[:, 2:4, 1] = img_crop_height[:, None] pts_std[:, 2:4, 1] = img_crop_height[:, None]
img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1).tolist() img_crop_sizes = np.stack([img_crop_width, img_crop_height], axis=1)
imgs: list[NDArray[np.float32]] = [] all_coeffs = self._get_perspective_transform(pts_std, boxes)
for box, pts_std, dst_size in zip(list(boxes), list(pts_std), img_crop_sizes): imgs: list[NDArray[np.uint8]] = []
M = cv2.getPerspectiveTransform(box, pts_std) for coeffs, dst_size in zip(all_coeffs, img_crop_sizes):
dst_img: NDArray[np.float32] = cv2.warpPerspective( dst_img = img.transform(
img, size=tuple(dst_size),
M, method=Image.Transform.PERSPECTIVE,
dst_size, data=tuple(coeffs),
borderMode=cv2.BORDER_REPLICATE, resample=Image.Resampling.BICUBIC,
flags=cv2.INTER_CUBIC, )
) # type: ignore
dst_height, dst_width = dst_img.shape[0:2] dst_width, dst_height = dst_img.size
if dst_height * 1.0 / dst_width >= 1.5: if dst_height * 1.0 / dst_width >= 1.5:
dst_img = np.rot90(dst_img) dst_img = dst_img.rotate(90, expand=True)
imgs.append(dst_img) imgs.append(pil_to_cv2(dst_img))
return imgs return imgs
def _get_perspective_transform(self, src: NDArray[np.float32], dst: NDArray[np.float32]) -> NDArray[np.float32]:
N = src.shape[0]
x, y = src[:, :, 0], src[:, :, 1]
u, v = dst[:, :, 0], dst[:, :, 1]
A = np.zeros((N, 8, 9), dtype=np.float32)
# Fill even rows (0, 2, 4, 6): [x, y, 1, 0, 0, 0, -u*x, -u*y, -u]
A[:, ::2, 0] = x
A[:, ::2, 1] = y
A[:, ::2, 2] = 1
A[:, ::2, 6] = -u * x
A[:, ::2, 7] = -u * y
A[:, ::2, 8] = -u
# Fill odd rows (1, 3, 5, 7): [0, 0, 0, x, y, 1, -v*x, -v*y, -v]
A[:, 1::2, 3] = x
A[:, 1::2, 4] = y
A[:, 1::2, 5] = 1
A[:, 1::2, 6] = -v * x
A[:, 1::2, 7] = -v * y
A[:, 1::2, 8] = -v
# Solve using SVD for all matrices at once
_, _, Vt = np.linalg.svd(A)
H = Vt[:, -1, :].reshape(N, 3, 3)
H = H / H[:, 2:3, 2:3]
# Extract the 8 coefficients for each transformation
return np.column_stack(
[H[:, 0, 0], H[:, 0, 1], H[:, 0, 2], H[:, 1, 0], H[:, 1, 1], H[:, 1, 2], H[:, 2, 0], H[:, 2, 1]]
) # pyright: ignore[reportReturnType]
def configure(self, **kwargs: Any) -> None: def configure(self, **kwargs: Any) -> None:
self.min_score = kwargs.get("minScore", self.min_score) self.min_score = kwargs.get("minScore", self.min_score)

View file

@ -7,7 +7,6 @@ from typing_extensions import TypedDict
class TextDetectionOutput(TypedDict): class TextDetectionOutput(TypedDict):
image: npt.NDArray[np.float32]
boxes: npt.NDArray[np.float32] boxes: npt.NDArray[np.float32]
scores: npt.NDArray[np.float32] scores: npt.NDArray[np.float32]